pull/1359/head
SengokuCola 2025-11-13 19:00:59 +08:00
commit d306e40db0
46 changed files with 1000 additions and 1041 deletions

2
bot.py
View File

@ -30,7 +30,7 @@ else:
raise raise
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
initialize_logging() initialize_logging()

View File

@ -1,15 +1,11 @@
from typing import List, Tuple, Type, Optional from typing import List, Tuple, Type, Optional
from src.plugin_system import ( from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField
BasePlugin,
register_plugin,
BaseCommand,
ComponentInfo,
ConfigField
)
from src.plugin_system.apis import send_api, frequency_api from src.plugin_system.apis import send_api, frequency_api
class SetTalkFrequencyCommand(BaseCommand): class SetTalkFrequencyCommand(BaseCommand):
"""设置当前聊天的talk_frequency值""" """设置当前聊天的talk_frequency值"""
command_name = "set_talk_frequency" command_name = "set_talk_frequency"
command_description = "设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>" command_description = "设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>"
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$" command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
@ -19,35 +15,35 @@ class SetTalkFrequencyCommand(BaseCommand):
# 获取命令参数 - 使用命名捕获组 # 获取命令参数 - 使用命名捕获组
if not self.matched_groups or "value" not in self.matched_groups: if not self.matched_groups or "value" not in self.matched_groups:
return False, "命令格式错误", False return False, "命令格式错误", False
value_str = self.matched_groups["value"] value_str = self.matched_groups["value"]
if not value_str: if not value_str:
return False, "无法获取数值参数", False return False, "无法获取数值参数", False
value = float(value_str) value = float(value_str)
# 获取聊天流ID # 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"): if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
return False, "无法获取聊天流信息", False return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id chat_id = self.message.chat_stream.stream_id
# 设置talk_frequency # 设置talk_frequency
frequency_api.set_talk_frequency_adjust(chat_id, value) frequency_api.set_talk_frequency_adjust(chat_id, value)
final_value = frequency_api.get_current_talk_value(chat_id) final_value = frequency_api.get_current_talk_value(chat_id)
adjust_value = frequency_api.get_talk_frequency_adjust(chat_id) adjust_value = frequency_api.get_talk_frequency_adjust(chat_id)
base_value = final_value / adjust_value base_value = final_value / adjust_value
# 发送反馈消息(不保存到数据库) # 发送反馈消息(不保存到数据库)
await send_api.text_to_stream( await send_api.text_to_stream(
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}", f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
chat_id, chat_id,
storage_message=False storage_message=False,
) )
return True, None, False return True, None, False
except ValueError: except ValueError:
error_msg = "数值格式错误,请输入有效的数字" error_msg = "数值格式错误,请输入有效的数字"
await self.send_text(error_msg, storage_message=False) await self.send_text(error_msg, storage_message=False)
@ -60,6 +56,7 @@ class SetTalkFrequencyCommand(BaseCommand):
class ShowFrequencyCommand(BaseCommand): class ShowFrequencyCommand(BaseCommand):
"""显示当前聊天的频率控制状态""" """显示当前聊天的频率控制状态"""
command_name = "show_frequency" command_name = "show_frequency"
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s" command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
command_pattern = r"^/chat\s+(?:show|s)$" command_pattern = r"^/chat\s+(?:show|s)$"
@ -116,11 +113,7 @@ class BetterFrequencyPlugin(BasePlugin):
config_file_name: str = "config.toml" config_file_name: str = "config.toml"
# 配置节描述 # 配置节描述
config_section_descriptions = { config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"}
"plugin": "插件基本信息",
"frequency": "频率控制配置",
"features": "功能开关配置"
}
# 配置Schema定义 # 配置Schema定义
config_schema: dict = { config_schema: dict = {
@ -138,13 +131,14 @@ class BetterFrequencyPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
components = [] components = []
# 根据配置决定是否注册命令组件 # 根据配置决定是否注册命令组件
if self.config.get("features", {}).get("enable_commands", True): if self.config.get("features", {}).get("enable_commands", True):
components.extend([ components.extend(
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand), [
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand), (SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
]) (ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
]
)
return components return components

View File

@ -6,15 +6,16 @@ import sys
import os import os
from datetime import datetime from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
# 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级) # 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path: if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT) sys.path.insert(0, PROJECT_ROOT)
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
SECONDS_5_MINUTES = 5 * 60 SECONDS_5_MINUTES = 5 * 60
@ -28,16 +29,16 @@ def clean_output_text(text: str) -> str:
""" """
if not text: if not text:
return text return text
# 移除表情包内容:[表情包:...] # 移除表情包内容:[表情包:...]
text = re.sub(r'\[表情包:[^\]]*\]', '', text) text = re.sub(r"\[表情包:[^\]]*\]", "", text)
# 移除回复内容:[回复...],说:... 的完整模式 # 移除回复内容:[回复...],说:... 的完整模式
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text) text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
# 清理多余的空格和换行 # 清理多余的空格和换行
text = re.sub(r'\s+', ' ', text).strip() text = re.sub(r"\s+", " ", text).strip()
return text return text
@ -89,7 +90,7 @@ def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[Databa
for msg in messages: for msg in messages:
groups.setdefault(msg.chat_id, []).append(msg) groups.setdefault(msg.chat_id, []).append(msg)
# 保证每个分组内按时间升序 # 保证每个分组内按时间升序
for chat_id, msgs in groups.items(): for _chat_id, msgs in groups.items():
msgs.sort(key=lambda m: m.time or 0) msgs.sort(key=lambda m: m.time or 0)
return groups return groups
@ -170,8 +171,8 @@ def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseM
continue continue
last = bucket[-1] last = bucket[-1]
same_user = (msg.user_info.user_id == last.user_info.user_id) same_user = msg.user_info.user_id == last.user_info.user_id
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES) close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
if same_user and close_enough: if same_user and close_enough:
bucket.append(msg) bucket.append(msg)
@ -199,38 +200,36 @@ def build_pairs_for_chat(
pairs: List[Tuple[str, str, str]] = [] pairs: List[Tuple[str, str, str]] = []
n_merged = len(merged_messages) n_merged = len(merged_messages)
n_original = len(original_messages) n_original = len(original_messages)
if n_merged == 0 or n_original == 0: if n_merged == 0 or n_original == 0:
return pairs return pairs
# 为每个合并后的消息找到对应的原始消息位置 # 为每个合并后的消息找到对应的原始消息位置
merged_to_original_map = {} merged_to_original_map = {}
original_idx = 0 original_idx = 0
for merged_idx, merged_msg in enumerate(merged_messages): for merged_idx, merged_msg in enumerate(merged_messages):
# 找到这个合并消息对应的第一个原始消息 # 找到这个合并消息对应的第一个原始消息
while (original_idx < n_original and while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
original_messages[original_idx].time < merged_msg.time):
original_idx += 1 original_idx += 1
# 如果找到了时间匹配的原始消息,建立映射 # 如果找到了时间匹配的原始消息,建立映射
if (original_idx < n_original and if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
original_messages[original_idx].time == merged_msg.time):
merged_to_original_map[merged_idx] = original_idx merged_to_original_map[merged_idx] = original_idx
for merged_idx in range(n_merged): for merged_idx in range(n_merged):
merged_msg = merged_messages[merged_idx] merged_msg = merged_messages[merged_idx]
# 如果指定了 target_user_id只处理该用户的消息作为 output # 如果指定了 target_user_id只处理该用户的消息作为 output
if target_user_id and merged_msg.user_info.user_id != target_user_id: if target_user_id and merged_msg.user_info.user_id != target_user_id:
continue continue
# 找到对应的原始消息位置 # 找到对应的原始消息位置
if merged_idx not in merged_to_original_map: if merged_idx not in merged_to_original_map:
continue continue
original_idx = merged_to_original_map[merged_idx] original_idx = merged_to_original_map[merged_idx]
# 选择上下文窗口大小 # 选择上下文窗口大小
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
start = max(0, original_idx - window) start = max(0, original_idx - window)
@ -266,7 +265,7 @@ def build_pairs(
groups = group_by_chat(messages) groups = group_by_chat(messages)
all_pairs: List[Tuple[str, str, str]] = [] all_pairs: List[Tuple[str, str, str]] = []
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用 for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
# 对消息进行合并用于output # 对消息进行合并用于output
merged = merge_adjacent_same_user(msgs) merged = merge_adjacent_same_user(msgs)
# 传递原始消息和合并后消息input使用原始消息output使用合并后消息 # 传递原始消息和合并后消息input使用原始消息output使用合并后消息
@ -385,5 +384,3 @@ def run_interactive() -> int:
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(main()) sys.exit(main())

View File

@ -1,4 +1,3 @@
import time
import sys import sys
import os import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -6,16 +5,17 @@ import matplotlib.dates as mdates
from datetime import datetime from datetime import datetime
from typing import List, Tuple from typing import List, Tuple
import numpy as np import numpy as np
from src.common.database.database_model import Expression, ChatStreams
# Add project root to Python path # Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.common.database.database_model import Expression, ChatStreams
# 设置中文字体 # 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans'] plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams['axes.unicode_minus'] = False plt.rcParams["axes.unicode_minus"] = False
def get_chat_name(chat_id: str) -> str: def get_chat_name(chat_id: str) -> str:
@ -39,19 +39,14 @@ def get_expression_data() -> List[Tuple[float, float, str, str]]:
"""获取Expression表中的数据返回(create_date, count, chat_id, expression_type)的列表""" """获取Expression表中的数据返回(create_date, count, chat_id, expression_type)的列表"""
expressions = Expression.select() expressions = Expression.select()
data = [] data = []
for expr in expressions: for expr in expressions:
# 如果create_date为空跳过该记录 # 如果create_date为空跳过该记录
if expr.create_date is None: if expr.create_date is None:
continue continue
data.append(( data.append((expr.create_date, expr.count, expr.chat_id, expr.type))
expr.create_date,
expr.count,
expr.chat_id,
expr.type
))
return data return data
@ -60,71 +55,71 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
if not data: if not data:
print("没有找到有效的表达式数据") print("没有找到有效的表达式数据")
return return
# 分离数据 # 分离数据
create_dates = [item[0] for item in data] create_dates = [item[0] for item in data]
counts = [item[1] for item in data] counts = [item[1] for item in data]
chat_ids = [item[2] for item in data] _chat_ids = [item[2] for item in data]
expression_types = [item[3] for item in data] _expression_types = [item[3] for item in data]
# 转换时间戳为datetime对象 # 转换时间戳为datetime对象
dates = [datetime.fromtimestamp(ts) for ts in create_dates] dates = [datetime.fromtimestamp(ts) for ts in create_dates]
# 计算时间跨度,自动调整显示格式 # 计算时间跨度,自动调整显示格式
time_span = max(dates) - min(dates) time_span = max(dates) - min(dates)
if time_span.days > 30: # 超过30天按月显示 if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator() major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7) minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示 elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1) major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12) minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示 else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M' date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6) major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1) minor_locator = mdates.HourLocator(interval=1)
# 创建图形 # 创建图形
fig, ax = plt.subplots(figsize=(12, 8)) fig, ax = plt.subplots(figsize=(12, 8))
# 创建散点图 # 创建散点图
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis') scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap="viridis")
# 设置标签和标题 # 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12) ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12) ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold') ax.set_title("表达式使用次数随时间分布散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整 # 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format)) ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator) ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator) ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45) plt.xticks(rotation=45)
# 添加网格 # 添加网格
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
# 添加颜色条 # 添加颜色条
cbar = plt.colorbar(scatter) cbar = plt.colorbar(scatter)
cbar.set_label('数据点顺序', fontsize=10) cbar.set_label("数据点顺序", fontsize=10)
# 调整布局 # 调整布局
plt.tight_layout() plt.tight_layout()
# 显示统计信息 # 显示统计信息
print(f"\n=== 数据统计 ===") print("\n=== 数据统计 ===")
print(f"总数据点数量: {len(data)}") print(f"总数据点数量: {len(data)}")
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')}{max(dates).strftime('%Y-%m-%d %H:%M:%S')}") print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')}{max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
print(f"使用次数范围: {min(counts):.1f}{max(counts):.1f}") print(f"使用次数范围: {min(counts):.1f}{max(counts):.1f}")
print(f"平均使用次数: {np.mean(counts):.2f}") print(f"平均使用次数: {np.mean(counts):.2f}")
print(f"中位数使用次数: {np.median(counts):.2f}") print(f"中位数使用次数: {np.median(counts):.2f}")
# 保存图片 # 保存图片
if save_path: if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n散点图已保存到: {save_path}") print(f"\n散点图已保存到: {save_path}")
# 显示图片 # 显示图片
plt.show() plt.show()
@ -134,7 +129,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
if not data: if not data:
print("没有找到有效的表达式数据") print("没有找到有效的表达式数据")
return return
# 按chat_id分组 # 按chat_id分组
chat_groups = {} chat_groups = {}
for item in data: for item in data:
@ -142,75 +137,82 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
if chat_id not in chat_groups: if chat_id not in chat_groups:
chat_groups[chat_id] = [] chat_groups[chat_id] = []
chat_groups[chat_id].append(item) chat_groups[chat_id].append(item)
# 计算时间跨度,自动调整显示格式 # 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data] all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates) time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示 if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator() major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7) minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示 elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1) major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12) minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示 else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M' date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6) major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1) minor_locator = mdates.HourLocator(interval=1)
# 创建图形 # 创建图形
fig, ax = plt.subplots(figsize=(14, 10)) fig, ax = plt.subplots(figsize=(14, 10))
# 为每个聊天分配不同颜色 # 为每个聊天分配不同颜色
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups))) colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
for i, (chat_id, chat_data) in enumerate(chat_groups.items()): for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
create_dates = [item[0] for item in chat_data] create_dates = [item[0] for item in chat_data]
counts = [item[1] for item in chat_data] counts = [item[1] for item in chat_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates] dates = [datetime.fromtimestamp(ts) for ts in create_dates]
chat_name = get_chat_name(chat_id) chat_name = get_chat_name(chat_id)
# 截断过长的聊天名称 # 截断过长的聊天名称
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
ax.scatter(dates, counts, alpha=0.7, s=40, ax.scatter(
c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)", dates,
edgecolors='black', linewidth=0.5) counts,
alpha=0.7,
s=40,
c=[colors[i]],
label=f"{display_name} ({len(chat_data)}个)",
edgecolors="black",
linewidth=0.5,
)
# 设置标签和标题 # 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12) ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12) ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold') ax.set_title("按聊天分组的表达式使用次数散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整 # 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format)) ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator) ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator) ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45) plt.xticks(rotation=45)
# 添加图例 # 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8) ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
# 添加网格 # 添加网格
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
# 调整布局 # 调整布局
plt.tight_layout() plt.tight_layout()
# 显示统计信息 # 显示统计信息
print(f"\n=== 分组统计 ===") print("\n=== 分组统计 ===")
print(f"总聊天数量: {len(chat_groups)}") print(f"总聊天数量: {len(chat_groups)}")
for chat_id, chat_data in chat_groups.items(): for chat_id, chat_data in chat_groups.items():
chat_name = get_chat_name(chat_id) chat_name = get_chat_name(chat_id)
counts = [item[1] for item in chat_data] counts = [item[1] for item in chat_data]
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}") print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片 # 保存图片
if save_path: if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n分组散点图已保存到: {save_path}") print(f"\n分组散点图已保存到: {save_path}")
# 显示图片 # 显示图片
plt.show() plt.show()
@ -220,7 +222,7 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
if not data: if not data:
print("没有找到有效的表达式数据") print("没有找到有效的表达式数据")
return return
# 按type分组 # 按type分组
type_groups = {} type_groups = {}
for item in data: for item in data:
@ -228,69 +230,76 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
if expr_type not in type_groups: if expr_type not in type_groups:
type_groups[expr_type] = [] type_groups[expr_type] = []
type_groups[expr_type].append(item) type_groups[expr_type].append(item)
# 计算时间跨度,自动调整显示格式 # 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data] all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates) time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示 if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator() major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7) minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示 elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1) major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12) minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示 else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M' date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6) major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1) minor_locator = mdates.HourLocator(interval=1)
# 创建图形 # 创建图形
fig, ax = plt.subplots(figsize=(12, 8)) fig, ax = plt.subplots(figsize=(12, 8))
# 为每个类型分配不同颜色 # 为每个类型分配不同颜色
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups))) colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
for i, (expr_type, type_data) in enumerate(type_groups.items()): for i, (expr_type, type_data) in enumerate(type_groups.items()):
create_dates = [item[0] for item in type_data] create_dates = [item[0] for item in type_data]
counts = [item[1] for item in type_data] counts = [item[1] for item in type_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates] dates = [datetime.fromtimestamp(ts) for ts in create_dates]
ax.scatter(dates, counts, alpha=0.7, s=40, ax.scatter(
c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)", dates,
edgecolors='black', linewidth=0.5) counts,
alpha=0.7,
s=40,
c=[colors[i]],
label=f"{expr_type} ({len(type_data)}个)",
edgecolors="black",
linewidth=0.5,
)
# 设置标签和标题 # 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12) ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12) ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold') ax.set_title("按表达式类型分组的散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整 # 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format)) ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator) ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator) ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45) plt.xticks(rotation=45)
# 添加图例 # 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
# 添加网格 # 添加网格
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
# 调整布局 # 调整布局
plt.tight_layout() plt.tight_layout()
# 显示统计信息 # 显示统计信息
print(f"\n=== 类型统计 ===") print("\n=== 类型统计 ===")
for expr_type, type_data in type_groups.items(): for expr_type, type_data in type_groups.items():
counts = [item[1] for item in type_data] counts = [item[1] for item in type_data]
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}") print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片 # 保存图片
if save_path: if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n类型散点图已保存到: {save_path}") print(f"\n类型散点图已保存到: {save_path}")
# 显示图片 # 显示图片
plt.show() plt.show()
@ -298,35 +307,35 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
def main(): def main():
"""主函数""" """主函数"""
print("开始分析表达式数据...") print("开始分析表达式数据...")
# 获取数据 # 获取数据
data = get_expression_data() data = get_expression_data()
if not data: if not data:
print("没有找到有效的表达式数据create_date不为空的数据") print("没有找到有效的表达式数据create_date不为空的数据")
return return
print(f"找到 {len(data)} 条有效数据") print(f"找到 {len(data)} 条有效数据")
# 创建输出目录 # 创建输出目录
output_dir = os.path.join(project_root, "data", "temp") output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# 生成时间戳用于文件名 # 生成时间戳用于文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 1. 创建基础散点图 # 1. 创建基础散点图
print("\n1. 创建基础散点图...") print("\n1. 创建基础散点图...")
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png")) create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
# 2. 创建按聊天分组的散点图 # 2. 创建按聊天分组的散点图
print("\n2. 创建按聊天分组的散点图...") print("\n2. 创建按聊天分组的散点图...")
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png")) create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
# 3. 创建按类型分组的散点图 # 3. 创建按类型分组的散点图
print("\n3. 创建按类型分组的散点图...") print("\n3. 创建按类型分组的散点图...")
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png")) create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
print("\n分析完成!") print("\n分析完成!")

View File

@ -945,9 +945,7 @@ class EmojiManager:
prompt, image_base64, "jpg", temperature=0.5 prompt, image_base64, "jpg", temperature=0.5
) )
else: else:
prompt = ( prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析精简回答"
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析精简回答"
)
description, _ = await self.vlm.generate_response_for_image( description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.5 prompt, image_base64, image_format, temperature=0.5
) )

View File

@ -12,6 +12,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import frequency_api from src.plugin_system.apis import frequency_api
def init_prompt(): def init_prompt():
Prompt( Prompt(
"""{name_block} """{name_block}
@ -28,7 +29,7 @@ def init_prompt():
""", """,
"frequency_adjust_prompt", "frequency_adjust_prompt",
) )
logger = get_logger("frequency_control") logger = get_logger("frequency_control")
@ -40,7 +41,7 @@ class FrequencyControl:
self.chat_id = chat_id self.chat_id = chat_id
# 发言频率调整值 # 发言频率调整值
self.talk_frequency_adjust: float = 1.0 self.talk_frequency_adjust: float = 1.0
self.last_frequency_adjust_time: float = 0.0 self.last_frequency_adjust_time: float = 0.0
self.frequency_model = LLMRequest( self.frequency_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust" model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust"
@ -53,16 +54,14 @@ class FrequencyControl:
def set_talk_frequency_adjust(self, value: float) -> None: def set_talk_frequency_adjust(self, value: float) -> None:
"""设置发言频率调整值""" """设置发言频率调整值"""
self.talk_frequency_adjust = max(0.1, min(5.0, value)) self.talk_frequency_adjust = max(0.1, min(5.0, value))
async def trigger_frequency_adjust(self) -> None: async def trigger_frequency_adjust(self) -> None:
msg_list = get_raw_msg_by_timestamp_with_chat( msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_frequency_adjust_time, timestamp_start=self.last_frequency_adjust_time,
timestamp_end=time.time(), timestamp_end=time.time(),
) )
if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20: if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20:
return return
else: else:
@ -73,7 +72,7 @@ class FrequencyControl:
limit=20, limit=20,
limit_mode="latest", limit_mode="latest",
) )
message_str = build_readable_messages( message_str = build_readable_messages(
new_msg_list, new_msg_list,
replace_bot_name=True, replace_bot_name=True,
@ -97,15 +96,15 @@ class FrequencyControl:
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async( response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
prompt, prompt,
) )
# logger.info(f"频率调整 prompt: {prompt}") # logger.info(f"频率调整 prompt: {prompt}")
# logger.info(f"频率调整 response: {response}") # logger.info(f"频率调整 response: {response}")
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.info(f"频率调整 prompt: {prompt}") logger.info(f"频率调整 prompt: {prompt}")
logger.info(f"频率调整 response: {response}") logger.info(f"频率调整 response: {response}")
logger.info(f"频率调整 reasoning_content: {reasoning_content}") logger.info(f"频率调整 reasoning_content: {reasoning_content}")
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id) final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
# LLM依然输出过多内容时取消本次调整。合法最多4个字但有的模型可能会输出一些markdown换行符等需要长度宽限 # LLM依然输出过多内容时取消本次调整。合法最多4个字但有的模型可能会输出一些markdown换行符等需要长度宽限
@ -118,7 +117,8 @@ class FrequencyControl:
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2)) self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
self.last_frequency_adjust_time = time.time() self.last_frequency_adjust_time = time.time()
else: else:
logger.info(f"频率调整response不符合要求取消本次调整") logger.info("频率调整response不符合要求取消本次调整")
class FrequencyControlManager: class FrequencyControlManager:
"""频率控制管理器,管理多个聊天流的频率控制实例""" """频率控制管理器,管理多个聊天流的频率控制实例"""
@ -143,6 +143,7 @@ class FrequencyControlManager:
"""获取所有有频率控制的聊天ID""" """获取所有有频率控制的聊天ID"""
return list(self.frequency_control_dict.keys()) return list(self.frequency_control_dict.keys())
init_prompt() init_prompt()
# 创建全局实例 # 创建全局实例

View File

@ -1,5 +1,4 @@
import asyncio import asyncio
from multiprocessing import context
import time import time
import traceback import traceback
import random import random
@ -102,14 +101,14 @@ class HeartFChatting:
self.is_mute = False self.is_mute = False
self.last_active_time = time.time() # 记录上一次非noreply时间 self.last_active_time = time.time() # 记录上一次非noreply时间
self.question_probability_multiplier = 1 self.question_probability_multiplier = 1
self.questioned = False self.questioned = False
# 跟踪连续 no_reply 次数,用于动态调整阈值 # 跟踪连续 no_reply 次数,用于动态调整阈值
self.consecutive_no_reply_count = 0 self.consecutive_no_reply_count = 0
# 聊天内容概括器 # 聊天内容概括器
self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id) self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id)
@ -127,10 +126,10 @@ class HeartFChatting:
self._loop_task = asyncio.create_task(self._main_chat_loop()) self._loop_task = asyncio.create_task(self._main_chat_loop())
self._loop_task.add_done_callback(self._handle_loop_completion) self._loop_task.add_done_callback(self._handle_loop_completion)
# 启动聊天内容概括器的后台定期检查循环 # 启动聊天内容概括器的后台定期检查循环
await self.chat_history_summarizer.start() await self.chat_history_summarizer.start()
logger.info(f"{self.log_prefix} HeartFChatting 启动完成") logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
except Exception as e: except Exception as e:
@ -180,7 +179,7 @@ class HeartFChatting:
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "") + (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
) )
async def _loopbody(self): async def _loopbody(self):
recent_messages_list = message_api.get_messages_by_time_in_chat( recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id, chat_id=self.stream_id,
start_time=self.last_read_time, start_time=self.last_read_time,
@ -191,9 +190,6 @@ class HeartFChatting:
filter_command=True, filter_command=True,
) )
# 根据连续 no_reply 次数动态调整阈值 # 根据连续 no_reply 次数动态调整阈值
# 3次 no_reply 时,阈值调高到 1.550%概率为150%概率为2 # 3次 no_reply 时,阈值调高到 1.550%概率为150%概率为2
# 5次 no_reply 时,提高到 2大于等于两条消息的阈值 # 5次 no_reply 时,提高到 2大于等于两条消息的阈值
@ -204,10 +200,10 @@ class HeartFChatting:
threshold = 2 if random.random() < 0.5 else 1 threshold = 2 if random.random() < 0.5 else 1
else: else:
threshold = 1 threshold = 1
if len(recent_messages_list) >= threshold: if len(recent_messages_list) >= threshold:
# for message in recent_messages_list: # for message in recent_messages_list:
# print(message.processed_plain_text) # print(message.processed_plain_text)
# !处理no_reply_until_call逻辑 # !处理no_reply_until_call逻辑
if self.no_reply_until_call: if self.no_reply_until_call:
for message in recent_messages_list: for message in recent_messages_list:
@ -337,7 +333,7 @@ class HeartFChatting:
# 重置连续 no_reply 计数 # 重置连续 no_reply 计数
self.consecutive_no_reply_count = 0 self.consecutive_no_reply_count = 0
reason = "有人提到了你,进行回复" reason = "有人提到了你,进行回复"
await database_api.store_action_info( await database_api.store_action_info(
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
action_build_into_prompt=False, action_build_into_prompt=False,
@ -395,15 +391,16 @@ class HeartFChatting:
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if ) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
if recent_messages_list is None: if recent_messages_list is None:
recent_messages_list = [] recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError _reply_text = "" # 初始化reply_text变量避免UnboundLocalError
start_time = time.time() start_time = time.time()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
asyncio.create_task(self.expression_learner.trigger_learning_for_chat()) asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
asyncio.create_task(frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()) asyncio.create_task(
frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()
)
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容 # 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
# asyncio.create_task(check_and_make_question(self.stream_id)) # asyncio.create_task(check_and_make_question(self.stream_id))
# 添加jargon提取任务 - 提取聊天中的黑话/俚语并入库(内部自行取消息并带冷却) # 添加jargon提取任务 - 提取聊天中的黑话/俚语并入库(内部自行取消息并带冷却)
@ -411,8 +408,7 @@ class HeartFChatting:
# 添加聊天内容概括任务 - 累积、打包和压缩聊天记录 # 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
# 注意后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理 # 注意后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
# asyncio.create_task(self.chat_history_summarizer.process()) # asyncio.create_task(self.chat_history_summarizer.process())
cycle_timers, thinking_id = self.start_cycle() cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
@ -427,7 +423,7 @@ class HeartFChatting:
# 如果被提及让回复生成和planner并行执行 # 如果被提及让回复生成和planner并行执行
if force_reply_message: if force_reply_message:
logger.info(f"{self.log_prefix} 检测到提及回复生成与planner并行执行") logger.info(f"{self.log_prefix} 检测到提及回复生成与planner并行执行")
# 并行执行planner和回复生成 # 并行执行planner和回复生成
planner_task = asyncio.create_task( planner_task = asyncio.create_task(
self._run_planner_without_reply( self._run_planner_without_reply(
@ -457,7 +453,12 @@ class HeartFChatting:
# 处理回复结果 # 处理回复结果
if isinstance(reply_result, BaseException): if isinstance(reply_result, BaseException):
logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}") logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}")
reply_result = {"action_type": "reply", "success": False, "result": "回复生成异常", "loop_info": None} reply_result = {
"action_type": "reply",
"success": False,
"result": "回复生成异常",
"loop_info": None,
}
else: else:
# 正常流程只执行planner # 正常流程只执行planner
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
@ -516,7 +517,7 @@ class HeartFChatting:
# 并行执行所有任务 # 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True) results = await asyncio.gather(*action_tasks, return_exceptions=True)
# 如果有独立的回复结果,添加到结果列表中 # 如果有独立的回复结果,添加到结果列表中
if reply_result: if reply_result:
results = list(results) + [reply_result] results = list(results) + [reply_result]
@ -558,7 +559,7 @@ class HeartFChatting:
"taken_time": time.time(), "taken_time": time.time(),
} }
) )
reply_text = reply_text_from_reply _reply_text = reply_text_from_reply
else: else:
# 没有回复信息构建纯动作的loop_info # 没有回复信息构建纯动作的loop_info
loop_info = { loop_info = {
@ -571,7 +572,7 @@ class HeartFChatting:
"taken_time": time.time(), "taken_time": time.time(),
}, },
} }
reply_text = action_reply_text _reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers) self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers) self.print_cycle_info(cycle_timers)
@ -647,7 +648,6 @@ class HeartFChatting:
result = await action_handler.execute() result = await action_handler.execute()
success, action_text = result success, action_text = result
return success, action_text return success, action_text
except Exception as e: except Exception as e:
@ -655,8 +655,6 @@ class HeartFChatting:
traceback.print_exc() traceback.print_exc()
return False, "" return False, ""
async def _send_response( async def _send_response(
self, self,
reply_set: "ReplySetModel", reply_set: "ReplySetModel",
@ -732,7 +730,6 @@ class HeartFChatting:
action_reasoning=reason, action_reasoning=reason,
) )
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""} return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
elif action_planner_info.action_type == "no_reply_until_call": elif action_planner_info.action_type == "no_reply_until_call":
@ -753,7 +750,12 @@ class HeartFChatting:
action_name="no_reply_until_call", action_name="no_reply_until_call",
action_reasoning=reason, action_reasoning=reason,
) )
return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""} return {
"action_type": "no_reply_until_call",
"success": True,
"result": "保持沉默,直到有人直接叫的名字",
"command": "",
}
elif action_planner_info.action_type == "reply": elif action_planner_info.action_type == "reply":
# 直接当场执行reply逻辑 # 直接当场执行reply逻辑
@ -783,19 +785,16 @@ class HeartFChatting:
enable_tool=global_config.tool.enable_tool, enable_tool=global_config.tool.enable_tool,
request_type="replyer", request_type="replyer",
from_plugin=False, from_plugin=False,
reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()), reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
) )
if not success or not llm_response or not llm_response.reply_set: if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message: if action_planner_info.action_message:
logger.info( logger.info(f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败")
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
else: else:
logger.info("回复生成失败") logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None} return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
response_set = llm_response.reply_set response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply( loop_info, reply_text, _ = await self._send_and_store_reply(
@ -817,12 +816,12 @@ class HeartFChatting:
# 执行普通动作 # 执行普通动作
with Timer("动作执行", cycle_timers): with Timer("动作执行", cycle_timers):
success, result = await self._handle_action( success, result = await self._handle_action(
action = action_planner_info.action_type, action=action_planner_info.action_type,
action_reasoning = action_planner_info.action_reasoning or "", action_reasoning=action_planner_info.action_reasoning or "",
action_data = action_planner_info.action_data or {}, action_data=action_planner_info.action_data or {},
cycle_timers = cycle_timers, cycle_timers=cycle_timers,
thinking_id = thinking_id, thinking_id=thinking_id,
action_message= action_planner_info.action_message, action_message=action_planner_info.action_message,
) )
self.last_active_time = time.time() self.last_active_time = time.time()

View File

@ -13,10 +13,11 @@ from src.person_info.person_info import Person
from src.common.database.database_model import Images from src.common.database.database_model import Images
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.heart_flow.heartFC_chat import HeartFChatting pass
logger = get_logger("chat") logger = get_logger("chat")
class HeartFCMessageReceiver: class HeartFCMessageReceiver:
"""心流处理器,负责处理接收到的消息并计算兴趣度""" """心流处理器,负责处理接收到的消息并计算兴趣度"""

View File

@ -15,7 +15,6 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType from src.plugin_system.base import BaseCommand, EventType
from src.person_info.person_info import Person
# 定义日志配置 # 定义日志配置
@ -171,7 +170,11 @@ class ChatBot:
# 撤回事件打印;无法获取被撤回者则省略 # 撤回事件打印;无法获取被撤回者则省略
if sub_type == "recall": if sub_type == "recall":
op_name = getattr(op, "user_cardname", None) or getattr(op, "user_nickname", None) or str(getattr(op, "user_id", None)) op_name = (
getattr(op, "user_cardname", None)
or getattr(op, "user_nickname", None)
or str(getattr(op, "user_id", None))
)
recalled_name = None recalled_name = None
try: try:
if isinstance(recalled, dict): if isinstance(recalled, dict):
@ -189,7 +192,7 @@ class ChatBot:
logger.info(f"{op_name} 撤回了消息") logger.info(f"{op_name} 撤回了消息")
else: else:
logger.debug( logger.debug(
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op,'user_nickname',None)}({getattr(op,'user_id',None)}) " f"[notice] sub_type={sub_type} scene={scene} op={getattr(op, 'user_nickname', None)}({getattr(op, 'user_id', None)}) "
f"gid={gid} msg_id={msg_id} recalled={recalled_id}" f"gid={gid} msg_id={msg_id} recalled={recalled_id}"
) )
except Exception: except Exception:
@ -234,7 +237,6 @@ class ChatBot:
# 确保所有任务已启动 # 确保所有任务已启动
await self._ensure_started() await self._ensure_started()
if message_data["message_info"].get("group_info") is not None: if message_data["message_info"].get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str( message_data["message_info"]["group_info"]["group_id"] = str(
message_data["message_info"]["group_info"]["group_id"] message_data["message_info"]["group_info"]["group_id"]

View File

@ -143,7 +143,6 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0 self.last_obs_time_mark = 0.0
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = [] self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
def find_message_by_id( def find_message_by_id(
@ -306,7 +305,9 @@ class ActionPlanner:
loop_start_time=loop_start_time, loop_start_time=loop_start_time,
) )
logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}") logger.info(
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
self.add_plan_log(reasoning, actions) self.add_plan_log(reasoning, actions)
@ -316,7 +317,7 @@ class ActionPlanner:
self.plan_log.append((reasoning, time.time(), actions)) self.plan_log.append((reasoning, time.time(), actions))
if len(self.plan_log) > 20: if len(self.plan_log) > 20:
self.plan_log.pop(0) self.plan_log.pop(0)
def add_plan_excute_log(self, result: str): def add_plan_excute_log(self, result: str):
self.plan_log.append(("", time.time(), result)) self.plan_log.append(("", time.time(), result))
if len(self.plan_log) > 20: if len(self.plan_log) > 20:
@ -325,17 +326,17 @@ class ActionPlanner:
def get_plan_log_str(self, max_action_records: int = 2, max_execution_records: int = 5) -> str: def get_plan_log_str(self, max_action_records: int = 2, max_execution_records: int = 5) -> str:
""" """
获取计划日志字符串 获取计划日志字符串
Args: Args:
max_action_records: 显示多少条最新的action记录默认2 max_action_records: 显示多少条最新的action记录默认2
max_execution_records: 显示多少条最新执行结果记录默认8 max_execution_records: 显示多少条最新执行结果记录默认8
Returns: Returns:
格式化的日志字符串 格式化的日志字符串
""" """
action_records = [] action_records = []
execution_records = [] execution_records = []
# 从后往前遍历,收集最新的记录 # 从后往前遍历,收集最新的记录
for reasoning, timestamp, content in reversed(self.plan_log): for reasoning, timestamp, content in reversed(self.plan_log):
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content): if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
@ -346,13 +347,13 @@ class ActionPlanner:
# 这是执行结果记录 # 这是执行结果记录
if len(execution_records) < max_execution_records: if len(execution_records) < max_execution_records:
execution_records.append((reasoning, timestamp, content, "execution")) execution_records.append((reasoning, timestamp, content, "execution"))
# 合并所有记录并按时间戳排序 # 合并所有记录并按时间戳排序
all_records = action_records + execution_records all_records = action_records + execution_records
all_records.sort(key=lambda x: x[1]) # 按时间戳排序 all_records.sort(key=lambda x: x[1]) # 按时间戳排序
plan_log_str = "" plan_log_str = ""
# 按时间顺序添加所有记录 # 按时间顺序添加所有记录
for reasoning, timestamp, content, record_type in all_records: for reasoning, timestamp, content, record_type in all_records:
time_str = datetime.fromtimestamp(timestamp).strftime("%H:%M:%S") time_str = datetime.fromtimestamp(timestamp).strftime("%H:%M:%S")
@ -361,21 +362,21 @@ class ActionPlanner:
plan_log_str += f"{time_str}:{reasoning}\n" plan_log_str += f"{time_str}:{reasoning}\n"
else: else:
plan_log_str += f"{time_str}:你执行了action:{content}\n" plan_log_str += f"{time_str}:你执行了action:{content}\n"
return plan_log_str return plan_log_str
def _has_consecutive_no_reply(self, min_count: int = 3) -> bool: def _has_consecutive_no_reply(self, min_count: int = 3) -> bool:
""" """
检查是否有连续min_count次以上的no_reply 检查是否有连续min_count次以上的no_reply
Args: Args:
min_count: 需要连续的最少次数默认3 min_count: 需要连续的最少次数默认3
Returns: Returns:
如果有连续min_count次以上no_reply返回True否则返回False 如果有连续min_count次以上no_reply返回True否则返回False
""" """
consecutive_count = 0 consecutive_count = 0
# 从后往前遍历plan_log检查最新的连续记录 # 从后往前遍历plan_log检查最新的连续记录
for _reasoning, _timestamp, content in reversed(self.plan_log): for _reasoning, _timestamp, content in reversed(self.plan_log):
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content): if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
@ -387,7 +388,7 @@ class ActionPlanner:
else: else:
# 如果遇到非no_reply的action重置计数 # 如果遇到非no_reply的action重置计数
break break
return False return False
async def build_planner_prompt( async def build_planner_prompt(
@ -402,8 +403,7 @@ class ActionPlanner:
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]: ) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)""" """构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try: try:
actions_before_now_block = self.get_plan_log_str()
actions_before_now_block=self.get_plan_log_str()
# 构建聊天上下文描述 # 构建聊天上下文描述
chat_context_description = "你现在正在一个群聊中" chat_context_description = "你现在正在一个群聊中"
@ -537,7 +537,7 @@ class ActionPlanner:
for require_item in action_info.action_require: for require_item in action_info.action_require:
require_text += f"- {require_item}\n" require_text += f"- {require_item}\n"
require_text = require_text.rstrip("\n") require_text = require_text.rstrip("\n")
if not action_info.parallel_action: if not action_info.parallel_action:
parallel_text = "(当选择这个动作时,请不要选择其他动作)" parallel_text = "(当选择这个动作时,请不要选择其他动作)"
else: else:
@ -564,7 +564,7 @@ class ActionPlanner:
filtered_actions: Dict[str, ActionInfo], filtered_actions: Dict[str, ActionInfo],
available_actions: Dict[str, ActionInfo], available_actions: Dict[str, ActionInfo],
loop_start_time: float, loop_start_time: float,
) -> Tuple[str,List[ActionPlannerInfo]]: ) -> Tuple[str, List[ActionPlannerInfo]]:
"""执行主规划器""" """执行主规划器"""
llm_content = None llm_content = None
actions: List[ActionPlannerInfo] = [] actions: List[ActionPlannerInfo] = []
@ -589,7 +589,7 @@ class ActionPlanner:
except Exception as req_e: except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return f"LLM 请求失败,模型出现问题: {req_e}",[ return f"LLM 请求失败,模型出现问题: {req_e}", [
ActionPlannerInfo( ActionPlannerInfo(
action_type="no_reply", action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}", reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
@ -608,7 +608,11 @@ class ActionPlanner:
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象") logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
filtered_actions_list = list(filtered_actions.items()) filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects: for json_obj in json_objects:
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list, extracted_reasoning)) actions.extend(
self._parse_single_action(
json_obj, message_id_list, filtered_actions_list, extracted_reasoning
)
)
else: else:
# 尝试解析为直接的JSON # 尝试解析为直接的JSON
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}") logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
@ -631,7 +635,7 @@ class ActionPlanner:
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}") logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
return extracted_reasoning,actions return extracted_reasoning, actions
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]: def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
"""创建no_reply""" """创建no_reply"""
@ -674,7 +678,7 @@ class ActionPlanner:
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释 json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
if json_str := json_str.strip(): if json_str := json_str.strip():
# 尝试按行分割每行可能是一个JSON对象 # 尝试按行分割每行可能是一个JSON对象
lines = [line.strip() for line in json_str.split('\n') if line.strip()] lines = [line.strip() for line in json_str.split("\n") if line.strip()]
for line in lines: for line in lines:
try: try:
# 尝试解析每一行作为独立的JSON对象 # 尝试解析每一行作为独立的JSON对象
@ -688,7 +692,7 @@ class ActionPlanner:
except json.JSONDecodeError: except json.JSONDecodeError:
# 如果单行解析失败尝试将整个块作为一个JSON对象或数组 # 如果单行解析失败尝试将整个块作为一个JSON对象或数组
pass pass
# 如果按行解析没有成功尝试将整个块作为一个JSON对象或数组 # 如果按行解析没有成功尝试将整个块作为一个JSON对象或数组
if not json_objects: if not json_objects:
json_obj = json.loads(repair_json(json_str)) json_obj = json.loads(repair_json(json_str))

View File

@ -134,12 +134,12 @@ class DefaultReplyer:
try: try:
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt) content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
# logger.debug(f"replyer生成内容: {content}") # logger.debug(f"replyer生成内容: {content}")
logger.info(f"replyer生成内容: {content}") logger.info(f"replyer生成内容: {content}")
if global_config.debug.show_replyer_reasoning: if global_config.debug.show_replyer_reasoning:
logger.info(f"replyer生成推理:\n{reasoning_content}") logger.info(f"replyer生成推理:\n{reasoning_content}")
logger.info(f"replyer生成模型: {model_name}") logger.info(f"replyer生成模型: {model_name}")
llm_response.content = content llm_response.content = content
llm_response.reasoning = reasoning_content llm_response.reasoning = reasoning_content
llm_response.model = model_name llm_response.model = model_name
@ -268,14 +268,13 @@ class DefaultReplyer:
expression_habits_block += f"{style_habits_str}\n" expression_habits_block += f"{style_habits_str}\n"
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_mood_state_prompt(self) -> str: async def build_mood_state_prompt(self) -> str:
"""构建情绪状态提示""" """构建情绪状态提示"""
if not global_config.mood.enable_mood: if not global_config.mood.enable_mood:
return "" return ""
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood() mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}" return f"你现在的心情是:{mood_state}"
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块 """构建工具信息块
@ -303,7 +302,7 @@ class DefaultReplyer:
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result") _result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}】: {content}\n" tool_info_str += f"- 【{tool_name}】: {content}\n"
@ -343,45 +342,45 @@ class DefaultReplyer:
def _replace_picids_with_descriptions(self, text: str) -> str: def _replace_picids_with_descriptions(self, text: str) -> str:
"""将文本中的[picid:xxx]替换为具体的图片描述 """将文本中的[picid:xxx]替换为具体的图片描述
Args: Args:
text: 包含picid标记的文本 text: 包含picid标记的文本
Returns: Returns:
替换后的文本 替换后的文本
""" """
# 匹配 [picid:xxxxx] 格式 # 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]" pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str: def replace_pic_id(match: re.Match) -> str:
pic_id = match.group(1) pic_id = match.group(1)
description = translate_pid_to_description(pic_id) description = translate_pid_to_description(pic_id)
return f"[图片:{description}]" return f"[图片:{description}]"
return re.sub(pic_pattern, replace_pic_id, text) return re.sub(pic_pattern, replace_pic_id, text)
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]: def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
"""分析target内容类型基于原始picid格式 """分析target内容类型基于原始picid格式
Args: Args:
target: 目标消息内容包含[picid:xxx]格式 target: 目标消息内容包含[picid:xxx]格式
Returns: Returns:
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分) Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
""" """
if not target or not target.strip(): if not target or not target.strip():
return False, False, "", "" return False, False, "", ""
# 检查是否只包含picid标记 # 检查是否只包含picid标记
picid_pattern = r"\[picid:[^\]]+\]" picid_pattern = r"\[picid:[^\]]+\]"
picid_matches = re.findall(picid_pattern, target) picid_matches = re.findall(picid_pattern, target)
# 移除所有picid标记后检查是否还有文字内容 # 移除所有picid标记后检查是否还有文字内容
text_without_picids = re.sub(picid_pattern, "", target).strip() text_without_picids = re.sub(picid_pattern, "", target).strip()
has_only_pics = len(picid_matches) > 0 and not text_without_picids has_only_pics = len(picid_matches) > 0 and not text_without_picids
has_text = bool(text_without_picids) has_text = bool(text_without_picids)
# 提取图片部分(转换为[图片:描述]格式) # 提取图片部分(转换为[图片:描述]格式)
pic_part = "" pic_part = ""
if picid_matches: if picid_matches:
@ -396,7 +395,7 @@ class DefaultReplyer:
else: else:
pic_descriptions.append(f"[图片:{description}]") pic_descriptions.append(f"[图片:{description}]")
pic_part = "".join(pic_descriptions) pic_part = "".join(pic_descriptions)
return has_only_pics, has_text, pic_part, text_without_picids return has_only_pics, has_text, pic_part, text_without_picids
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
@ -481,7 +480,7 @@ class DefaultReplyer:
) )
return all_dialogue_prompt return all_dialogue_prompt
def core_background_build_chat_history_prompts( def core_background_build_chat_history_prompts(
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]: ) -> Tuple[str, str]:
@ -603,25 +602,27 @@ class DefaultReplyer:
# 获取基础personality # 获取基础personality
prompt_personality = global_config.personality.personality prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态 # 检查是否需要随机替换为状态
if (global_config.personality.states and if (
global_config.personality.state_probability > 0 and global_config.personality.states
random.random() < global_config.personality.state_probability): and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
# 随机选择一个状态替换personality # 随机选择一个状态替换personality
selected_state = random.choice(global_config.personality.states) selected_state = random.choice(global_config.personality.states)
prompt_personality = selected_state prompt_personality = selected_state
prompt_personality = f"{prompt_personality};" prompt_personality = f"{prompt_personality};"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]: def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]:
""" """
解析聊天prompt配置字符串并生成对应的 chat_id prompt内容 解析聊天prompt配置字符串并生成对应的 chat_id prompt内容
Args: Args:
chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串 chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串
Returns: Returns:
tuple: (chat_id, prompt_content)如果解析失败则返回 None tuple: (chat_id, prompt_content)如果解析失败则返回 None
""" """
@ -657,10 +658,10 @@ class DefaultReplyer:
def get_chat_prompt_for_chat(self, chat_id: str) -> str: def get_chat_prompt_for_chat(self, chat_id: str) -> str:
""" """
根据聊天流ID获取匹配的额外prompt仅匹配group类型 根据聊天流ID获取匹配的额外prompt仅匹配group类型
Args: Args:
chat_id: 聊天流ID哈希值 chat_id: 聊天流ID哈希值
Returns: Returns:
str: 匹配的额外prompt内容如果没有匹配则返回空字符串 str: 匹配的额外prompt内容如果没有匹配则返回空字符串
""" """
@ -670,21 +671,21 @@ class DefaultReplyer:
for chat_prompt_str in global_config.experimental.chat_prompts: for chat_prompt_str in global_config.experimental.chat_prompts:
if not isinstance(chat_prompt_str, str): if not isinstance(chat_prompt_str, str):
continue continue
# 解析配置字符串检查类型是否为group # 解析配置字符串检查类型是否为group
parts = chat_prompt_str.split(":", 3) parts = chat_prompt_str.split(":", 3)
if len(parts) != 4: if len(parts) != 4:
continue continue
stream_type = parts[2] stream_type = parts[2]
# 只匹配group类型 # 只匹配group类型
if stream_type != "group": if stream_type != "group":
continue continue
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str) result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
if result is None: if result is None:
continue continue
config_chat_id, prompt_content = result config_chat_id, prompt_content = result
if config_chat_id == chat_id: if config_chat_id == chat_id:
logger.debug(f"匹配到群聊prompt配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...") logger.debug(f"匹配到群聊prompt配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
@ -720,7 +721,7 @@ class DefaultReplyer:
available_actions = {} available_actions = {}
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info) _is_group_chat = bool(chat_stream.group_info)
platform = chat_stream.platform platform = chat_stream.platform
user_id = "用户ID" user_id = "用户ID"
@ -736,10 +737,10 @@ class DefaultReplyer:
target = reply_message.processed_plain_text target = reply_message.processed_plain_text
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
# 在picid替换之前分析内容类型防止prompt注入 # 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target) has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述 # 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target) target = self._replace_picids_with_descriptions(target)
@ -911,10 +912,10 @@ class DefaultReplyer:
sender, target = self._parse_reply_target(reply_to) sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
# 在picid替换之前分析内容类型防止prompt注入 # 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target) has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述 # 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target) target = self._replace_picids_with_descriptions(target)
@ -956,9 +957,7 @@ class DefaultReplyer:
) )
elif has_text and pic_part: elif has_text and pic_part:
# 既有图片又有文字 # 既有图片又有文字
reply_target_block = ( reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
else: else:
# 只包含文字 # 只包含文字
reply_target_block = ( reply_target_block = (
@ -975,7 +974,9 @@ class DefaultReplyer:
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。" reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part: elif has_text and pic_part:
# 既有图片又有文字 # 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。" reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
)
else: else:
# 只包含文字 # 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。" reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
@ -1132,6 +1133,7 @@ class DefaultReplyer:
logger.error(f"获取知识库内容时发生异常: {str(e)}") logger.error(f"获取知识库内容时发生异常: {str(e)}")
return "" return ""
def weighted_sample_no_replacement(items, weights, k) -> list: def weighted_sample_no_replacement(items, weights, k) -> list:
""" """
加权且不放回地随机抽取k个元素 加权且不放回地随机抽取k个元素

View File

@ -46,6 +46,7 @@ init_memory_retrieval_prompt()
logger = get_logger("replyer") logger = get_logger("replyer")
class PrivateReplyer: class PrivateReplyer:
def __init__( def __init__(
self, self,
@ -277,9 +278,7 @@ class PrivateReplyer:
expression_habits_block = "" expression_habits_block = ""
expression_habits_title = "" expression_habits_title = ""
if style_habits_str.strip(): if style_habits_str.strip():
expression_habits_title = ( expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
"在回复时,你可以参考以下的语言习惯,不要生硬使用:"
)
expression_habits_block += f"{style_habits_str}\n" expression_habits_block += f"{style_habits_str}\n"
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
@ -291,7 +290,6 @@ class PrivateReplyer:
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood() mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}" return f"你现在的心情是:{mood_state}"
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块 """构建工具信息块
@ -358,45 +356,45 @@ class PrivateReplyer:
def _replace_picids_with_descriptions(self, text: str) -> str: def _replace_picids_with_descriptions(self, text: str) -> str:
"""将文本中的[picid:xxx]替换为具体的图片描述 """将文本中的[picid:xxx]替换为具体的图片描述
Args: Args:
text: 包含picid标记的文本 text: 包含picid标记的文本
Returns: Returns:
替换后的文本 替换后的文本
""" """
# 匹配 [picid:xxxxx] 格式 # 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]" pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str: def replace_pic_id(match: re.Match) -> str:
pic_id = match.group(1) pic_id = match.group(1)
description = translate_pid_to_description(pic_id) description = translate_pid_to_description(pic_id)
return f"[图片:{description}]" return f"[图片:{description}]"
return re.sub(pic_pattern, replace_pic_id, text) return re.sub(pic_pattern, replace_pic_id, text)
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]: def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
"""分析target内容类型基于原始picid格式 """分析target内容类型基于原始picid格式
Args: Args:
target: 目标消息内容包含[picid:xxx]格式 target: 目标消息内容包含[picid:xxx]格式
Returns: Returns:
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分) Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
""" """
if not target or not target.strip(): if not target or not target.strip():
return False, False, "", "" return False, False, "", ""
# 检查是否只包含picid标记 # 检查是否只包含picid标记
picid_pattern = r"\[picid:[^\]]+\]" picid_pattern = r"\[picid:[^\]]+\]"
picid_matches = re.findall(picid_pattern, target) picid_matches = re.findall(picid_pattern, target)
# 移除所有picid标记后检查是否还有文字内容 # 移除所有picid标记后检查是否还有文字内容
text_without_picids = re.sub(picid_pattern, "", target).strip() text_without_picids = re.sub(picid_pattern, "", target).strip()
has_only_pics = len(picid_matches) > 0 and not text_without_picids has_only_pics = len(picid_matches) > 0 and not text_without_picids
has_text = bool(text_without_picids) has_text = bool(text_without_picids)
# 提取图片部分(转换为[图片:描述]格式) # 提取图片部分(转换为[图片:描述]格式)
pic_part = "" pic_part = ""
if picid_matches: if picid_matches:
@ -411,7 +409,7 @@ class PrivateReplyer:
else: else:
pic_descriptions.append(f"[图片:{description}]") pic_descriptions.append(f"[图片:{description}]")
pic_part = "".join(pic_descriptions) pic_part = "".join(pic_descriptions)
return has_only_pics, has_text, pic_part, text_without_picids return has_only_pics, has_text, pic_part, text_without_picids
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
@ -517,25 +515,27 @@ class PrivateReplyer:
# 获取基础personality # 获取基础personality
prompt_personality = global_config.personality.personality prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态 # 检查是否需要随机替换为状态
if (global_config.personality.states and if (
global_config.personality.state_probability > 0 and global_config.personality.states
random.random() < global_config.personality.state_probability): and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
# 随机选择一个状态替换personality # 随机选择一个状态替换personality
selected_state = random.choice(global_config.personality.states) selected_state = random.choice(global_config.personality.states)
prompt_personality = selected_state prompt_personality = selected_state
prompt_personality = f"{prompt_personality};" prompt_personality = f"{prompt_personality};"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]: def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]:
""" """
解析聊天prompt配置字符串并生成对应的 chat_id prompt内容 解析聊天prompt配置字符串并生成对应的 chat_id prompt内容
Args: Args:
chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串 chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串
Returns: Returns:
tuple: (chat_id, prompt_content)如果解析失败则返回 None tuple: (chat_id, prompt_content)如果解析失败则返回 None
""" """
@ -571,10 +571,10 @@ class PrivateReplyer:
def get_chat_prompt_for_chat(self, chat_id: str) -> str: def get_chat_prompt_for_chat(self, chat_id: str) -> str:
""" """
根据聊天流ID获取匹配的额外prompt仅匹配private类型 根据聊天流ID获取匹配的额外prompt仅匹配private类型
Args: Args:
chat_id: 聊天流ID哈希值 chat_id: 聊天流ID哈希值
Returns: Returns:
str: 匹配的额外prompt内容如果没有匹配则返回空字符串 str: 匹配的额外prompt内容如果没有匹配则返回空字符串
""" """
@ -584,21 +584,21 @@ class PrivateReplyer:
for chat_prompt_str in global_config.experimental.chat_prompts: for chat_prompt_str in global_config.experimental.chat_prompts:
if not isinstance(chat_prompt_str, str): if not isinstance(chat_prompt_str, str):
continue continue
# 解析配置字符串检查类型是否为private # 解析配置字符串检查类型是否为private
parts = chat_prompt_str.split(":", 3) parts = chat_prompt_str.split(":", 3)
if len(parts) != 4: if len(parts) != 4:
continue continue
stream_type = parts[2] stream_type = parts[2]
# 只匹配private类型 # 只匹配private类型
if stream_type != "private": if stream_type != "private":
continue continue
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str) result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
if result is None: if result is None:
continue continue
config_chat_id, prompt_content = result config_chat_id, prompt_content = result
if config_chat_id == chat_id: if config_chat_id == chat_id:
logger.debug(f"匹配到私聊prompt配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...") logger.debug(f"匹配到私聊prompt配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
@ -647,13 +647,11 @@ class PrivateReplyer:
sender = person_name sender = person_name
target = reply_message.processed_plain_text target = reply_message.processed_plain_text
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
# 在picid替换之前分析内容类型防止prompt注入 # 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target) has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述 # 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target) target = self._replace_picids_with_descriptions(target)
@ -662,7 +660,7 @@ class PrivateReplyer:
timestamp=time.time(), timestamp=time.time(),
limit=global_config.chat.max_context_size, limit=global_config.chat.max_context_size,
) )
dialogue_prompt = build_readable_messages( dialogue_prompt = build_readable_messages(
message_list_before_now_long, message_list_before_now_long,
replace_bot_name=True, replace_bot_name=True,
@ -710,9 +708,7 @@ class PrivateReplyer:
self._time_and_run_task( self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits" self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
), ),
self._time_and_run_task( self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"),
self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"
),
self._time_and_run_task( self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info" self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
), ),
@ -852,15 +848,13 @@ class PrivateReplyer:
sender, target = self._parse_reply_target(reply_to) sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
# 在picid替换之前分析内容类型防止prompt注入 # 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target) has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述 # 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target) target = self._replace_picids_with_descriptions(target)
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id,
timestamp=time.time(), timestamp=time.time(),
@ -900,9 +894,7 @@ class PrivateReplyer:
) )
elif has_text and pic_part: elif has_text and pic_part:
# 既有图片又有文字 # 既有图片又有文字
reply_target_block = ( reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
else: else:
# 只包含文字 # 只包含文字
reply_target_block = ( reply_target_block = (
@ -919,7 +911,9 @@ class PrivateReplyer:
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。" reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part: elif has_text and pic_part:
# 既有图片又有文字 # 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。" reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
)
else: else:
# 只包含文字 # 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。" reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
@ -1010,7 +1004,7 @@ class PrivateReplyer:
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async( content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt prompt
) )
content = content.strip() content = content.strip()
logger.info(f"使用 {model_name} 生成回复内容: {content}") logger.info(f"使用 {model_name} 生成回复内容: {content}")
@ -1106,6 +1100,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
pool.pop(idx) pool.pop(idx)
break break
return selected return selected

View File

@ -1,16 +1,13 @@
from src.chat.utils.prompt_builder import Prompt from src.chat.utils.prompt_builder import Prompt
# from src.chat.memory_system.memory_activator import MemoryActivator # from src.chat.memory_system.memory_activator import MemoryActivator
def init_replyer_prompt(): def init_replyer_prompt():
Prompt("正在群里聊天", "chat_target_group2") Prompt("正在群里聊天", "chat_target_group2")
Prompt("{sender_name}聊天", "chat_target_private2") Prompt("{sender_name}聊天", "chat_target_private2")
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval} {expression_habits_block}{memory_retrieval}
你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片: 你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片:
@ -27,10 +24,9 @@ def init_replyer_prompt():
现在你说""", 现在你说""",
"replyer_prompt", "replyer_prompt",
) )
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval} {expression_habits_block}{memory_retrieval}
你正在和{sender_name}聊天这是你们之前聊的内容: 你正在和{sender_name}聊天这是你们之前聊的内容:
@ -46,10 +42,9 @@ def init_replyer_prompt():
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )""", {moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )""",
"private_replyer_prompt", "private_replyer_prompt",
) )
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval} {expression_habits_block}{memory_retrieval}
你正在和{sender_name}聊天这是你们之前聊的内容: 你正在和{sender_name}聊天这是你们之前聊的内容:
@ -65,4 +60,4 @@ def init_replyer_prompt():
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 ) {moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )
""", """,
"private_replyer_self_prompt", "private_replyer_self_prompt",
) )

View File

@ -2,6 +2,7 @@
聊天内容概括器 聊天内容概括器
用于累积打包和压缩聊天记录 用于累积打包和压缩聊天记录
""" """
import asyncio import asyncio
import json import json
import time import time
@ -23,6 +24,7 @@ logger = get_logger("chat_history_summarizer")
@dataclass @dataclass
class MessageBatch: class MessageBatch:
"""消息批次""" """消息批次"""
messages: List[DatabaseMessages] messages: List[DatabaseMessages]
start_time: float start_time: float
end_time: float end_time: float
@ -31,11 +33,11 @@ class MessageBatch:
class ChatHistorySummarizer: class ChatHistorySummarizer:
"""聊天内容概括器""" """聊天内容概括器"""
def __init__(self, chat_id: str, check_interval: int = 60): def __init__(self, chat_id: str, check_interval: int = 60):
""" """
初始化聊天内容概括器 初始化聊天内容概括器
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
check_interval: 定期检查间隔默认60秒 check_interval: 定期检查间隔默认60秒
@ -43,24 +45,23 @@ class ChatHistorySummarizer:
self.chat_id = chat_id self.chat_id = chat_id
self._chat_display_name = self._get_chat_display_name() self._chat_display_name = self._get_chat_display_name()
self.log_prefix = f"[{self._chat_display_name}]" self.log_prefix = f"[{self._chat_display_name}]"
# 记录时间点,用于计算新消息 # 记录时间点,用于计算新消息
self.last_check_time = time.time() self.last_check_time = time.time()
# 当前累积的消息批次 # 当前累积的消息批次
self.current_batch: Optional[MessageBatch] = None self.current_batch: Optional[MessageBatch] = None
# LLM请求器用于压缩聊天内容 # LLM请求器用于压缩聊天内容
self.summarizer_llm = LLMRequest( self.summarizer_llm = LLMRequest(
model_set=model_config.model_task_config.utils, model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
request_type="chat_history_summarizer"
) )
# 后台循环相关 # 后台循环相关
self.check_interval = check_interval # 检查间隔(秒) self.check_interval = check_interval # 检查间隔(秒)
self._periodic_task: Optional[asyncio.Task] = None self._periodic_task: Optional[asyncio.Task] = None
self._running = False self._running = False
def _get_chat_display_name(self) -> str: def _get_chat_display_name(self) -> str:
"""获取聊天显示名称""" """获取聊天显示名称"""
try: try:
@ -76,17 +77,17 @@ class ChatHistorySummarizer:
if len(self.chat_id) > 20: if len(self.chat_id) > 20:
return f"{self.chat_id[:8]}..." return f"{self.chat_id[:8]}..."
return self.chat_id return self.chat_id
async def process(self, current_time: Optional[float] = None): async def process(self, current_time: Optional[float] = None):
""" """
处理聊天内容概括 处理聊天内容概括
Args: Args:
current_time: 当前时间戳如果为None则使用time.time() current_time: 当前时间戳如果为None则使用time.time()
""" """
if current_time is None: if current_time is None:
current_time = time.time() current_time = time.time()
try: try:
logger.info( logger.info(
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}" f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
@ -101,25 +102,23 @@ class ChatHistorySummarizer:
filter_mai=False, # 不过滤bot消息因为需要检查bot是否发言 filter_mai=False, # 不过滤bot消息因为需要检查bot是否发言
filter_command=False, filter_command=False,
) )
if not new_messages: if not new_messages:
# 没有新消息,检查是否需要打包 # 没有新消息,检查是否需要打包
if self.current_batch and self.current_batch.messages: if self.current_batch and self.current_batch.messages:
await self._check_and_package(current_time) await self._check_and_package(current_time)
self.last_check_time = current_time self.last_check_time = current_time
return return
# 有新消息,更新最后检查时间 # 有新消息,更新最后检查时间
self.last_check_time = current_time self.last_check_time = current_time
# 如果有当前批次,添加新消息 # 如果有当前批次,添加新消息
if self.current_batch: if self.current_batch:
before_count = len(self.current_batch.messages) before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages) self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time self.current_batch.end_time = current_time
logger.info( logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
else: else:
# 创建新批次 # 创建新批次
self.current_batch = MessageBatch( self.current_batch = MessageBatch(
@ -127,23 +126,22 @@ class ChatHistorySummarizer:
start_time=new_messages[0].time if new_messages else current_time, start_time=new_messages[0].time if new_messages else current_time,
end_time=current_time, end_time=current_time,
) )
logger.info( logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息"
)
# 检查是否需要打包 # 检查是否需要打包
await self._check_and_package(current_time) await self._check_and_package(current_time)
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}") logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
async def _check_and_package(self, current_time: float): async def _check_and_package(self, current_time: float):
"""检查是否需要打包""" """检查是否需要打包"""
if not self.current_batch or not self.current_batch.messages: if not self.current_batch or not self.current_batch.messages:
return return
messages = self.current_batch.messages messages = self.current_batch.messages
message_count = len(messages) message_count = len(messages)
last_message_time = messages[-1].time if messages else current_time last_message_time = messages[-1].time if messages else current_time
@ -153,48 +151,48 @@ class ChatHistorySummarizer:
if time_since_last_message < 60: if time_since_last_message < 60:
time_str = f"{time_since_last_message:.1f}" time_str = f"{time_since_last_message:.1f}"
elif time_since_last_message < 3600: elif time_since_last_message < 3600:
time_str = f"{time_since_last_message/60:.1f}分钟" time_str = f"{time_since_last_message / 60:.1f}分钟"
else: else:
time_str = f"{time_since_last_message/3600:.1f}小时" time_str = f"{time_since_last_message / 3600:.1f}小时"
preparing_status = "" if self.current_batch.is_preparing else "" preparing_status = "" if self.current_batch.is_preparing else ""
logger.info( logger.info(
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距最后消息: {time_str} | 准备结束模式: {preparing_status}" f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距最后消息: {time_str} | 准备结束模式: {preparing_status}"
) )
# 检查打包条件 # 检查打包条件
should_package = False should_package = False
# 条件1: 消息长度超过120直接打包 # 条件1: 消息长度超过120直接打包
if message_count >= 120: if message_count >= 120:
should_package = True should_package = True
logger.info(f"{self.log_prefix} 触发打包条件: 消息数量达到 {message_count} 条(阈值: 120条") logger.info(f"{self.log_prefix} 触发打包条件: 消息数量达到 {message_count} 条(阈值: 120条")
# 条件2: 最后一条消息的时间和当前时间差>600秒直接打包 # 条件2: 最后一条消息的时间和当前时间差>600秒直接打包
elif time_since_last_message > 600: elif time_since_last_message > 600:
should_package = True should_package = True
logger.info(f"{self.log_prefix} 触发打包条件: 距最后消息 {time_str}(阈值: 10分钟") logger.info(f"{self.log_prefix} 触发打包条件: 距最后消息 {time_str}(阈值: 10分钟")
# 条件3: 消息长度超过100进入准备结束模式 # 条件3: 消息长度超过100进入准备结束模式
elif message_count > 100: elif message_count > 100:
if not self.current_batch.is_preparing: if not self.current_batch.is_preparing:
self.current_batch.is_preparing = True self.current_batch.is_preparing = True
logger.info(f"{self.log_prefix} 消息数量 {message_count} 条超过阈值100条进入准备结束模式") logger.info(f"{self.log_prefix} 消息数量 {message_count} 条超过阈值100条进入准备结束模式")
# 在准备结束模式下,如果最后一条消息的时间和当前时间差>10秒就打包 # 在准备结束模式下,如果最后一条消息的时间和当前时间差>10秒就打包
if time_since_last_message > 10: if time_since_last_message > 10:
should_package = True should_package = True
logger.info(f"{self.log_prefix} 触发打包条件: 准备结束模式下,距最后消息 {time_str}(阈值: 10秒") logger.info(f"{self.log_prefix} 触发打包条件: 准备结束模式下,距最后消息 {time_str}(阈值: 10秒")
if should_package: if should_package:
await self._package_and_store() await self._package_and_store()
async def _package_and_store(self): async def _package_and_store(self):
"""打包并存储聊天记录""" """打包并存储聊天记录"""
if not self.current_batch or not self.current_batch.messages: if not self.current_batch or not self.current_batch.messages:
return return
messages = self.current_batch.messages messages = self.current_batch.messages
start_time = self.current_batch.start_time start_time = self.current_batch.start_time
end_time = self.current_batch.end_time end_time = self.current_batch.end_time
@ -202,12 +200,12 @@ class ChatHistorySummarizer:
logger.info( logger.info(
f"{self.log_prefix} 开始打包批次 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" f"{self.log_prefix} 开始打包批次 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
) )
# 检查是否有bot发言 # 检查是否有bot发言
# 第一条消息前推600s到最后一条消息的时间内 # 第一条消息前推600s到最后一条消息的时间内
check_start_time = max(start_time - 600, 0) check_start_time = max(start_time - 600, 0)
check_end_time = end_time check_end_time = end_time
# 使用包含边界的时间范围查询 # 使用包含边界的时间范围查询
bot_messages = message_api.get_messages_by_time_in_chat_inclusive( bot_messages = message_api.get_messages_by_time_in_chat_inclusive(
chat_id=self.chat_id, chat_id=self.chat_id,
@ -218,7 +216,7 @@ class ChatHistorySummarizer:
filter_mai=False, filter_mai=False,
filter_command=False, filter_command=False,
) )
# 检查是否有bot的发言 # 检查是否有bot的发言
has_bot_message = False has_bot_message = False
bot_user_id = str(global_config.bot.qq_account) bot_user_id = str(global_config.bot.qq_account)
@ -226,14 +224,14 @@ class ChatHistorySummarizer:
if msg.user_info.user_id == bot_user_id: if msg.user_info.user_id == bot_user_id:
has_bot_message = True has_bot_message = True
break break
if not has_bot_message: if not has_bot_message:
logger.info( logger.info(
f"{self.log_prefix} 批次内无Bot发言丢弃批次 | 检查时间范围: {check_start_time:.2f} - {check_end_time:.2f}" f"{self.log_prefix} 批次内无Bot发言丢弃批次 | 检查时间范围: {check_start_time:.2f} - {check_end_time:.2f}"
) )
self.current_batch = None self.current_batch = None
return return
# 有bot发言进行压缩和存储 # 有bot发言进行压缩和存储
try: try:
# 构建对话原文 # 构建对话原文
@ -245,39 +243,36 @@ class ChatHistorySummarizer:
truncate=False, truncate=False,
show_actions=False, show_actions=False,
) )
# 获取参与的所有人的昵称 # 获取参与的所有人的昵称
participants_set: Set[str] = set() participants_set: Set[str] = set()
for msg in messages: for msg in messages:
# 使用 msg.user_platform扁平化字段或 msg.user_info.platform # 使用 msg.user_platform扁平化字段或 msg.user_info.platform
platform = getattr(msg, 'user_platform', None) or (msg.user_info.platform if msg.user_info else None) or msg.chat_info.platform platform = (
person = Person( getattr(msg, "user_platform", None)
platform=platform, or (msg.user_info.platform if msg.user_info else None)
user_id=msg.user_info.user_id or msg.chat_info.platform
) )
person = Person(platform=platform, user_id=msg.user_info.user_id)
person_name = person.person_name person_name = person.person_name
if person_name: if person_name:
participants_set.add(person_name) participants_set.add(person_name)
participants = list(participants_set) participants = list(participants_set)
logger.info( logger.info(f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}")
f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}"
)
# 使用LLM压缩聊天内容 # 使用LLM压缩聊天内容
success, theme, keywords, summary = await self._compress_with_llm(original_text) success, theme, keywords, summary = await self._compress_with_llm(original_text)
if not success: if not success:
logger.warning( logger.warning(f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}")
f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}"
)
# 清空当前批次,避免重复处理 # 清空当前批次,避免重复处理
self.current_batch = None self.current_batch = None
return return
logger.info( logger.info(
f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)}" f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)}"
) )
# 存储到数据库 # 存储到数据库
await self._store_to_database( await self._store_to_database(
start_time=start_time, start_time=start_time,
@ -288,23 +283,24 @@ class ChatHistorySummarizer:
keywords=keywords, keywords=keywords,
summary=summary, summary=summary,
) )
logger.info(f"{self.log_prefix} 成功打包并存储聊天记录 | 消息数: {len(messages)} | 主题: {theme}") logger.info(f"{self.log_prefix} 成功打包并存储聊天记录 | 消息数: {len(messages)} | 主题: {theme}")
# 清空当前批次 # 清空当前批次
self.current_batch = None self.current_batch = None
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}") logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# 出错时也清空批次,避免重复处理 # 出错时也清空批次,避免重复处理
self.current_batch = None self.current_batch = None
async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]: async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]:
""" """
使用LLM压缩聊天内容 使用LLM压缩聊天内容
Returns: Returns:
tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括) tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括)
""" """
@ -325,37 +321,37 @@ class ChatHistorySummarizer:
{original_text} {original_text}
请直接返回JSON不要包含其他内容""" 请直接返回JSON不要包含其他内容"""
try: try:
response, _ = await self.summarizer_llm.generate_response_async( response, _ = await self.summarizer_llm.generate_response_async(
prompt=prompt, prompt=prompt,
temperature=0.3, temperature=0.3,
max_tokens=500, max_tokens=500,
) )
# 解析JSON响应 # 解析JSON响应
import re import re
# 移除可能的markdown代码块标记 # 移除可能的markdown代码块标记
json_str = response.strip() json_str = response.strip()
json_str = re.sub(r'^```json\s*', '', json_str, flags=re.MULTILINE) json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
json_str = re.sub(r'^```\s*', '', json_str, flags=re.MULTILINE) json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
json_str = json_str.strip() json_str = json_str.strip()
# 尝试找到JSON对象的开始和结束位置 # 尝试找到JSON对象的开始和结束位置
# 查找第一个 { 和最后一个匹配的 } # 查找第一个 { 和最后一个匹配的 }
start_idx = json_str.find('{') start_idx = json_str.find("{")
if start_idx == -1: if start_idx == -1:
raise ValueError("未找到JSON对象开始标记") raise ValueError("未找到JSON对象开始标记")
# 从后往前查找最后一个 } # 从后往前查找最后一个 }
end_idx = json_str.rfind('}') end_idx = json_str.rfind("}")
if end_idx == -1 or end_idx <= start_idx: if end_idx == -1 or end_idx <= start_idx:
raise ValueError("未找到JSON对象结束标记") raise ValueError("未找到JSON对象结束标记")
# 提取JSON字符串 # 提取JSON字符串
json_str = json_str[start_idx:end_idx + 1] json_str = json_str[start_idx : end_idx + 1]
# 尝试解析JSON # 尝试解析JSON
try: try:
result = json.loads(json_str) result = json.loads(json_str)
@ -372,7 +368,7 @@ class ChatHistorySummarizer:
if escape_next: if escape_next:
fixed_chars.append(char) fixed_chars.append(char)
escape_next = False escape_next = False
elif char == '\\': elif char == "\\":
fixed_chars.append(char) fixed_chars.append(char)
escape_next = True escape_next = True
elif char == '"' and not escape_next: elif char == '"' and not escape_next:
@ -384,27 +380,27 @@ class ChatHistorySummarizer:
else: else:
fixed_chars.append(char) fixed_chars.append(char)
i += 1 i += 1
json_str = ''.join(fixed_chars) json_str = "".join(fixed_chars)
# 再次尝试解析 # 再次尝试解析
result = json.loads(json_str) result = json.loads(json_str)
theme = result.get("theme", "未命名对话") theme = result.get("theme", "未命名对话")
keywords = result.get("keywords", []) keywords = result.get("keywords", [])
summary = result.get("summary", "无概括") summary = result.get("summary", "无概括")
# 确保keywords是列表 # 确保keywords是列表
if isinstance(keywords, str): if isinstance(keywords, str):
keywords = [keywords] keywords = [keywords]
return True, theme, keywords, summary return True, theme, keywords, summary
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}") logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
# 返回失败标志和默认值 # 返回失败标志和默认值
return False, "未命名对话", [], "压缩失败,无法生成概括" return False, "未命名对话", [], "压缩失败,无法生成概括"
async def _store_to_database( async def _store_to_database(
self, self,
start_time: float, start_time: float,
@ -419,7 +415,7 @@ class ChatHistorySummarizer:
try: try:
from src.common.database.database_model import ChatHistory from src.common.database.database_model import ChatHistory
from src.plugin_system.apis import database_api from src.plugin_system.apis import database_api
# 准备数据 # 准备数据
data = { data = {
"chat_id": self.chat_id, "chat_id": self.chat_id,
@ -432,7 +428,7 @@ class ChatHistorySummarizer:
"summary": summary, "summary": summary,
"count": 0, "count": 0,
} }
# 使用db_save存储使用start_time和chat_id作为唯一标识 # 使用db_save存储使用start_time和chat_id作为唯一标识
# 由于可能有多条记录我们使用组合键但peewee不支持所以使用start_time作为唯一标识 # 由于可能有多条记录我们使用组合键但peewee不支持所以使用start_time作为唯一标识
# 但为了避免冲突我们使用组合键chat_id + start_time # 但为了避免冲突我们使用组合键chat_id + start_time
@ -441,28 +437,29 @@ class ChatHistorySummarizer:
ChatHistory, ChatHistory,
data=data, data=data,
) )
if saved_record: if saved_record:
logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库") logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库")
else: else:
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败") logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}") logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise raise
async def start(self): async def start(self):
"""启动后台定期检查循环""" """启动后台定期检查循环"""
if self._running: if self._running:
logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动") logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动")
return return
self._running = True self._running = True
self._periodic_task = asyncio.create_task(self._periodic_check_loop()) self._periodic_task = asyncio.create_task(self._periodic_check_loop())
logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}") logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}")
async def stop(self): async def stop(self):
"""停止后台定期检查循环""" """停止后台定期检查循环"""
self._running = False self._running = False
@ -474,14 +471,14 @@ class ChatHistorySummarizer:
pass pass
self._periodic_task = None self._periodic_task = None
logger.info(f"{self.log_prefix} 已停止后台定期检查循环") logger.info(f"{self.log_prefix} 已停止后台定期检查循环")
async def _periodic_check_loop(self): async def _periodic_check_loop(self):
"""后台定期检查循环""" """后台定期检查循环"""
try: try:
while self._running: while self._running:
# 执行一次检查 # 执行一次检查
await self.process() await self.process()
# 等待指定间隔后再次检查 # 等待指定间隔后再次检查
await asyncio.sleep(self.check_interval) await asyncio.sleep(self.check_interval)
except asyncio.CancelledError: except asyncio.CancelledError:
@ -490,6 +487,6 @@ class ChatHistorySummarizer:
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 后台检查循环出错: {e}") logger.error(f"{self.log_prefix} 后台检查循环出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
self._running = False self._running = False

View File

@ -2,7 +2,7 @@ import time
import random import random
import re import re
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install from rich.traceback import install
from src.config.config import global_config from src.config.config import global_config
@ -568,7 +568,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
output_lines = [] output_lines = []
current_time = time.time() current_time = time.time()
for action in actions: for action in actions:
action_time = action.time or current_time action_time = action.time or current_time
action_name = action.action_name or "未知动作" action_name = action.action_name or "未知动作"
@ -595,7 +594,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}" line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}"
output_lines.append(line) output_lines.append(line)
return "\n".join(output_lines) return "\n".join(output_lines)
@ -936,7 +934,6 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
return formatted_string return formatted_string
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
""" """
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身) 从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)

View File

@ -2,6 +2,7 @@
记忆遗忘任务 记忆遗忘任务
每5分钟进行一次遗忘检查根据不同的遗忘阶段删除记忆 每5分钟进行一次遗忘检查根据不同的遗忘阶段删除记忆
""" """
import time import time
import random import random
from typing import List from typing import List
@ -15,27 +16,27 @@ logger = get_logger("memory_forget_task")
class MemoryForgetTask(AsyncTask): class MemoryForgetTask(AsyncTask):
"""记忆遗忘任务每5分钟执行一次""" """记忆遗忘任务每5分钟执行一次"""
def __init__(self): def __init__(self):
# 每5分钟执行一次300秒 # 每5分钟执行一次300秒
super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300) super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)
async def run(self): async def run(self):
"""执行遗忘检查""" """执行遗忘检查"""
try: try:
current_time = time.time() current_time = time.time()
logger.info("[记忆遗忘] 开始遗忘检查...") logger.info("[记忆遗忘] 开始遗忘检查...")
# 执行4个阶段的遗忘检查 # 执行4个阶段的遗忘检查
await self._forget_stage_1(current_time) await self._forget_stage_1(current_time)
await self._forget_stage_2(current_time) await self._forget_stage_2(current_time)
await self._forget_stage_3(current_time) await self._forget_stage_3(current_time)
await self._forget_stage_4(current_time) await self._forget_stage_4(current_time)
logger.info("[记忆遗忘] 遗忘检查完成") logger.info("[记忆遗忘] 遗忘检查完成")
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True) logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
async def _forget_stage_1(self, current_time: float): async def _forget_stage_1(self, current_time: float):
""" """
第一次遗忘检查 第一次遗忘检查
@ -45,38 +46,34 @@ class MemoryForgetTask(AsyncTask):
try: try:
# 30分钟 = 1800秒 # 30分钟 = 1800秒
time_threshold = current_time - 1800 time_threshold = current_time - 1800
# 查询符合条件的记忆forget_times=0 且 end_time < time_threshold # 查询符合条件的记忆forget_times=0 且 end_time < time_threshold
candidates = list( candidates = list(
ChatHistory.select() ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
.where(
(ChatHistory.forget_times == 0) &
(ChatHistory.end_time < time_threshold)
)
) )
if not candidates: if not candidates:
logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆") logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
return return
logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆") logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序 # 按count排序
candidates.sort(key=lambda x: x.count, reverse=True) candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高25%和最低25% # 计算要删除的数量最高25%和最低25%
total_count = len(candidates) total_count = len(candidates)
delete_count = int(total_count * 0.25) # 25% delete_count = int(total_count * 0.25) # 25%
if delete_count == 0: if delete_count == 0:
logger.debug("[记忆遗忘-阶段1] 删除数量为0跳过") logger.debug("[记忆遗忘-阶段1] 删除数量为0跳过")
return return
# 选择要删除的记录处理count相同的情况随机选择 # 选择要删除的记录处理count相同的情况随机选择
to_delete = [] to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重避免重复删除使用id去重 # 去重避免重复删除使用id去重
seen_ids = set() seen_ids = set()
unique_to_delete = [] unique_to_delete = []
@ -85,7 +82,7 @@ class MemoryForgetTask(AsyncTask):
seen_ids.add(record.id) seen_ids.add(record.id)
unique_to_delete.append(record) unique_to_delete.append(record)
to_delete = unique_to_delete to_delete = unique_to_delete
# 删除记录并更新forget_times # 删除记录并更新forget_times
deleted_count = 0 deleted_count = 0
for record in to_delete: for record in to_delete:
@ -94,22 +91,22 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1 deleted_count += 1
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}") logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")
# 更新剩余记录的forget_times为1 # 更新剩余记录的forget_times为1
to_delete_ids = {r.id for r in to_delete} to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids] remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining: if remaining:
# 批量更新 # 批量更新
ids_to_update = [r.id for r in remaining] ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=1).where( ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
ChatHistory.id.in_(ids_to_update)
).execute() logger.info(
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
logger.info(f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1") )
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True) logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
async def _forget_stage_2(self, current_time: float): async def _forget_stage_2(self, current_time: float):
""" """
第二次遗忘检查 第二次遗忘检查
@ -119,41 +116,37 @@ class MemoryForgetTask(AsyncTask):
try: try:
# 8小时 = 28800秒 # 8小时 = 28800秒
time_threshold = current_time - 28800 time_threshold = current_time - 28800
# 查询符合条件的记忆forget_times=1 且 end_time < time_threshold # 查询符合条件的记忆forget_times=1 且 end_time < time_threshold
candidates = list( candidates = list(
ChatHistory.select() ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
.where(
(ChatHistory.forget_times == 1) &
(ChatHistory.end_time < time_threshold)
)
) )
if not candidates: if not candidates:
logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆") logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
return return
logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆") logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序 # 按count排序
candidates.sort(key=lambda x: x.count, reverse=True) candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高7%和最低7% # 计算要删除的数量最高7%和最低7%
total_count = len(candidates) total_count = len(candidates)
delete_count = int(total_count * 0.07) # 7% delete_count = int(total_count * 0.07) # 7%
if delete_count == 0: if delete_count == 0:
logger.debug("[记忆遗忘-阶段2] 删除数量为0跳过") logger.debug("[记忆遗忘-阶段2] 删除数量为0跳过")
return return
# 选择要删除的记录 # 选择要删除的记录
to_delete = [] to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重 # 去重
to_delete = list(set(to_delete)) to_delete = list(set(to_delete))
# 删除记录 # 删除记录
deleted_count = 0 deleted_count = 0
for record in to_delete: for record in to_delete:
@ -162,21 +155,21 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1 deleted_count += 1
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}") logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")
# 更新剩余记录的forget_times为2 # 更新剩余记录的forget_times为2
to_delete_ids = {r.id for r in to_delete} to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids] remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining: if remaining:
ids_to_update = [r.id for r in remaining] ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=2).where( ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
ChatHistory.id.in_(ids_to_update)
).execute() logger.info(
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
logger.info(f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2") )
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True) logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
async def _forget_stage_3(self, current_time: float): async def _forget_stage_3(self, current_time: float):
""" """
第三次遗忘检查 第三次遗忘检查
@ -186,41 +179,37 @@ class MemoryForgetTask(AsyncTask):
try: try:
# 48小时 = 172800秒 # 48小时 = 172800秒
time_threshold = current_time - 172800 time_threshold = current_time - 172800
# 查询符合条件的记忆forget_times=2 且 end_time < time_threshold # 查询符合条件的记忆forget_times=2 且 end_time < time_threshold
candidates = list( candidates = list(
ChatHistory.select() ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
.where(
(ChatHistory.forget_times == 2) &
(ChatHistory.end_time < time_threshold)
)
) )
if not candidates: if not candidates:
logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆") logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
return return
logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆") logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序 # 按count排序
candidates.sort(key=lambda x: x.count, reverse=True) candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高5%和最低5% # 计算要删除的数量最高5%和最低5%
total_count = len(candidates) total_count = len(candidates)
delete_count = int(total_count * 0.05) # 5% delete_count = int(total_count * 0.05) # 5%
if delete_count == 0: if delete_count == 0:
logger.debug("[记忆遗忘-阶段3] 删除数量为0跳过") logger.debug("[记忆遗忘-阶段3] 删除数量为0跳过")
return return
# 选择要删除的记录 # 选择要删除的记录
to_delete = [] to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重 # 去重
to_delete = list(set(to_delete)) to_delete = list(set(to_delete))
# 删除记录 # 删除记录
deleted_count = 0 deleted_count = 0
for record in to_delete: for record in to_delete:
@ -229,21 +218,21 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1 deleted_count += 1
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}") logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")
# 更新剩余记录的forget_times为3 # 更新剩余记录的forget_times为3
to_delete_ids = {r.id for r in to_delete} to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids] remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining: if remaining:
ids_to_update = [r.id for r in remaining] ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=3).where( ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
ChatHistory.id.in_(ids_to_update)
).execute() logger.info(
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
logger.info(f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3") )
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True) logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
async def _forget_stage_4(self, current_time: float): async def _forget_stage_4(self, current_time: float):
""" """
第四次遗忘检查 第四次遗忘检查
@ -253,41 +242,37 @@ class MemoryForgetTask(AsyncTask):
try: try:
# 7天 = 604800秒 # 7天 = 604800秒
time_threshold = current_time - 604800 time_threshold = current_time - 604800
# 查询符合条件的记忆forget_times=3 且 end_time < time_threshold # 查询符合条件的记忆forget_times=3 且 end_time < time_threshold
candidates = list( candidates = list(
ChatHistory.select() ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
.where(
(ChatHistory.forget_times == 3) &
(ChatHistory.end_time < time_threshold)
)
) )
if not candidates: if not candidates:
logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆") logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
return return
logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆") logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序 # 按count排序
candidates.sort(key=lambda x: x.count, reverse=True) candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高2%和最低2% # 计算要删除的数量最高2%和最低2%
total_count = len(candidates) total_count = len(candidates)
delete_count = int(total_count * 0.02) # 2% delete_count = int(total_count * 0.02) # 2%
if delete_count == 0: if delete_count == 0:
logger.debug("[记忆遗忘-阶段4] 删除数量为0跳过") logger.debug("[记忆遗忘-阶段4] 删除数量为0跳过")
return return
# 选择要删除的记录 # 选择要删除的记录
to_delete = [] to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low")) to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重 # 去重
to_delete = list(set(to_delete)) to_delete = list(set(to_delete))
# 删除记录 # 删除记录
deleted_count = 0 deleted_count = 0
for record in to_delete: for record in to_delete:
@ -296,38 +281,40 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1 deleted_count += 1
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}") logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")
# 更新剩余记录的forget_times为4 # 更新剩余记录的forget_times为4
to_delete_ids = {r.id for r in to_delete} to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids] remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining: if remaining:
ids_to_update = [r.id for r in remaining] ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=4).where( ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
ChatHistory.id.in_(ids_to_update)
).execute() logger.info(
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
logger.info(f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4") )
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True) logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
def _handle_same_count_random(self, candidates: List[ChatHistory], delete_count: int, mode: str) -> List[ChatHistory]: def _handle_same_count_random(
self, candidates: List[ChatHistory], delete_count: int, mode: str
) -> List[ChatHistory]:
""" """
处理count相同的情况随机选择要删除的记录 处理count相同的情况随机选择要删除的记录
Args: Args:
candidates: 候选记录列表已按count排序 candidates: 候选记录列表已按count排序
delete_count: 要删除的数量 delete_count: 要删除的数量
mode: "high" 表示选择最高count的记录"low" 表示选择最低count的记录 mode: "high" 表示选择最高count的记录"low" 表示选择最低count的记录
Returns: Returns:
要删除的记录列表 要删除的记录列表
""" """
if not candidates or delete_count == 0: if not candidates or delete_count == 0:
return [] return []
to_delete = [] to_delete = []
if mode == "high": if mode == "high":
# 从最高count开始选择 # 从最高count开始选择
start_idx = 0 start_idx = 0
@ -339,7 +326,7 @@ class MemoryForgetTask(AsyncTask):
while idx < len(candidates) and candidates[idx].count == current_count: while idx < len(candidates) and candidates[idx].count == current_count:
same_count_records.append(candidates[idx]) same_count_records.append(candidates[idx])
idx += 1 idx += 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择 # 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete) needed = delete_count - len(to_delete)
if len(same_count_records) <= needed: if len(same_count_records) <= needed:
@ -347,9 +334,9 @@ class MemoryForgetTask(AsyncTask):
else: else:
# 随机选择需要的数量 # 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed)) to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx start_idx = idx
else: # mode == "low" else: # mode == "low"
# 从最低count开始选择 # 从最低count开始选择
start_idx = len(candidates) - 1 start_idx = len(candidates) - 1
@ -361,7 +348,7 @@ class MemoryForgetTask(AsyncTask):
while idx >= 0 and candidates[idx].count == current_count: while idx >= 0 and candidates[idx].count == current_count:
same_count_records.append(candidates[idx]) same_count_records.append(candidates[idx])
idx -= 1 idx -= 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择 # 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete) needed = delete_count - len(to_delete)
if len(same_count_records) <= needed: if len(same_count_records) <= needed:
@ -369,8 +356,7 @@ class MemoryForgetTask(AsyncTask):
else: else:
# 随机选择需要的数量 # 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed)) to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx
return to_delete
start_idx = idx
return to_delete

View File

@ -153,7 +153,7 @@ def _format_large_number(num: float | int, html: bool = False) -> str:
else: else:
number_part = f"{value:.1f}" number_part = f"{value:.1f}"
k_suffix = "K" k_suffix = "K"
if html: if html:
# HTML输出K着色为主题色并加粗大写 # HTML输出K着色为主题色并加粗大写
return f"{number_part}<span style='color: #8b5cf6; font-weight: bold;'>K</span>" return f"{number_part}<span style='color: #8b5cf6; font-weight: bold;'>K</span>"
@ -502,9 +502,13 @@ class StatisticOutputTask(AsyncTask):
} }
for period_key, _ in collect_period for period_key, _ in collect_period
} }
# 获取bot的QQ账号 # 获取bot的QQ账号
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else "" bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
@ -547,7 +551,7 @@ class StatisticOutputTask(AsyncTask):
is_bot_reply = False is_bot_reply = False
if bot_qq_account and message.user_id == bot_qq_account: if bot_qq_account and message.user_id == bot_qq_account:
is_bot_reply = True is_bot_reply = True
for idx, (_, period_start_dt) in enumerate(collect_period): for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp(): if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]: for period_key, _ in collect_period[idx:]:
@ -588,7 +592,9 @@ class StatisticOutputTask(AsyncTask):
continue continue
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据 last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳 last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段 self.stat_period = [
item for item in self.stat_period if item[0] != "all_time"
] # 删除"所有时间"的统计时段
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的")) self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
except Exception as e: except Exception as e:
logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}") logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}")
@ -640,12 +646,12 @@ class StatisticOutputTask(AsyncTask):
# 更新上次完整统计数据的时间戳 # 更新上次完整统计数据的时间戳
# 将所有defaultdict转换为普通dict以避免类型冲突 # 将所有defaultdict转换为普通dict以避免类型冲突
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"]) clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
# 将 name_mapping 中的元组转换为列表因为JSON不支持元组 # 将 name_mapping 中的元组转换为列表因为JSON不支持元组
json_safe_name_mapping = {} json_safe_name_mapping = {}
for chat_id, (chat_name, timestamp) in self.name_mapping.items(): for chat_id, (chat_name, timestamp) in self.name_mapping.items():
json_safe_name_mapping[chat_id] = [chat_name, timestamp] json_safe_name_mapping[chat_id] = [chat_name, timestamp]
local_storage["last_full_statistics"] = { local_storage["last_full_statistics"] = {
"name_mapping": json_safe_name_mapping, "name_mapping": json_safe_name_mapping,
"stat_data": clean_stat_data, "stat_data": clean_stat_data,
@ -682,24 +688,28 @@ class StatisticOutputTask(AsyncTask):
""" """
# 计算总token数从所有模型的token数中累加 # 计算总token数从所有模型的token数中累加
total_tokens = sum(stats[TOTAL_TOK_BY_MODEL].values()) if stats[TOTAL_TOK_BY_MODEL] else 0 total_tokens = sum(stats[TOTAL_TOK_BY_MODEL].values()) if stats[TOTAL_TOK_BY_MODEL] else 0
# 计算花费/消息数量指标每100条 # 计算花费/消息数量指标每100条
cost_per_100_messages = (stats[TOTAL_COST] / stats[TOTAL_MSG_CNT] * 100) if stats[TOTAL_MSG_CNT] > 0 else 0.0 cost_per_100_messages = (stats[TOTAL_COST] / stats[TOTAL_MSG_CNT] * 100) if stats[TOTAL_MSG_CNT] > 0 else 0.0
# 计算花费/时间指标(花费/小时) # 计算花费/时间指标(花费/小时)
online_hours = stats[ONLINE_TIME] / 3600.0 if stats[ONLINE_TIME] > 0 else 0.0 online_hours = stats[ONLINE_TIME] / 3600.0 if stats[ONLINE_TIME] > 0 else 0.0
cost_per_hour = stats[TOTAL_COST] / online_hours if online_hours > 0 else 0.0 cost_per_hour = stats[TOTAL_COST] / online_hours if online_hours > 0 else 0.0
# 计算token/时间指标token/小时) # 计算token/时间指标token/小时)
tokens_per_hour = (total_tokens / online_hours) if online_hours > 0 else 0.0 tokens_per_hour = (total_tokens / online_hours) if online_hours > 0 else 0.0
# 计算花费/回复数量指标每100条 # 计算花费/回复数量指标每100条
total_replies = stats.get(TOTAL_REPLY_CNT, 0) total_replies = stats.get(TOTAL_REPLY_CNT, 0)
cost_per_100_replies = (stats[TOTAL_COST] / total_replies * 100) if total_replies > 0 else 0.0 cost_per_100_replies = (stats[TOTAL_COST] / total_replies * 100) if total_replies > 0 else 0.0
# 计算花费/消息数量排除自己回复指标每100条 # 计算花费/消息数量排除自己回复指标每100条
total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies
cost_per_100_messages_excluding_replies = (stats[TOTAL_COST] / total_messages_excluding_replies * 100) if total_messages_excluding_replies > 0 else 0.0 cost_per_100_messages_excluding_replies = (
(stats[TOTAL_COST] / total_messages_excluding_replies * 100)
if total_messages_excluding_replies > 0
else 0.0
)
output = [ output = [
f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}", f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
@ -709,7 +719,9 @@ class StatisticOutputTask(AsyncTask):
f"总Token数: {_format_large_number(total_tokens)}", f"总Token数: {_format_large_number(total_tokens)}",
f"总花费: {stats[TOTAL_COST]:.2f}¥", f"总花费: {stats[TOTAL_COST]:.2f}¥",
f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A", f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A",
f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条" if total_messages_excluding_replies > 0 else "花费/消息数量(排除回复): N/A", f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条"
if total_messages_excluding_replies > 0
else "花费/消息数量(排除回复): N/A",
f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A", f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A",
f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A", f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A",
f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A", f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A",
@ -745,7 +757,16 @@ class StatisticOutputTask(AsyncTask):
formatted_out_tokens = _format_large_number(out_tokens) formatted_out_tokens = _format_large_number(out_tokens)
formatted_tokens = _format_large_number(tokens) formatted_tokens = _format_large_number(tokens)
output.append( output.append(
data_fmt.format(name, formatted_count, formatted_in_tokens, formatted_out_tokens, formatted_tokens, cost, avg_time_cost, std_time_cost) data_fmt.format(
name,
formatted_count,
formatted_in_tokens,
formatted_out_tokens,
formatted_tokens,
cost,
avg_time_cost,
std_time_cost,
)
) )
output.append("") output.append("")
@ -891,8 +912,12 @@ class StatisticOutputTask(AsyncTask):
except (IndexError, TypeError) as e: except (IndexError, TypeError) as e:
logger.warning(f"生成HTML聊天统计时发生错误chat_id: {chat_id}, 错误: {e}") logger.warning(f"生成HTML聊天统计时发生错误chat_id: {chat_id}, 错误: {e}")
chat_rows.append(f"<tr><td>未知聊天</td><td>{_format_large_number(count, html=True)}</td></tr>") chat_rows.append(f"<tr><td>未知聊天</td><td>{_format_large_number(count, html=True)}</td></tr>")
chat_rows_html = "\n".join(chat_rows) if chat_rows else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>" chat_rows_html = (
"\n".join(chat_rows)
if chat_rows
else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
)
# 生成HTML # 生成HTML
return f""" return f"""
<div id=\"{div_id}\" class=\"tab-content\"> <div id=\"{div_id}\" class=\"tab-content\">
@ -1197,7 +1222,7 @@ class StatisticOutputTask(AsyncTask):
# 添加图表内容 # 添加图表内容
chart_data = self._generate_chart_data(stat) chart_data = self._generate_chart_data(stat)
tab_content_list.append(self._generate_chart_tab(chart_data)) tab_content_list.append(self._generate_chart_tab(chart_data))
# 添加指标趋势图表 # 添加指标趋势图表
metrics_data = self._generate_metrics_data(now) metrics_data = self._generate_metrics_data(now)
tab_content_list.append(self._generate_metrics_tab(metrics_data)) tab_content_list.append(self._generate_metrics_tab(metrics_data))
@ -1772,121 +1797,125 @@ class StatisticOutputTask(AsyncTask):
def _generate_metrics_data(self, now: datetime) -> dict: def _generate_metrics_data(self, now: datetime) -> dict:
"""生成指标趋势数据""" """生成指标趋势数据"""
metrics_data = {} metrics_data = {}
# 24小时尺度1小时为单位 # 24小时尺度1小时为单位
metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1) metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1)
# 7天尺度1天为单位 # 7天尺度1天为单位
metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24*7, interval_hours=24) metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24 * 7, interval_hours=24)
# 30天尺度1天为单位 # 30天尺度1天为单位
metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24*30, interval_hours=24) metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24 * 30, interval_hours=24)
return metrics_data return metrics_data
def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict: def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict:
"""收集指定时间范围内每个间隔的指标数据""" """收集指定时间范围内每个间隔的指标数据"""
start_time = now - timedelta(hours=hours) start_time = now - timedelta(hours=hours)
time_points = [] time_points = []
current_time = start_time current_time = start_time
# 生成时间点 # 生成时间点
while current_time <= now: while current_time <= now:
time_points.append(current_time) time_points.append(current_time)
current_time += timedelta(hours=interval_hours) current_time += timedelta(hours=interval_hours)
# 初始化数据结构 # 初始化数据结构
cost_per_100_messages = [0.0] * len(time_points) # 花费/消息数量每100条 cost_per_100_messages = [0.0] * len(time_points) # 花费/消息数量每100条
cost_per_hour = [0.0] * len(time_points) # 花费/时间(每小时) cost_per_hour = [0.0] * len(time_points) # 花费/时间(每小时)
tokens_per_hour = [0.0] * len(time_points) # Token/时间(每小时) tokens_per_hour = [0.0] * len(time_points) # Token/时间(每小时)
cost_per_100_replies = [0.0] * len(time_points) # 花费/回复数量每100条 cost_per_100_replies = [0.0] * len(time_points) # 花费/回复数量每100条
# 每个时间点的累计数据 # 每个时间点的累计数据
total_costs = [0.0] * len(time_points) total_costs = [0.0] * len(time_points)
total_tokens = [0] * len(time_points) total_tokens = [0] * len(time_points)
total_messages = [0] * len(time_points) total_messages = [0] * len(time_points)
total_replies = [0] * len(time_points) total_replies = [0] * len(time_points)
total_online_hours = [0.0] * len(time_points) total_online_hours = [0.0] * len(time_points)
# 获取bot的QQ账号 # 获取bot的QQ账号
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else "" bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
interval_seconds = interval_hours * 3600 interval_seconds = interval_hours * 3600
# 查询LLM使用记录 # 查询LLM使用记录
query_start_time = start_time query_start_time = start_time
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp record_time = record.timestamp
# 找到对应的时间间隔索引 # 找到对应的时间间隔索引
time_diff = (record_time - start_time).total_seconds() time_diff = (record_time - start_time).total_seconds()
interval_index = int(time_diff // interval_seconds) interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points): if 0 <= interval_index < len(time_points):
cost = record.cost or 0.0 cost = record.cost or 0.0
prompt_tokens = record.prompt_tokens or 0 prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0 completion_tokens = record.completion_tokens or 0
total_token = prompt_tokens + completion_tokens total_token = prompt_tokens + completion_tokens
total_costs[interval_index] += cost total_costs[interval_index] += cost
total_tokens[interval_index] += total_token total_tokens[interval_index] += total_token
# 查询消息记录 # 查询消息记录
query_start_timestamp = start_time.timestamp() query_start_timestamp = start_time.timestamp()
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time message_time_ts = message.time
time_diff = message_time_ts - query_start_timestamp time_diff = message_time_ts - query_start_timestamp
interval_index = int(time_diff // interval_seconds) interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points): if 0 <= interval_index < len(time_points):
total_messages[interval_index] += 1 total_messages[interval_index] += 1
# 检查是否是bot发送的消息回复 # 检查是否是bot发送的消息回复
if bot_qq_account and message.user_id == bot_qq_account: if bot_qq_account and message.user_id == bot_qq_account:
total_replies[interval_index] += 1 total_replies[interval_index] += 1
# 查询在线时间记录 # 查询在线时间记录
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= start_time): # type: ignore for record in OnlineTime.select().where(OnlineTime.end_timestamp >= start_time): # type: ignore
record_start = record.start_timestamp record_start = record.start_timestamp
record_end = record.end_timestamp record_end = record.end_timestamp
# 找到记录覆盖的所有时间间隔 # 找到记录覆盖的所有时间间隔
for idx, time_point in enumerate(time_points): for idx, time_point in enumerate(time_points):
interval_start = time_point interval_start = time_point
interval_end = time_point + timedelta(hours=interval_hours) interval_end = time_point + timedelta(hours=interval_hours)
# 计算重叠部分 # 计算重叠部分
overlap_start = max(record_start, interval_start) overlap_start = max(record_start, interval_start)
overlap_end = min(record_end, interval_end) overlap_end = min(record_end, interval_end)
if overlap_end > overlap_start: if overlap_end > overlap_start:
overlap_hours = (overlap_end - overlap_start).total_seconds() / 3600.0 overlap_hours = (overlap_end - overlap_start).total_seconds() / 3600.0
total_online_hours[idx] += overlap_hours total_online_hours[idx] += overlap_hours
# 计算指标 # 计算指标
for idx in range(len(time_points)): for idx in range(len(time_points)):
# 花费/消息数量每100条 # 花费/消息数量每100条
if total_messages[idx] > 0: if total_messages[idx] > 0:
cost_per_100_messages[idx] = (total_costs[idx] / total_messages[idx] * 100) cost_per_100_messages[idx] = total_costs[idx] / total_messages[idx] * 100
# 花费/时间(每小时) # 花费/时间(每小时)
if total_online_hours[idx] > 0: if total_online_hours[idx] > 0:
cost_per_hour[idx] = (total_costs[idx] / total_online_hours[idx]) cost_per_hour[idx] = total_costs[idx] / total_online_hours[idx]
# Token/时间(每小时) # Token/时间(每小时)
if total_online_hours[idx] > 0: if total_online_hours[idx] > 0:
tokens_per_hour[idx] = (total_tokens[idx] / total_online_hours[idx]) tokens_per_hour[idx] = total_tokens[idx] / total_online_hours[idx]
# 花费/回复数量每100条 # 花费/回复数量每100条
if total_replies[idx] > 0: if total_replies[idx] > 0:
cost_per_100_replies[idx] = (total_costs[idx] / total_replies[idx] * 100) cost_per_100_replies[idx] = total_costs[idx] / total_replies[idx] * 100
# 生成时间标签 # 生成时间标签
if interval_hours == 1: if interval_hours == 1:
time_labels = [t.strftime("%H:%M") for t in time_points] time_labels = [t.strftime("%H:%M") for t in time_points]
else: else:
time_labels = [t.strftime("%m-%d") for t in time_points] time_labels = [t.strftime("%m-%d") for t in time_points]
return { return {
"time_labels": time_labels, "time_labels": time_labels,
"cost_per_100_messages": cost_per_100_messages, "cost_per_100_messages": cost_per_100_messages,
@ -1894,7 +1923,7 @@ class StatisticOutputTask(AsyncTask):
"tokens_per_hour": tokens_per_hour, "tokens_per_hour": tokens_per_hour,
"cost_per_100_replies": cost_per_100_replies, "cost_per_100_replies": cost_per_100_replies,
} }
def _generate_metrics_tab(self, metrics_data: dict) -> str: def _generate_metrics_tab(self, metrics_data: dict) -> str:
"""生成指标趋势图表选项卡HTML内容""" """生成指标趋势图表选项卡HTML内容"""
colors = { colors = {
@ -1903,7 +1932,7 @@ class StatisticOutputTask(AsyncTask):
"tokens_per_hour": "#c7bbff", "tokens_per_hour": "#c7bbff",
"cost_per_100_replies": "#d9ceff", "cost_per_100_replies": "#d9ceff",
} }
return f""" return f"""
<div id="metrics" class="tab-content"> <div id="metrics" class="tab-content">
<h2>指标趋势图表</h2> <h2>指标趋势图表</h2>

View File

@ -4,14 +4,11 @@ import time
import jieba import jieba
import json import json
import ast import ast
import numpy as np
from collections import Counter
from typing import Optional, Tuple, List, TYPE_CHECKING from typing import Optional, Tuple, List, TYPE_CHECKING
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
@ -32,10 +29,10 @@ def is_english_letter(char: str) -> bool:
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]: def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
"""解析 platforms 列表,返回平台到账号的映射 """解析 platforms 列表,返回平台到账号的映射
Args: Args:
platforms: 格式为 ["platform:account"] 的列表 ["tg:123456789", "wx:wxid123"] platforms: 格式为 ["platform:account"] 的列表 ["tg:123456789", "wx:wxid123"]
Returns: Returns:
字典键为平台名值为账号 字典键为平台名值为账号
""" """
@ -49,12 +46,12 @@ def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str: def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
"""根据当前平台获取对应的账号 """根据当前平台获取对应的账号
Args: Args:
platform: 当前消息的平台 platform: 当前消息的平台
platform_accounts: platforms 列表解析的平台账号映射 platform_accounts: platforms 列表解析的平台账号映射
qq_account: QQ 账号兼容旧配置 qq_account: QQ 账号兼容旧配置
Returns: Returns:
当前平台对应的账号 当前平台对应的账号
""" """
@ -72,12 +69,12 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
"""检查消息是否提到了机器人(统一多平台实现)""" """检查消息是否提到了机器人(统一多平台实现)"""
text = message.processed_plain_text or "" text = message.processed_plain_text or ""
platform = getattr(message.message_info, "platform", "") or "" platform = getattr(message.message_info, "platform", "") or ""
# 获取各平台账号 # 获取各平台账号
platforms_list = getattr(global_config.bot, "platforms", []) or [] platforms_list = getattr(global_config.bot, "platforms", []) or []
platform_accounts = parse_platform_accounts(platforms_list) platform_accounts = parse_platform_accounts(platforms_list)
qq_account = str(getattr(global_config.bot, "qq_account", "") or "") qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
# 获取当前平台对应的账号 # 获取当前平台对应的账号
current_account = get_current_platform_account(platform, platform_accounts, qq_account) current_account = get_current_platform_account(platform, platform_accounts, qq_account)
@ -146,7 +143,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
elif current_account: elif current_account:
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\)(.+?)\],说:", text): if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\)(.+?)\],说:", text):
is_mentioned = True is_mentioned = True
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text): elif re.search(
rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text
):
is_mentioned = True is_mentioned = True
# 6) 名称/别名 提及(去除 @/回复标记后再匹配) # 6) 名称/别名 提及(去除 @/回复标记后再匹配)
@ -185,7 +184,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
return embedding return embedding
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]: def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并 """将文本分割成句子,并根据概率合并
1. 识别分割点, ; 空格但如果分割点左右都是英文字母则不分割 1. 识别分割点, ; 空格但如果分割点左右都是英文字母则不分割
@ -227,7 +225,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
prev_char = text[i - 1] prev_char = text[i - 1]
next_char = text[i + 1] next_char = text[i + 1]
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则 # 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
if char == ' ': if char == " ":
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char) prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
next_is_alnum = next_char.isdigit() or is_english_letter(next_char) next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
if prev_is_alnum and next_is_alnum: if prev_is_alnum and next_is_alnum:
@ -340,7 +338,7 @@ def _get_random_default_reply() -> str:
"不知道", "不知道",
"不晓得", "不晓得",
"懒得说", "懒得说",
"()" "()",
] ]
return random.choice(default_replies) return random.choice(default_replies)
@ -469,7 +467,6 @@ def calculate_typing_time(
return total_time # 加上回车时间 return total_time # 加上回车时间
def truncate_message(message: str, max_length=20) -> str: def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度""" """截断消息,使其不超过指定长度"""
return f"{message[:max_length]}..." if len(message) > max_length else message return f"{message[:max_length]}..." if len(message) > max_length else message
@ -546,7 +543,6 @@ def get_western_ratio(paragraph):
return western_count / len(alnum_chars) return western_count / len(alnum_chars)
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str: def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch # sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式 """将时间戳转换为人类可读的时间格式

View File

@ -103,14 +103,16 @@ class ImageManager:
invalid_values = ["", "None"] invalid_values = ["", "None"]
# 清理 Images 表 # 清理 Images 表
deleted_images = Images.delete().where( deleted_images = (
(Images.description >> None) | (Images.description << invalid_values) Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
).execute() )
# 清理 ImageDescriptions 表 # 清理 ImageDescriptions 表
deleted_descriptions = ImageDescriptions.delete().where( deleted_descriptions = (
(ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values) ImageDescriptions.delete()
).execute() .where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
.execute()
)
if deleted_images or deleted_descriptions: if deleted_images or deleted_descriptions:
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}") logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}")

View File

@ -220,7 +220,7 @@ class DatabaseActionRecords(BaseDataModel):
chat_id: str, chat_id: str,
chat_info_stream_id: str, chat_info_stream_id: str,
chat_info_platform: str, chat_info_platform: str,
action_reasoning:str action_reasoning: str,
): ):
self.action_id = action_id self.action_id = action_id
self.time = time self.time = time
@ -235,4 +235,4 @@ class DatabaseActionRecords(BaseDataModel):
self.chat_id = chat_id self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform self.chat_info_platform = chat_info_platform
self.action_reasoning = action_reasoning self.action_reasoning = action_reasoning

View File

@ -317,10 +317,12 @@ class Expression(BaseModel):
class Meta: class Meta:
table_name = "expression" table_name = "expression"
class Jargon(BaseModel): class Jargon(BaseModel):
""" """
用于存储俚语的模型 用于存储俚语的模型
""" """
content = TextField() content = TextField()
raw_content = TextField(null=True) raw_content = TextField(null=True)
type = TextField(null=True) type = TextField(null=True)
@ -332,14 +334,16 @@ class Jargon(BaseModel):
is_jargon = BooleanField(null=True) # None表示未判定True表示是黑话False表示不是黑话 is_jargon = BooleanField(null=True) # None表示未判定True表示是黑话False表示不是黑话
last_inference_count = IntegerField(null=True) # 最后一次判定的count值用于避免重启后重复判定 last_inference_count = IntegerField(null=True) # 最后一次判定的count值用于避免重启后重复判定
is_complete = BooleanField(default=False) # 是否已完成所有推断count>=100后不再推断 is_complete = BooleanField(default=False) # 是否已完成所有推断count>=100后不再推断
class Meta: class Meta:
table_name = "jargon" table_name = "jargon"
class ChatHistory(BaseModel): class ChatHistory(BaseModel):
""" """
用于存储聊天历史概括的模型 用于存储聊天历史概括的模型
""" """
chat_id = TextField(index=True) # 聊天ID chat_id = TextField(index=True) # 聊天ID
start_time = DoubleField() # 起始时间 start_time = DoubleField() # 起始时间
end_time = DoubleField() # 结束时间 end_time = DoubleField() # 结束时间
@ -350,7 +354,7 @@ class ChatHistory(BaseModel):
summary = TextField() # 概括:对这段话的平文本概括 summary = TextField() # 概括:对这段话的平文本概括
count = IntegerField(default=0) # 被检索次数 count = IntegerField(default=0) # 被检索次数
forget_times = IntegerField(default=0) # 被遗忘检查的次数 forget_times = IntegerField(default=0) # 被遗忘检查的次数
class Meta: class Meta:
table_name = "chat_history" table_name = "chat_history"
@ -359,6 +363,7 @@ class ThinkingBack(BaseModel):
""" """
用于存储记忆检索思考过程的模型 用于存储记忆检索思考过程的模型
""" """
chat_id = TextField(index=True) # 聊天ID chat_id = TextField(index=True) # 聊天ID
question = TextField() # 提出的问题 question = TextField() # 提出的问题
context = TextField(null=True) # 上下文信息 context = TextField(null=True) # 上下文信息
@ -367,10 +372,11 @@ class ThinkingBack(BaseModel):
thinking_steps = TextField(null=True) # 思考步骤JSON格式 thinking_steps = TextField(null=True) # 思考步骤JSON格式
create_time = DoubleField() # 创建时间 create_time = DoubleField() # 创建时间
update_time = DoubleField() # 更新时间 update_time = DoubleField() # 更新时间
class Meta: class Meta:
table_name = "thinking_back" table_name = "thinking_back"
MODELS = [ MODELS = [
ChatStreams, ChatStreams,
LLMUsage, LLMUsage,
@ -387,6 +393,7 @@ MODELS = [
ThinkingBack, ThinkingBack,
] ]
def create_tables(): def create_tables():
""" """
创建所有在模型中定义的数据库表 创建所有在模型中定义的数据库表

View File

@ -27,7 +27,7 @@ class BotConfig(ConfigBase):
nickname: str nickname: str
"""昵称""" """昵称"""
platforms: list[str] = field(default_factory=lambda: []) platforms: list[str] = field(default_factory=lambda: [])
"""其他平台列表""" """其他平台列表"""
@ -311,16 +311,18 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set()) ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表""" """过滤正则表达式列表"""
@dataclass @dataclass
class MemoryConfig(ConfigBase): class MemoryConfig(ConfigBase):
"""记忆配置类""" """记忆配置类"""
max_memory_number: int = 100 max_memory_number: int = 100
"""记忆最大数量""" """记忆最大数量"""
memory_build_frequency: int = 1 memory_build_frequency: int = 1
"""记忆构建频率""" """记忆构建频率"""
@dataclass @dataclass
class ExpressionConfig(ConfigBase): class ExpressionConfig(ConfigBase):
"""表达配置类""" """表达配置类"""
@ -494,13 +496,14 @@ class MoodConfig(ConfigBase):
enable_mood: bool = True enable_mood: bool = True
"""是否启用情绪系统""" """是否启用情绪系统"""
mood_update_threshold: float = 1 mood_update_threshold: float = 1
"""情绪更新阈值,越高,更新越慢""" """情绪更新阈值,越高,更新越慢"""
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大" emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
"""情感特征,影响情绪的变化情况""" """情感特征,影响情绪的变化情况"""
@dataclass @dataclass
class VoiceConfig(ConfigBase): class VoiceConfig(ConfigBase):
"""语音识别配置类""" """语音识别配置类"""
@ -644,16 +647,16 @@ class DebugConfig(ConfigBase):
show_prompt: bool = False show_prompt: bool = False
"""是否显示prompt""" """是否显示prompt"""
show_replyer_prompt: bool = True show_replyer_prompt: bool = True
"""是否显示回复器prompt""" """是否显示回复器prompt"""
show_replyer_reasoning: bool = True show_replyer_reasoning: bool = True
"""是否显示回复器推理""" """是否显示回复器推理"""
show_jargon_prompt: bool = False show_jargon_prompt: bool = False
"""是否显示jargon相关提示词""" """是否显示jargon相关提示词"""
show_planner_prompt: bool = False show_planner_prompt: bool = False
"""是否显示planner相关提示词""" """是否显示planner相关提示词"""

View File

@ -3,31 +3,30 @@ import difflib
import random import random
from datetime import datetime from datetime import datetime
from typing import Optional, List, Dict from typing import Optional, List, Dict
from collections import defaultdict
def filter_message_content(content: Optional[str]) -> str: def filter_message_content(content: Optional[str]) -> str:
""" """
过滤消息内容移除回复@图片等格式 过滤消息内容移除回复@图片等格式
Args: Args:
content: 原始消息内容 content: 原始消息内容
Returns: Returns:
str: 过滤后的内容 str: 过滤后的内容
""" """
if not content: if not content:
return "" return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r'\[回复.*?\],说:\s*', '', content) content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容 # 移除@<...>格式的内容
content = re.sub(r'@<[^>]*>', '', content) content = re.sub(r"@<[^>]*>", "", content)
# 移除[picid:...]格式的图片ID # 移除[picid:...]格式的图片ID
content = re.sub(r'\[picid:[^\]]*\]', '', content) content = re.sub(r"\[picid:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容 # 移除[表情包:...]格式的内容
content = re.sub(r'\[表情包:[^\]]*\]', '', content) content = re.sub(r"\[表情包:[^\]]*\]", "", content)
return content.strip() return content.strip()
@ -35,11 +34,11 @@ def calculate_similarity(text1: str, text2: str) -> float:
""" """
计算两个文本的相似度返回0-1之间的值 计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度 使用SequenceMatcher计算相似度
Args: Args:
text1: 第一个文本 text1: 第一个文本
text2: 第二个文本 text2: 第二个文本
Returns: Returns:
float: 相似度值范围0-1 float: 相似度值范围0-1
""" """
@ -49,10 +48,10 @@ def calculate_similarity(text1: str, text2: str) -> float:
def format_create_date(timestamp: float) -> str: def format_create_date(timestamp: float) -> str:
""" """
将时间戳格式化为可读的日期字符串 将时间戳格式化为可读的日期字符串
Args: Args:
timestamp: 时间戳 timestamp: 时间戳
Returns: Returns:
str: 格式化后的日期字符串 str: 格式化后的日期字符串
""" """
@ -65,11 +64,11 @@ def format_create_date(timestamp: float) -> str:
def weighted_sample(population: List[Dict], k: int) -> List[Dict]: def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
""" """
随机抽样函数 随机抽样函数
Args: Args:
population: 总体数据列表 population: 总体数据列表
k: 需要抽取的数量 k: 需要抽取的数量
Returns: Returns:
List[Dict]: 抽取的数据列表 List[Dict]: 抽取的数据列表
""" """

View File

@ -1,7 +1,6 @@
import time import time
import json import json
import os import os
from datetime import datetime
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import traceback import traceback
from src.common.logger import get_logger from src.common.logger import get_logger
@ -158,8 +157,6 @@ class ExpressionLearner:
traceback.print_exc() traceback.print_exc()
return return
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]: async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
""" """
学习并存储表达方式 学习并存储表达方式
@ -169,7 +166,7 @@ class ExpressionLearner:
if learnt_expressions is None: if learnt_expressions is None:
logger.info("没有学习到表达风格") logger.info("没有学习到表达风格")
return [] return []
# 展示学到的表达方式 # 展示学到的表达方式
learnt_expressions_str = "" learnt_expressions_str = ""
for ( for (
@ -186,7 +183,7 @@ class ExpressionLearner:
# 存储到数据库 Expression 表并训练 style_learner # 存储到数据库 Expression 表并训练 style_learner
has_new_expressions = False # 记录是否有新的表达方式 has_new_expressions = False # 记录是否有新的表达方式
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例 learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
for ( for (
situation, situation,
style, style,
@ -195,9 +192,7 @@ class ExpressionLearner:
) in learnt_expressions: ) in learnt_expressions:
# 查找是否已存在相似表达方式 # 查找是否已存在相似表达方式
query = Expression.select().where( query = Expression.select().where(
(Expression.chat_id == self.chat_id) (Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style)
& (Expression.situation == situation)
& (Expression.style == style)
) )
if query.exists(): if query.exists():
# 表达方式完全相同,只更新时间戳 # 表达方式完全相同,只更新时间戳
@ -216,39 +211,37 @@ class ExpressionLearner:
up_content=up_content, up_content=up_content,
) )
has_new_expressions = True has_new_expressions = True
# 训练 style_learnerup_content 和 style 必定存在) # 训练 style_learnerup_content 和 style 必定存在)
try: try:
learner.add_style(style, situation) learner.add_style(style, situation)
# 学习映射关系 # 学习映射关系
success = style_learner_manager.learn_mapping( success = style_learner_manager.learn_mapping(self.chat_id, up_content, style)
self.chat_id,
up_content,
style
)
if success: if success:
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else "")) logger.debug(
f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}"
+ (f" (situation: {situation})" if situation else "")
)
else: else:
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}") logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
except Exception as e: except Exception as e:
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}") logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
# 保存当前聊天室的 style_learner 模型 # 保存当前聊天室的 style_learner 模型
if has_new_expressions: if has_new_expressions:
try: try:
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...") logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
save_success = learner.save(style_learner_manager.model_save_path) save_success = learner.save(style_learner_manager.model_save_path)
if save_success: if save_success:
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}") logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
else: else:
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}") logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
except Exception as e: except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}") logger.error(f"StyleLearner 模型保存异常: {e}")
return learnt_expressions return learnt_expressions
async def match_expression_context( async def match_expression_context(
@ -334,7 +327,7 @@ class ExpressionLearner:
matched_expressions = [] matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引 used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}") logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
logger.debug(f"match_responses 内容: {match_responses}") logger.debug(f"match_responses 内容: {match_responses}")
@ -344,12 +337,12 @@ class ExpressionLearner:
if not isinstance(match_response, dict): if not isinstance(match_response, dict):
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}") logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
continue continue
# 获取表达方式序号 # 获取表达方式序号
if "expression_pair" not in match_response: if "expression_pair" not in match_response:
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}") logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
continue continue
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引 pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过 # 检查索引是否有效且未被使用过
@ -367,9 +360,7 @@ class ExpressionLearner:
return matched_expressions return matched_expressions
async def learn_expression( async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]:
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式 """从指定聊天流学习表达方式
Args: Args:
@ -409,7 +400,6 @@ class ExpressionLearner:
expressions: List[Tuple[str, str]] = self.parse_expression_response(response) expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# logger.debug(f"学习{type_str}的response: {response}") # logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源 # 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context( matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str expressions, random_msg_match_str
@ -426,17 +416,17 @@ class ExpressionLearner:
if similarity >= 0.85: # 85%相似度阈值 if similarity >= 0.85: # 85%相似度阈值
pos = i pos = i
break break
if pos is None or pos == 0: if pos is None or pos == 0:
# 没有匹配到目标句或没有上一句,跳过该表达 # 没有匹配到目标句或没有上一句,跳过该表达
continue continue
# 检查目标句是否为空 # 检查目标句是否为空
target_content = bare_lines[pos][1] target_content = bare_lines[pos][1]
if not target_content: if not target_content:
# 目标句为空,跳过该表达 # 目标句为空,跳过该表达
continue continue
prev_original_idx = bare_lines[pos - 1][0] prev_original_idx = bare_lines[pos - 1][0]
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "") up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
if not up_content: if not up_content:
@ -449,7 +439,6 @@ class ExpressionLearner:
return filtered_with_up return filtered_with_up
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
""" """
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组 解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组
@ -483,21 +472,21 @@ class ExpressionLearner:
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]: def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
""" """
为每条消息构建精简文本列表保留到原消息索引的映射 为每条消息构建精简文本列表保留到原消息索引的映射
Args: Args:
messages: 消息列表 messages: 消息列表
Returns: Returns:
List[Tuple[int, str]]: (original_index, bare_content) 元组列表 List[Tuple[int, str]]: (original_index, bare_content) 元组列表
""" """
bare_lines: List[Tuple[int, str]] = [] bare_lines: List[Tuple[int, str]] = []
for idx, msg in enumerate(messages): for idx, msg in enumerate(messages):
content = msg.processed_plain_text or "" content = msg.processed_plain_text or ""
content = filter_message_content(content) content = filter_message_content(content)
# 即使content为空也要记录防止错位 # 即使content为空也要记录防止错位
bare_lines.append((idx, content)) bare_lines.append((idx, content))
return bare_lines return bare_lines

View File

@ -1,8 +1,6 @@
import json import json
import time import time
import random
import hashlib import hashlib
import re
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json from json_repair import repair_json
@ -115,30 +113,31 @@ class ExpressionSelector:
return group_chat_ids return group_chat_ids
return [chat_id] return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]: def get_model_predicted_expressions(
self, chat_id: str, target_message: str, total_num: int = 10
) -> List[Dict[str, Any]]:
""" """
使用 style_learner 模型预测最合适的表达方式 使用 style_learner 模型预测最合适的表达方式
Args: Args:
chat_id: 聊天室ID chat_id: 聊天室ID
target_message: 目标消息内容 target_message: 目标消息内容
total_num: 需要预测的数量 total_num: 需要预测的数量
Returns: Returns:
List[Dict[str, Any]]: 预测的表达方式列表 List[Dict[str, Any]]: 预测的表达方式列表
""" """
try: try:
# 过滤目标消息内容,移除回复、表情包等特殊格式 # 过滤目标消息内容,移除回复、表情包等特殊格式
filtered_target_message = filter_message_content(target_message) filtered_target_message = filter_message_content(target_message)
logger.info(f"{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}") logger.info(f"{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}")
# 支持多chat_id合并预测 # 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
predicted_expressions = [] predicted_expressions = []
# 为每个相关的chat_id进行预测 # 为每个相关的chat_id进行预测
for related_chat_id in related_chat_ids: for related_chat_id in related_chat_ids:
try: try:
@ -146,59 +145,65 @@ class ExpressionSelector:
best_style, scores = style_learner_manager.predict_style( best_style, scores = style_learner_manager.predict_style(
related_chat_id, filtered_target_message, top_k=total_num related_chat_id, filtered_target_message, top_k=total_num
) )
if best_style and scores: if best_style and scores:
# 获取预测风格的完整信息 # 获取预测风格的完整信息
learner = style_learner_manager.get_learner(related_chat_id) learner = style_learner_manager.get_learner(related_chat_id)
style_id, situation = learner.get_style_info(best_style) style_id, situation = learner.get_style_info(best_style)
if style_id and situation: if style_id and situation:
# 从数据库查找对应的表达记录 # 从数据库查找对应的表达记录
expr_query = Expression.select().where( expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) & (Expression.chat_id == related_chat_id)
(Expression.situation == situation) & & (Expression.situation == situation)
(Expression.style == best_style) & (Expression.style == best_style)
) )
if expr_query.exists(): if expr_query.exists():
expr = expr_query.get() expr = expr_query.get()
predicted_expressions.append({ predicted_expressions.append(
"id": expr.id, {
"situation": expr.situation, "id": expr.id,
"style": expr.style, "situation": expr.situation,
"last_active_time": expr.last_active_time, "style": expr.style,
"source_id": expr.chat_id, "last_active_time": expr.last_active_time,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "source_id": expr.chat_id,
"prediction_score": scores.get(best_style, 0.0), "create_date": expr.create_date
"prediction_input": filtered_target_message if expr.create_date is not None
}) else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": filtered_target_message,
}
)
else: else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式") logger.warning(
f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式"
)
except Exception as e: except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}") logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
continue continue
# 按预测分数排序,取前 total_num 个 # 按预测分数排序,取前 total_num 个
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True) predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
selected_expressions = predicted_expressions[:total_num] selected_expressions = predicted_expressions[:total_num]
logger.info(f"{chat_id} 预测到 {len(selected_expressions)} 个表达方式") logger.info(f"{chat_id} 预测到 {len(selected_expressions)} 个表达方式")
return selected_expressions return selected_expressions
except Exception as e: except Exception as e:
logger.error(f"模型预测表达方式失败: {e}") logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择 # 如果预测失败,回退到随机选择
return self._random_expressions(chat_id, total_num) return self._random_expressions(chat_id, total_num)
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]: def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
""" """
随机选择表达方式 随机选择表达方式
Args: Args:
chat_id: 聊天室ID chat_id: 聊天室ID
total_num: 需要选择的数量 total_num: 需要选择的数量
Returns: Returns:
List[Dict[str, Any]]: 随机选择的表达方式列表 List[Dict[str, Any]]: 随机选择的表达方式列表
""" """
@ -207,9 +212,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式 # 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where( style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)))
(Expression.chat_id.in_(related_chat_ids))
)
style_exprs = [ style_exprs = [
{ {
@ -228,15 +231,14 @@ class ExpressionSelector:
selected_style = weighted_sample(style_exprs, total_num) selected_style = weighted_sample(style_exprs, total_num)
else: else:
selected_style = [] selected_style = []
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式") logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
return selected_style return selected_style
except Exception as e: except Exception as e:
logger.error(f"随机选择表达方式失败: {e}") logger.error(f"随机选择表达方式失败: {e}")
return [] return []
async def select_suitable_expressions( async def select_suitable_expressions(
self, self,
chat_id: str, chat_id: str,
@ -246,13 +248,13 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]: ) -> Tuple[List[Dict[str, Any]], List[int]]:
""" """
根据配置模式选择适合的表达方式 根据配置模式选择适合的表达方式
Args: Args:
chat_id: 聊天流ID chat_id: 聊天流ID
chat_info: 聊天内容信息 chat_info: 聊天内容信息
max_num: 最大选择数量 max_num: 最大选择数量
target_message: 目标消息内容 target_message: 目标消息内容
Returns: Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
""" """
@ -263,7 +265,7 @@ class ExpressionSelector:
# 获取配置模式 # 获取配置模式
expression_mode = global_config.expression.mode expression_mode = global_config.expression.mode
if expression_mode == "exp_model": if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM # exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式") logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
@ -284,12 +286,12 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]: ) -> Tuple[List[Dict[str, Any]], List[int]]:
""" """
exp_model模式直接使用模型预测不经过LLM exp_model模式直接使用模型预测不经过LLM
Args: Args:
chat_id: 聊天流ID chat_id: 聊天流ID
target_message: 目标消息内容 target_message: 目标消息内容
max_num: 最大选择数量 max_num: 最大选择数量
Returns: Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
""" """
@ -297,14 +299,14 @@ class ExpressionSelector:
# 使用模型预测最合适的表达方式 # 使用模型预测最合适的表达方式
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num) selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
selected_ids = [expr["id"] for expr in selected_expressions] selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time # 更新last_active_time
if selected_expressions: if selected_expressions:
self.update_expressions_last_active_time(selected_expressions) self.update_expressions_last_active_time(selected_expressions)
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式") logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
return selected_expressions, selected_ids return selected_expressions, selected_ids
except Exception as e: except Exception as e:
logger.error(f"exp_model模式选择表达方式失败: {e}") logger.error(f"exp_model模式选择表达方式失败: {e}")
return [], [] return [], []
@ -318,13 +320,13 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]: ) -> Tuple[List[Dict[str, Any]], List[int]]:
""" """
classic模式随机选择+LLM选择 classic模式随机选择+LLM选择
Args: Args:
chat_id: 聊天流ID chat_id: 聊天流ID
chat_info: 聊天内容信息 chat_info: 聊天内容信息
max_num: 最大选择数量 max_num: 最大选择数量
target_message: 目标消息内容 target_message: 目标消息内容
Returns: Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
""" """
@ -425,17 +427,13 @@ class ExpressionSelector:
updates_by_key[key] = expr updates_by_key[key] = expr
for chat_id, situation, style in updates_by_key: for chat_id, situation, style in updates_by_key:
query = Expression.select().where( query = Expression.select().where(
(Expression.chat_id == chat_id) (Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
& (Expression.situation == situation)
& (Expression.style == style)
) )
if query.exists(): if query.exists():
expr_obj = query.get() expr_obj = query.get()
expr_obj.last_active_time = time.time() expr_obj.last_active_time = time.time()
expr_obj.save() expr_obj.save()
logger.debug( logger.debug("表达方式激活: 更新last_active_time in db")
"表达方式激活: 更新last_active_time in db"
)
init_prompt() init_prompt()

View File

@ -6,18 +6,21 @@ import os
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .online_nb import OnlineNaiveBayes from .online_nb import OnlineNaiveBayes
class ExpressorModel: class ExpressorModel:
""" """
直接使用朴素贝叶斯精排可在线学习 直接使用朴素贝叶斯精排可在线学习
支持存储situation字段不参与计算仅与style对应 支持存储situation字段不参与计算仅与style对应
""" """
def __init__(self, def __init__(
alpha: float = 0.5, self,
beta: float = 0.5, alpha: float = 0.5,
gamma: float = 1.0, beta: float = 0.5,
vocab_size: int = 200000, gamma: float = 1.0,
use_jieba: bool = True): vocab_size: int = 200000,
use_jieba: bool = True,
):
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba) self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size) self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
self._candidates: Dict[str, str] = {} # cid -> text (style) self._candidates: Dict[str, str] = {} # cid -> text (style)
@ -28,7 +31,7 @@ class ExpressorModel:
self._candidates[cid] = text self._candidates[cid] = text
if situation is not None: if situation is not None:
self._situations[cid] = situation self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数 # 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts: if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0 self.nb.cls_counts[cid] = 0.0
@ -46,7 +49,7 @@ class ExpressorModel:
toks = self.tokenizer.tokenize(text) toks = self.tokenizer.tokenize(text)
if not toks: if not toks:
return None, {} return None, {}
if not self._candidates: if not self._candidates:
return None, {} return None, {}
@ -58,7 +61,7 @@ class ExpressorModel:
# 取最高分 # 取最高分
if not scores: if not scores:
return None, {} return None, {}
# 根据k参数限制返回的候选数量 # 根据k参数限制返回的候选数量
if k is not None and k > 0: if k is not None and k > 0:
# 按分数降序排序取前k个 # 按分数降序排序取前k个
@ -81,40 +84,42 @@ class ExpressorModel:
def decay(self, factor: float): def decay(self, factor: float):
self.nb.decay(factor=factor) self.nb.decay(factor=factor)
def get_situation(self, cid: str) -> Optional[str]: def get_situation(self, cid: str) -> Optional[str]:
"""获取候选对应的situation""" """获取候选对应的situation"""
return self._situations.get(cid) return self._situations.get(cid)
def get_style(self, cid: str) -> Optional[str]: def get_style(self, cid: str) -> Optional[str]:
"""获取候选对应的style""" """获取候选对应的style"""
return self._candidates.get(cid) return self._candidates.get(cid)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]: def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""获取候选的style和situation信息""" """获取候选的style和situation信息"""
return self._candidates.get(cid), self._situations.get(cid) return self._candidates.get(cid), self._situations.get(cid)
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]: def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""获取所有候选的style和situation信息""" """获取所有候选的style和situation信息"""
return {cid: (style, self._situations.get(cid)) return {cid: (style, self._situations.get(cid)) for cid, style in self._candidates.items()}
for cid, style in self._candidates.items()}
def save(self, path: str): def save(self, path: str):
"""保存模型""" """保存模型"""
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f: with open(path, "wb") as f:
pickle.dump({ pickle.dump(
"candidates": self._candidates, {
"situations": self._situations, "candidates": self._candidates,
"nb": { "situations": self._situations,
"cls_counts": dict(self.nb.cls_counts), "nb": {
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()}, "cls_counts": dict(self.nb.cls_counts),
"alpha": self.nb.alpha, "token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"beta": self.nb.beta, "alpha": self.nb.alpha,
"gamma": self.nb.gamma, "beta": self.nb.beta,
"V": self.nb.V, "gamma": self.nb.gamma,
} "V": self.nb.V,
}, f) },
},
f,
)
def load(self, path: str): def load(self, path: str):
"""加载模型""" """加载模型"""
@ -133,9 +138,11 @@ class ExpressorModel:
self.nb.V = obj["nb"]["V"] self.nb.V = obj["nb"]["V"]
self.nb._logZ.clear() self.nb._logZ.clear()
def defaultdict_dict(d: Dict[str, Dict[str, float]]): def defaultdict_dict(d: Dict[str, Dict[str, float]]):
from collections import defaultdict from collections import defaultdict
outer = defaultdict(lambda: defaultdict(float)) outer = defaultdict(lambda: defaultdict(float))
for k, inner in d.items(): for k, inner in d.items():
outer[k].update(inner) outer[k].update(inner)
return outer return outer

View File

@ -2,6 +2,7 @@ import math
from typing import Dict, List from typing import Dict, List
from collections import defaultdict, Counter from collections import defaultdict, Counter
class OnlineNaiveBayes: class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000): def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha self.alpha = alpha
@ -9,9 +10,9 @@ class OnlineNaiveBayes:
self.gamma = gamma self.gamma = gamma
self.V = vocab_size self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str): def _invalidate(self, cid: str):
if cid in self._logZ: if cid in self._logZ:
@ -57,4 +58,4 @@ class OnlineNaiveBayes:
self.cls_counts[cid] *= g self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()): for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g self.token_counts[cid][term] *= g
self._invalidate(cid) self._invalidate(cid)

View File

@ -3,17 +3,20 @@ from typing import List, Optional, Set
try: try:
import jieba import jieba
_HAS_JIEBA = True _HAS_JIEBA = True
except Exception: except Exception:
_HAS_JIEBA = False _HAS_JIEBA = False
_WORD_RE = re.compile(r"[A-Za-z0-9_]+") _WORD_RE = re.compile(r"[A-Za-z0-9_]+")
# 匹配纯符号的正则表达式 # 匹配纯符号的正则表达式
_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$') _SYMBOL_RE = re.compile(r"^[^\w\u4e00-\u9fff]+$")
def simple_en_tokenize(text: str) -> List[str]: def simple_en_tokenize(text: str) -> List[str]:
return _WORD_RE.findall(text.lower()) return _WORD_RE.findall(text.lower())
class Tokenizer: class Tokenizer:
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True): def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
self.stopwords = stopwords or set() self.stopwords = stopwords or set()
@ -28,4 +31,4 @@ class Tokenizer:
else: else:
toks = simple_en_tokenize(text) toks = simple_en_tokenize(text)
# 过滤掉纯符号和停用词 # 过滤掉纯符号和停用词
return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)] return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)]

View File

@ -22,42 +22,42 @@ class StyleLearner:
学习从up_content到style的映射关系 学习从up_content到style的映射关系
支持动态管理风格集合无数量上限 支持动态管理风格集合无数量上限
""" """
def __init__(self, chat_id: str, model_config: Optional[Dict] = None): def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
self.chat_id = chat_id self.chat_id = chat_id
self.model_config = model_config or { self.model_config = model_config or {
"alpha": 0.5, "alpha": 0.5,
"beta": 0.5, "beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘 "gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000, "vocab_size": 200000,
"use_jieba": True "use_jieba": True,
} }
# 初始化表达模型 # 初始化表达模型
self.expressor = ExpressorModel(**self.model_config) self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理 # 动态风格管理
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本 self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本 self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0 # 下一个可用的style_id self.next_style_id = 0 # 下一个可用的style_id
# 学习统计 # 学习统计
self.learning_stats = { self.learning_stats = {
"total_samples": 0, "total_samples": 0,
"style_counts": defaultdict(int), "style_counts": defaultdict(int),
"last_update": None, "last_update": None,
"style_usage_frequency": defaultdict(int) # 风格使用频率 "style_usage_frequency": defaultdict(int), # 风格使用频率
} }
def add_style(self, style: str, situation: str = None) -> bool: def add_style(self, style: str, situation: str = None) -> bool:
""" """
动态添加一个新的风格 动态添加一个新的风格
Args: Args:
style: 风格文本 style: 风格文本
situation: 对应的situation文本可选 situation: 对应的situation文本可选
Returns: Returns:
bool: 添加是否成功 bool: 添加是否成功
""" """
@ -66,35 +66,37 @@ class StyleLearner:
if style in self.style_to_id: if style in self.style_to_id:
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在") logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
return True return True
# 生成新的style_id # 生成新的style_id
style_id = f"style_{self.next_style_id}" style_id = f"style_{self.next_style_id}"
self.next_style_id += 1 self.next_style_id += 1
# 添加到映射 # 添加到映射
self.style_to_id[style] = style_id self.style_to_id[style] = style_id
self.id_to_style[style_id] = style self.id_to_style[style_id] = style
if situation: if situation:
self.id_to_situation[style_id] = situation self.id_to_situation[style_id] = situation
# 添加到expressor模型 # 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation) self.expressor.add_candidate(style_id, style, situation)
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" + logger.info(
(f", situation: '{situation}'" if situation else "")) f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})"
+ (f", situation: '{situation}'" if situation else "")
)
return True return True
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_id}] 添加风格失败: {e}") logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
return False return False
def remove_style(self, style: str) -> bool: def remove_style(self, style: str) -> bool:
""" """
删除一个风格 删除一个风格
Args: Args:
style: 要删除的风格文本 style: 要删除的风格文本
Returns: Returns:
bool: 删除是否成功 bool: 删除是否成功
""" """
@ -102,33 +104,33 @@ class StyleLearner:
if style not in self.style_to_id: if style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在") logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
return False return False
style_id = self.style_to_id[style] style_id = self.style_to_id[style]
# 从映射中删除 # 从映射中删除
del self.style_to_id[style] del self.style_to_id[style]
del self.id_to_style[style_id] del self.id_to_style[style_id]
if style_id in self.id_to_situation: if style_id in self.id_to_situation:
del self.id_to_situation[style_id] del self.id_to_situation[style_id]
# 从expressor模型中删除通过重新构建 # 从expressor模型中删除通过重新构建
self._rebuild_expressor() self._rebuild_expressor()
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})") logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
return True return True
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_id}] 删除风格失败: {e}") logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
return False return False
def update_style(self, old_style: str, new_style: str) -> bool: def update_style(self, old_style: str, new_style: str) -> bool:
""" """
更新一个风格 更新一个风格
Args: Args:
old_style: 原风格文本 old_style: 原风格文本
new_style: 新风格文本 new_style: 新风格文本
Returns: Returns:
bool: 更新是否成功 bool: 更新是否成功
""" """
@ -136,37 +138,37 @@ class StyleLearner:
if old_style not in self.style_to_id: if old_style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在") logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
return False return False
if new_style in self.style_to_id and new_style != old_style: if new_style in self.style_to_id and new_style != old_style:
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在") logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
return False return False
style_id = self.style_to_id[old_style] style_id = self.style_to_id[old_style]
# 更新映射 # 更新映射
del self.style_to_id[old_style] del self.style_to_id[old_style]
self.style_to_id[new_style] = style_id self.style_to_id[new_style] = style_id
self.id_to_style[style_id] = new_style self.id_to_style[style_id] = new_style
# 更新expressor模型保留原有的situation # 更新expressor模型保留原有的situation
situation = self.id_to_situation.get(style_id) situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, new_style, situation) self.expressor.add_candidate(style_id, new_style, situation)
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'") logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
return True return True
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_id}] 更新风格失败: {e}") logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
return False return False
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int: def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
""" """
批量添加风格 批量添加风格
Args: Args:
styles: 风格文本列表 styles: 风格文本列表
situations: 对应的situation文本列表可选 situations: 对应的situation文本列表可选
Returns: Returns:
int: 成功添加的数量 int: 成功添加的数量
""" """
@ -175,55 +177,55 @@ class StyleLearner:
situation = situations[i] if situations and i < len(situations) else None situation = situations[i] if situations and i < len(situations) else None
if self.add_style(style, situation): if self.add_style(style, situation):
success_count += 1 success_count += 1
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功") logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
return success_count return success_count
def get_all_styles(self) -> List[str]: def get_all_styles(self) -> List[str]:
"""获取所有已注册的风格""" """获取所有已注册的风格"""
return list(self.style_to_id.keys()) return list(self.style_to_id.keys())
def get_style_count(self) -> int: def get_style_count(self) -> int:
"""获取当前风格数量""" """获取当前风格数量"""
return len(self.style_to_id) return len(self.style_to_id)
def get_situation(self, style: str) -> Optional[str]: def get_situation(self, style: str) -> Optional[str]:
""" """
获取风格对应的situation 获取风格对应的situation
Args: Args:
style: 风格文本 style: 风格文本
Returns: Returns:
Optional[str]: 对应的situation如果不存在则返回None Optional[str]: 对应的situation如果不存在则返回None
""" """
if style not in self.style_to_id: if style not in self.style_to_id:
return None return None
style_id = self.style_to_id[style] style_id = self.style_to_id[style]
return self.id_to_situation.get(style_id) return self.id_to_situation.get(style_id)
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
""" """
获取风格的完整信息 获取风格的完整信息
Args: Args:
style: 风格文本 style: 风格文本
Returns: Returns:
Tuple[Optional[str], Optional[str]]: (style_id, situation) Tuple[Optional[str], Optional[str]]: (style_id, situation)
""" """
if style not in self.style_to_id: if style not in self.style_to_id:
return None, None return None, None
style_id = self.style_to_id[style] style_id = self.style_to_id[style]
situation = self.id_to_situation.get(style_id) situation = self.id_to_situation.get(style_id)
return style_id, situation return style_id, situation
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]: def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
""" """
获取所有风格的完整信息 获取所有风格的完整信息
Returns: Returns:
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)} Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
""" """
@ -232,32 +234,32 @@ class StyleLearner:
situation = self.id_to_situation.get(style_id) situation = self.id_to_situation.get(style_id)
result[style] = (style_id, situation) result[style] = (style_id, situation)
return result return result
def _rebuild_expressor(self): def _rebuild_expressor(self):
"""重新构建expressor模型删除风格后使用""" """重新构建expressor模型删除风格后使用"""
try: try:
# 重新创建expressor # 重新创建expressor
self.expressor = ExpressorModel(**self.model_config) self.expressor = ExpressorModel(**self.model_config)
# 重新添加所有风格和situation # 重新添加所有风格和situation
for style_id, style_text in self.id_to_style.items(): for style_id, style_text in self.id_to_style.items():
situation = self.id_to_situation.get(style_id) situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, style_text, situation) self.expressor.add_candidate(style_id, style_text, situation)
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型") logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}") logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
def learn_mapping(self, up_content: str, style: str) -> bool: def learn_mapping(self, up_content: str, style: str) -> bool:
""" """
学习一个up_content到style的映射 学习一个up_content到style的映射
如果style不存在会自动添加 如果style不存在会自动添加
Args: Args:
up_content: 输入内容 up_content: 输入内容
style: 对应的style文本 style: 对应的style文本
Returns: Returns:
bool: 学习是否成功 bool: 学习是否成功
""" """
@ -267,71 +269,71 @@ class StyleLearner:
if not self.add_style(style): if not self.add_style(style):
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败") logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
return False return False
# 获取style_id # 获取style_id
style_id = self.style_to_id[style] style_id = self.style_to_id[style]
# 使用正反馈学习 # 使用正反馈学习
self.expressor.update_positive(up_content, style_id) self.expressor.update_positive(up_content, style_id)
# 更新统计 # 更新统计
self.learning_stats["total_samples"] += 1 self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1 self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["style_usage_frequency"][style] += 1 self.learning_stats["style_usage_frequency"][style] += 1
self.learning_stats["last_update"] = asyncio.get_event_loop().time() self.learning_stats["last_update"] = asyncio.get_event_loop().time()
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'") logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
return True return True
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_id}] 学习映射失败: {e}") logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
traceback.print_exc() traceback.print_exc()
return False return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
""" """
根据up_content预测最合适的style 根据up_content预测最合适的style
Args: Args:
up_content: 输入内容 up_content: 输入内容
top_k: 返回前k个候选 top_k: 返回前k个候选
Returns: Returns:
Tuple[最佳style文本, 所有候选的分数] Tuple[最佳style文本, 所有候选的分数]
""" """
try: try:
best_style_id, scores = self.expressor.predict(up_content, k=top_k) best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None: if best_style_id is None:
return None, {} return None, {}
# 将style_id转换为style文本 # 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id) best_style = self.id_to_style.get(best_style_id)
# 转换所有分数 # 转换所有分数
style_scores = {} style_scores = {}
for sid, score in scores.items(): for sid, score in scores.items():
style_text = self.id_to_style.get(sid) style_text = self.id_to_style.get(sid)
if style_text: if style_text:
style_scores[style_text] = score style_scores[style_text] = score
return best_style, style_scores return best_style, style_scores
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_id}] 预测style失败: {e}") logger.error(f"[{self.chat_id}] 预测style失败: {e}")
traceback.print_exc() traceback.print_exc()
return None, {} return None, {}
def decay_learning(self, factor: Optional[float] = None) -> None: def decay_learning(self, factor: Optional[float] = None) -> None:
""" """
对学习到的知识进行衰减遗忘 对学习到的知识进行衰减遗忘
Args: Args:
factor: 衰减因子None则使用配置中的gamma factor: 衰减因子None则使用配置中的gamma
""" """
self.expressor.decay(factor) self.expressor.decay(factor)
logger.debug(f"[{self.chat_id}] 执行知识衰减") logger.debug(f"[{self.chat_id}] 执行知识衰减")
def get_stats(self) -> Dict: def get_stats(self) -> Dict:
"""获取学习统计信息""" """获取学习统计信息"""
return { return {
@ -341,20 +343,20 @@ class StyleLearner:
"style_counts": dict(self.learning_stats["style_counts"]), "style_counts": dict(self.learning_stats["style_counts"]),
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]), "style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
"last_update": self.learning_stats["last_update"], "last_update": self.learning_stats["last_update"],
"all_styles": list(self.style_to_id.keys()) "all_styles": list(self.style_to_id.keys()),
} }
def save(self, base_path: str) -> bool: def save(self, base_path: str) -> bool:
""" """
保存模型到文件 保存模型到文件
Args: Args:
base_path: 基础路径实际文件为 {base_path}/{chat_id}_style_model.pkl base_path: 基础路径实际文件为 {base_path}/{chat_id}_style_model.pkl
""" """
try: try:
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl") file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
# 保存模型和统计信息 # 保存模型和统计信息
save_data = { save_data = {
"model_config": self.model_config, "model_config": self.model_config,
@ -362,43 +364,43 @@ class StyleLearner:
"id_to_style": self.id_to_style, "id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation, "id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id, "next_style_id": self.next_style_id,
"learning_stats": self.learning_stats "learning_stats": self.learning_stats,
} }
# 先保存expressor模型 # 先保存expressor模型
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl") expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
self.expressor.save(expressor_path) self.expressor.save(expressor_path)
# 保存其他数据 # 保存其他数据
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
pickle.dump(save_data, f) pickle.dump(save_data, f)
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}") logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_id}] 保存模型失败: {e}") logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
return False return False
def load(self, base_path: str) -> bool: def load(self, base_path: str) -> bool:
""" """
从文件加载模型 从文件加载模型
Args: Args:
base_path: 基础路径 base_path: 基础路径
""" """
try: try:
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl") file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl") expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
if not os.path.exists(file_path) or not os.path.exists(expressor_path): if not os.path.exists(file_path) or not os.path.exists(expressor_path):
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置") logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
return False return False
# 加载其他数据 # 加载其他数据
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
save_data = pickle.load(f) save_data = pickle.load(f)
# 恢复配置和状态 # 恢复配置和状态
self.model_config = save_data["model_config"] self.model_config = save_data["model_config"]
self.style_to_id = save_data["style_to_id"] self.style_to_id = save_data["style_to_id"]
@ -406,14 +408,14 @@ class StyleLearner:
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本 self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
self.next_style_id = save_data["next_style_id"] self.next_style_id = save_data["next_style_id"]
self.learning_stats = save_data["learning_stats"] self.learning_stats = save_data["learning_stats"]
# 重新创建expressor并加载 # 重新创建expressor并加载
self.expressor = ExpressorModel(**self.model_config) self.expressor = ExpressorModel(**self.model_config)
self.expressor.load(expressor_path) self.expressor.load(expressor_path)
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载") logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
return True return True
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_id}] 加载模型失败: {e}") logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
return False return False
@ -425,156 +427,156 @@ class StyleLearnerManager:
为每个chat_id维护独立的StyleLearner实例 为每个chat_id维护独立的StyleLearner实例
每个chat_id可以动态管理自己的风格集合无数量上限 每个chat_id可以动态管理自己的风格集合无数量上限
""" """
def __init__(self, model_save_path: str = "data/style_models"): def __init__(self, model_save_path: str = "data/style_models"):
self.model_save_path = model_save_path self.model_save_path = model_save_path
self.learners: Dict[str, StyleLearner] = {} self.learners: Dict[str, StyleLearner] = {}
# 自动保存配置 # 自动保存配置
self.auto_save_interval = 300 # 5分钟 self.auto_save_interval = 300 # 5分钟
self._auto_save_task: Optional[asyncio.Task] = None self._auto_save_task: Optional[asyncio.Task] = None
logger.info("StyleLearnerManager 已初始化") logger.info("StyleLearnerManager 已初始化")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner: def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
""" """
获取或创建指定chat_id的学习器 获取或创建指定chat_id的学习器
Args: Args:
chat_id: 聊天室ID chat_id: 聊天室ID
model_config: 模型配置None则使用默认配置 model_config: 模型配置None则使用默认配置
Returns: Returns:
StyleLearner实例 StyleLearner实例
""" """
if chat_id not in self.learners: if chat_id not in self.learners:
# 创建新的学习器 # 创建新的学习器
learner = StyleLearner(chat_id, model_config) learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型 # 尝试加载已保存的模型
learner.load(self.model_save_path) learner.load(self.model_save_path)
self.learners[chat_id] = learner self.learners[chat_id] = learner
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner") logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
return self.learners[chat_id] return self.learners[chat_id]
def add_style(self, chat_id: str, style: str) -> bool: def add_style(self, chat_id: str, style: str) -> bool:
""" """
为指定chat_id添加风格 为指定chat_id添加风格
Args: Args:
chat_id: 聊天室ID chat_id: 聊天室ID
style: 风格文本 style: 风格文本
Returns: Returns:
bool: 添加是否成功 bool: 添加是否成功
""" """
learner = self.get_learner(chat_id) learner = self.get_learner(chat_id)
return learner.add_style(style) return learner.add_style(style)
def remove_style(self, chat_id: str, style: str) -> bool: def remove_style(self, chat_id: str, style: str) -> bool:
""" """
为指定chat_id删除风格 为指定chat_id删除风格
Args: Args:
chat_id: 聊天室ID chat_id: 聊天室ID
style: 风格文本 style: 风格文本
Returns: Returns:
bool: 删除是否成功 bool: 删除是否成功
""" """
learner = self.get_learner(chat_id) learner = self.get_learner(chat_id)
return learner.remove_style(style) return learner.remove_style(style)
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool: def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
""" """
为指定chat_id更新风格 为指定chat_id更新风格
Args: Args:
chat_id: 聊天室ID chat_id: 聊天室ID
old_style: 原风格文本 old_style: 原风格文本
new_style: 新风格文本 new_style: 新风格文本
Returns: Returns:
bool: 更新是否成功 bool: 更新是否成功
""" """
learner = self.get_learner(chat_id) learner = self.get_learner(chat_id)
return learner.update_style(old_style, new_style) return learner.update_style(old_style, new_style)
def get_chat_styles(self, chat_id: str) -> List[str]: def get_chat_styles(self, chat_id: str) -> List[str]:
""" """
获取指定chat_id的所有风格 获取指定chat_id的所有风格
Args: Args:
chat_id: 聊天室ID chat_id: 聊天室ID
Returns: Returns:
List[str]: 风格列表 List[str]: 风格列表
""" """
learner = self.get_learner(chat_id) learner = self.get_learner(chat_id)
return learner.get_all_styles() return learner.get_all_styles()
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
""" """
学习一个映射关系 学习一个映射关系
Args: Args:
chat_id: 聊天室ID chat_id: 聊天室ID
up_content: 输入内容 up_content: 输入内容
style: 对应的style style: 对应的style
Returns: Returns:
bool: 学习是否成功 bool: 学习是否成功
""" """
learner = self.get_learner(chat_id) learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style) return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
""" """
预测最合适的style 预测最合适的style
Args: Args:
chat_id: 聊天室ID chat_id: 聊天室ID
up_content: 输入内容 up_content: 输入内容
top_k: 返回前k个候选 top_k: 返回前k个候选
Returns: Returns:
Tuple[最佳style, 所有候选分数] Tuple[最佳style, 所有候选分数]
""" """
learner = self.get_learner(chat_id) learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k) return learner.predict_style(up_content, top_k)
def decay_all_learners(self, factor: Optional[float] = None) -> None: def decay_all_learners(self, factor: Optional[float] = None) -> None:
""" """
对所有学习器执行衰减 对所有学习器执行衰减
Args: Args:
factor: 衰减因子 factor: 衰减因子
""" """
for learner in self.learners.values(): for learner in self.learners.values():
learner.decay_learning(factor) learner.decay_learning(factor)
logger.info("已对所有学习器执行衰减") logger.info("已对所有学习器执行衰减")
def get_all_stats(self) -> Dict[str, Dict]: def get_all_stats(self) -> Dict[str, Dict]:
"""获取所有学习器的统计信息""" """获取所有学习器的统计信息"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()} return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
def save_all_models(self) -> bool: def save_all_models(self) -> bool:
"""保存所有模型""" """保存所有模型"""
success_count = 0 success_count = 0
for learner in self.learners.values(): for learner in self.learners.values():
if learner.save(self.model_save_path): if learner.save(self.model_save_path):
success_count += 1 success_count += 1
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型") logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
return success_count == len(self.learners) return success_count == len(self.learners)
def load_all_models(self) -> int: def load_all_models(self) -> int:
"""加载所有已保存的模型""" """加载所有已保存的模型"""
if not os.path.exists(self.model_save_path): if not os.path.exists(self.model_save_path):
return 0 return 0
loaded_count = 0 loaded_count = 0
for filename in os.listdir(self.model_save_path): for filename in os.listdir(self.model_save_path):
if filename.endswith("_style_model.pkl"): if filename.endswith("_style_model.pkl"):
@ -583,16 +585,16 @@ class StyleLearnerManager:
if learner.load(self.model_save_path): if learner.load(self.model_save_path):
self.learners[chat_id] = learner self.learners[chat_id] = learner
loaded_count += 1 loaded_count += 1
logger.info(f"已加载 {loaded_count} 个模型") logger.info(f"已加载 {loaded_count} 个模型")
return loaded_count return loaded_count
async def start_auto_save(self) -> None: async def start_auto_save(self) -> None:
"""启动自动保存任务""" """启动自动保存任务"""
if self._auto_save_task is None or self._auto_save_task.done(): if self._auto_save_task is None or self._auto_save_task.done():
self._auto_save_task = asyncio.create_task(self._auto_save_loop()) self._auto_save_task = asyncio.create_task(self._auto_save_loop())
logger.info("已启动自动保存任务") logger.info("已启动自动保存任务")
async def stop_auto_save(self) -> None: async def stop_auto_save(self) -> None:
"""停止自动保存任务""" """停止自动保存任务"""
if self._auto_save_task and not self._auto_save_task.done(): if self._auto_save_task and not self._auto_save_task.done():
@ -602,7 +604,7 @@ class StyleLearnerManager:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
logger.info("已停止自动保存任务") logger.info("已停止自动保存任务")
async def _auto_save_loop(self) -> None: async def _auto_save_loop(self) -> None:
"""自动保存循环""" """自动保存循环"""
while True: while True:

View File

@ -3,5 +3,3 @@ from .jargon_miner import extract_and_store_jargon
__all__ = [ __all__ = [
"extract_and_store_jargon", "extract_and_store_jargon",
] ]

View File

@ -250,7 +250,7 @@ def _build_stream_api_resp(
if fr: if fr:
reason = str(fr) reason = str(fr)
break break
if str(reason).endswith("MAX_TOKENS"): if str(reason).endswith("MAX_TOKENS"):
has_visible_output = bool(resp.content and resp.content.strip()) has_visible_output = bool(resp.content and resp.content.strip())
if has_visible_output: if has_visible_output:
@ -281,8 +281,8 @@ async def _default_stream_response_handler(
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 _tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录 _usage_record = None # 使用情况记录
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
resp = APIResponse() resp = APIResponse()
def _insure_buffer_closed(): def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed: if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close() _fc_delta_buffer.close()
@ -298,7 +298,7 @@ async def _default_stream_response_handler(
chunk, chunk,
_fc_delta_buffer, _fc_delta_buffer,
_tool_calls_buffer, _tool_calls_buffer,
resp=resp, resp=resp,
) )
if chunk.usage_metadata: if chunk.usage_metadata:
@ -314,7 +314,7 @@ async def _default_stream_response_handler(
_fc_delta_buffer, _fc_delta_buffer,
_tool_calls_buffer, _tool_calls_buffer,
last_resp=last_resp, last_resp=last_resp,
resp=resp, resp=resp,
), _usage_record ), _usage_record
except Exception: except Exception:
# 确保缓冲区被关闭 # 确保缓冲区被关闭

View File

@ -256,7 +256,7 @@ def _build_stream_api_resp(
# 检查 max_tokens 截断(流式的告警改由处理函数统一输出,这里不再输出) # 检查 max_tokens 截断(流式的告警改由处理函数统一输出,这里不再输出)
# 保留 finish_reason 仅用于上层判断 # 保留 finish_reason 仅用于上层判断
if not resp.content and not resp.tool_calls: if not resp.content and not resp.tool_calls:
raise EmptyResponseException() raise EmptyResponseException()
@ -310,7 +310,7 @@ async def _default_stream_response_handler(
if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason: if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason:
finish_reason = event.choices[0].finish_reason finish_reason = event.choices[0].finish_reason
if hasattr(event, "model") and event.model and not _model_name: if hasattr(event, "model") and event.model and not _model_name:
_model_name = event.model # 记录模型名 _model_name = event.model # 记录模型名
@ -358,10 +358,7 @@ async def _default_stream_response_handler(
model_dbg = None model_dbg = None
# 统一日志格式 # 统一日志格式
logger.info( logger.info("模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整" % (model_dbg or ""))
"模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整"
% (model_dbg or "")
)
return resp, _usage_record return resp, _usage_record
except Exception: except Exception:
@ -404,9 +401,7 @@ def _default_normal_response_parser(
raw_snippet = str(resp)[:300] raw_snippet = str(resp)[:300]
except Exception: except Exception:
raw_snippet = "<unserializable>" raw_snippet = "<unserializable>"
logger.debug( logger.debug(f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}")
f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}"
)
except Exception: except Exception:
# 日志采集失败不应影响控制流 # 日志采集失败不应影响控制流
pass pass
@ -464,14 +459,11 @@ def _default_normal_response_parser(
# print(resp) # print(resp)
_model_name = resp.model _model_name = resp.model
# 统一日志格式 # 统一日志格式
logger.info( logger.info("模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整" % (_model_name or ""))
"模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整"
% (_model_name or "")
)
return api_response, _usage_record return api_response, _usage_record
except Exception as e: except Exception as e:
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}") logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
if not api_response.content and not api_response.tool_calls: if not api_response.content and not api_response.tool_calls:
raise EmptyResponseException() raise EmptyResponseException()

View File

@ -328,9 +328,7 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。") logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning( logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}")
f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval) await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e: except NetworkConnectionError as e:
@ -340,9 +338,7 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。") logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning( logger.warning(f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}")
f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval) await asyncio.sleep(api_provider.retry_interval)
except RespNotOkException as e: except RespNotOkException as e:

View File

@ -5,6 +5,7 @@ from maim_message import MessageServer
from src.common.remote import TelemetryHeartBeatTask from src.common.remote import TelemetryHeartBeatTask
from src.manager.async_task_manager import async_task_manager from src.manager.async_task_manager import async_task_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
# from src.chat.utils.token_statistics import TokenStatisticsTask # from src.chat.utils.token_statistics import TokenStatisticsTask
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
@ -70,9 +71,10 @@ class MainSystem:
# 添加遥测心跳任务 # 添加遥测心跳任务
await async_task_manager.add_task(TelemetryHeartBeatTask()) await async_task_manager.add_task(TelemetryHeartBeatTask())
# 添加记忆遗忘任务 # 添加记忆遗忘任务
from src.chat.utils.memory_forget_task import MemoryForgetTask from src.chat.utils.memory_forget_task import MemoryForgetTask
await async_task_manager.add_task(MemoryForgetTask()) await async_task_manager.add_task(MemoryForgetTask())
# 启动API服务器 # 启动API服务器
@ -106,7 +108,6 @@ class MainSystem:
self.app.register_message_handler(chat_bot.message_process) self.app.register_message_handler(chat_bot.message_process)
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process) self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
# 触发 ON_START 事件 # 触发 ON_START 事件
from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType from src.plugin_system.base.component_types import EventType

View File

@ -9,16 +9,13 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools") logger = get_logger("memory_retrieval_tools")
async def query_jargon( async def query_jargon(keyword: str, chat_id: str) -> str:
keyword: str,
chat_id: str
) -> str:
"""根据关键词在jargon库中查询 """根据关键词在jargon库中查询
Args: Args:
keyword: 关键词黑话/俚语/缩写 keyword: 关键词黑话/俚语/缩写
chat_id: 聊天ID chat_id: 聊天ID
Returns: Returns:
str: 查询结果 str: 查询结果
""" """
@ -26,29 +23,17 @@ async def query_jargon(
content = str(keyword).strip() content = str(keyword).strip()
if not content: if not content:
return "关键词为空" return "关键词为空"
# 先尝试精确匹配 # 先尝试精确匹配
results = search_jargon( results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=False
)
is_fuzzy_match = False is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索 # 如果精确匹配未找到,尝试模糊搜索
if not results: if not results:
results = search_jargon( results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=True
)
is_fuzzy_match = True is_fuzzy_match = True
if results: if results:
# 如果是模糊匹配显示找到的实际jargon内容 # 如果是模糊匹配显示找到的实际jargon内容
if is_fuzzy_match: if is_fuzzy_match:
@ -71,11 +56,11 @@ async def query_jargon(
output = "".join(output_parts) if len(output_parts) > 1 else output_parts[0] output = "".join(output_parts) if len(output_parts) > 1 else output_parts[0]
logger.info(f"在jargon库中找到匹配当前会话或全局精确匹配: {content},找到{len(results)}条结果") logger.info(f"在jargon库中找到匹配当前会话或全局精确匹配: {content},找到{len(results)}条结果")
return output return output
# 未命中 # 未命中
logger.info(f"在jargon库中未找到匹配当前会话或全局精确匹配和模糊搜索都未找到: {content}") logger.info(f"在jargon库中未找到匹配当前会话或全局精确匹配和模糊搜索都未找到: {content}")
return f"未在jargon库中找到'{content}'的解释" return f"未在jargon库中找到'{content}'的解释"
except Exception as e: except Exception as e:
logger.error(f"查询jargon失败: {e}") logger.error(f"查询jargon失败: {e}")
return f"查询失败: {str(e)}" return f"查询失败: {str(e)}"
@ -86,14 +71,6 @@ def register_tool():
register_memory_retrieval_tool( register_memory_retrieval_tool(
name="query_jargon", name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索默认会先尝试精确匹配如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。", description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索默认会先尝试精确匹配如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
parameters=[ parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
{ execute_func=query_jargon,
"name": "keyword",
"type": "string",
"description": "关键词(黑话/俚语/缩写)",
"required": True
}
],
execute_func=query_jargon
) )

View File

@ -12,17 +12,13 @@ logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool: class MemoryRetrievalTool:
"""记忆检索工具基类""" """记忆检索工具基类"""
def __init__( def __init__(
self, self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
): ):
""" """
初始化工具 初始化工具
Args: Args:
name: 工具名称 name: 工具名称
description: 工具描述 description: 工具描述
@ -33,7 +29,7 @@ class MemoryRetrievalTool:
self.description = description self.description = description
self.parameters = parameters self.parameters = parameters
self.execute_func = execute_func self.execute_func = execute_func
def get_tool_description(self) -> str: def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt""" """获取工具的文本描述用于prompt"""
param_descriptions = [] param_descriptions = []
@ -44,10 +40,10 @@ class MemoryRetrievalTool:
required = param.get("required", True) required = param.get("required", True)
required_str = "必填" if required else "可选" required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}") param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数" params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}" return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs) -> str: async def execute(self, **kwargs) -> str:
"""执行工具""" """执行工具"""
return await self.execute_func(**kwargs) return await self.execute_func(**kwargs)
@ -97,10 +93,10 @@ class MemoryRetrievalTool:
class MemoryRetrievalToolRegistry: class MemoryRetrievalToolRegistry:
"""工具注册器""" """工具注册器"""
def __init__(self): def __init__(self):
self.tools: Dict[str, MemoryRetrievalTool] = {} self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None: def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具""" """注册工具"""
if tool.name in self.tools: if tool.name in self.tools:
@ -108,22 +104,22 @@ class MemoryRetrievalToolRegistry:
return return
self.tools[tool.name] = tool self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}") logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]: def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具""" """获取工具"""
return self.tools.get(name) return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]: def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具""" """获取所有工具"""
return self.tools.copy() return self.tools.copy()
def get_tools_description(self) -> str: def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt""" """获取所有工具的描述用于prompt"""
descriptions = [] descriptions = []
for i, tool in enumerate(self.tools.values(), 1): for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}") descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions) return "\n".join(descriptions)
def get_action_types_list(self) -> str: def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt已废弃保留用于兼容""" """获取所有动作类型的列表用于prompt已废弃保留用于兼容"""
action_types = [tool.name for tool in self.tools.values()] action_types = [tool.name for tool in self.tools.values()]
@ -145,13 +141,10 @@ _tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool( def register_memory_retrieval_tool(
name: str, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
) -> None: ) -> None:
"""注册记忆检索工具的便捷函数 """注册记忆检索工具的便捷函数
Args: Args:
name: 工具名称 name: 工具名称
description: 工具描述 description: 工具描述
@ -165,4 +158,3 @@ def register_memory_retrieval_tool(
def get_tool_registry() -> MemoryRetrievalToolRegistry: def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例""" """获取工具注册器实例"""
return _tool_registry return _tool_registry

View File

@ -1,10 +1,7 @@
import math
import random
import time import time
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive

View File

@ -6,7 +6,9 @@ logger = get_logger("frequency_api")
def get_current_talk_value(chat_id: str) -> float: def get_current_talk_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id) return frequency_control_manager.get_or_create_frequency_control(
chat_id
).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None: def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:

View File

@ -109,7 +109,7 @@ def get_messages_by_time_in_chat(
limit=limit, limit=limit,
limit_mode=limit_mode, limit_mode=limit_mode,
filter_bot=filter_mai, filter_bot=filter_mai,
filter_command=filter_command filter_command=filter_command,
) )

View File

@ -12,11 +12,11 @@ logger = get_logger("tool_api")
def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]: def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]:
"""获取公开工具实例 """获取公开工具实例
Args: Args:
tool_name: 工具名称 tool_name: 工具名称
chat_stream: 聊天流对象用于传递聊天上下文信息 chat_stream: 聊天流对象用于传递聊天上下文信息
Returns: Returns:
Optional[BaseTool]: 工具实例如果未找到则返回None Optional[BaseTool]: 工具实例如果未找到则返回None
""" """

View File

@ -77,7 +77,7 @@ class BaseAction(ABC):
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy() self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
"""NORMAL模式下的激活类型""" """NORMAL模式下的激活类型"""
self.activation_type = getattr(self.__class__, "activation_type") self.activation_type = self.__class__.activation_type
"""激活类型""" """激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0) self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率""" """当激活类型为RANDOM时的概率"""
@ -108,21 +108,16 @@ class BaseAction(ABC):
self.is_group = False self.is_group = False
self.target_id = None self.target_id = None
self.group_id = ( self.group_id = (
str(self.action_message.chat_info.group_info.group_id) str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None
if self.action_message.chat_info.group_info
else None
) )
self.group_name = ( self.group_name = (
self.action_message.chat_info.group_info.group_name self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None
if self.action_message.chat_info.group_info
else None
) )
self.user_id = str(self.action_message.user_info.user_id) self.user_id = str(self.action_message.user_info.user_id)
self.user_nickname = self.action_message.user_info.user_nickname self.user_nickname = self.action_message.user_info.user_nickname
if self.group_id: if self.group_id:
self.is_group = True self.is_group = True
self.target_id = self.group_id self.target_id = self.group_id
@ -132,7 +127,6 @@ class BaseAction(ABC):
self.target_id = self.user_id self.target_id = self.user_id
self.log_prefix = f"[{self.user_nickname} 的 私聊]" self.log_prefix = f"[{self.user_nickname} 的 私聊]"
logger.debug( logger.debug(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}" f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
) )
@ -448,7 +442,6 @@ class BaseAction(ABC):
wait_start_time = asyncio.get_event_loop().time() wait_start_time = asyncio.get_event_loop().time()
while True: while True:
# 检查新消息 # 检查新消息
current_time = time.time() current_time = time.time()
new_message_count = message_api.count_new_messages( new_message_count = message_api.count_new_messages(
@ -497,7 +490,7 @@ class BaseAction(ABC):
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代") raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
# 获取focus_activation_type和normal_activation_type # 获取focus_activation_type和normal_activation_type
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS) focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS) _normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
# 处理activation_type如果插件中声明了就用插件的值否则默认使用focus_activation_type # 处理activation_type如果插件中声明了就用插件的值否则默认使用focus_activation_type
activation_type = getattr(cls, "activation_type", focus_activation_type) activation_type = getattr(cls, "activation_type", focus_activation_type)

View File

@ -34,17 +34,17 @@ class BaseTool(ABC):
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None): def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None):
"""初始化工具基类 """初始化工具基类
Args: Args:
plugin_config: 插件配置字典 plugin_config: 插件配置字典
chat_stream: 聊天流对象用于获取聊天上下文信息 chat_stream: 聊天流对象用于获取聊天上下文信息
""" """
self.plugin_config = plugin_config or {} # 直接存储插件配置字典 self.plugin_config = plugin_config or {} # 直接存储插件配置字典
# ============================================================================= # =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息与BaseAction保持一致 # 便捷属性 - 直接在初始化时获取常用聊天信息与BaseAction保持一致
# ============================================================================= # =============================================================================
# 获取聊天流对象 # 获取聊天流对象
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.chat_id = self.chat_stream.stream_id if self.chat_stream else None self.chat_id = self.chat_stream.stream_id if self.chat_stream else None

View File

@ -346,9 +346,7 @@ class EventsManager:
if not isinstance(result, tuple) or len(result) != 5: if not isinstance(result, tuple) or len(result) != 5:
if isinstance(result, tuple): if isinstance(result, tuple):
annotated = ", ".join( annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False))
f"{name}={val!r}" for name, val in zip(expected_fields, result)
)
actual_desc = f"{len(result)} 个元素 ({annotated})" actual_desc = f"{len(result)} 个元素 ({annotated})"
else: else:
actual_desc = f"非 tuple 类型: {type(result)}" actual_desc = f"非 tuple 类型: {type(result)}"
@ -380,7 +378,6 @@ class EventsManager:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True) logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
return True, None # 发生异常时默认不中断其他处理 return True, None # 发生异常时默认不中断其他处理
def _task_done_callback( def _task_done_callback(
self, self,
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]], task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],

View File

@ -189,9 +189,8 @@ class ToolExecutor:
tool_info["content"] = str(content) tool_info["content"] = str(content)
# 空内容直接跳过(空字符串、全空白字符串、空列表/空元组) # 空内容直接跳过(空字符串、全空白字符串、空列表/空元组)
content_check = tool_info["content"] content_check = tool_info["content"]
if ( if (isinstance(content_check, str) and not content_check.strip()) or (
(isinstance(content_check, str) and not content_check.strip()) isinstance(content_check, (list, tuple)) and len(content_check) == 0
or (isinstance(content_check, (list, tuple)) and len(content_check) == 0)
): ):
logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示") logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示")
continue continue

View File

@ -8,29 +8,30 @@ import sys
import os import os
from pprint import pprint from pprint import pprint
def view_pkl_file(file_path): def view_pkl_file(file_path):
"""查看 pkl 文件内容""" """查看 pkl 文件内容"""
if not os.path.exists(file_path): if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}") print(f"❌ 文件不存在: {file_path}")
return return
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
data = pickle.load(f) data = pickle.load(f)
print(f"📁 文件: {file_path}") print(f"📁 文件: {file_path}")
print(f"📊 数据类型: {type(data)}") print(f"📊 数据类型: {type(data)}")
print("=" * 50) print("=" * 50)
if isinstance(data, dict): if isinstance(data, dict):
print("🔑 字典键:") print("🔑 字典键:")
for key in data.keys(): for key in data.keys():
print(f" - {key}: {type(data[key])}") print(f" - {key}: {type(data[key])}")
print() print()
print("📋 详细内容:") print("📋 详细内容:")
pprint(data, width=120, depth=10) pprint(data, width=120, depth=10)
elif isinstance(data, list): elif isinstance(data, list):
print(f"📝 列表长度: {len(data)}") print(f"📝 列表长度: {len(data)}")
if data: if data:
@ -38,16 +39,16 @@ def view_pkl_file(file_path):
print("📋 前几个元素:") print("📋 前几个元素:")
for i, item in enumerate(data[:3]): for i, item in enumerate(data[:3]):
print(f" [{i}]: {item}") print(f" [{i}]: {item}")
else: else:
print("📋 内容:") print("📋 内容:")
pprint(data, width=120, depth=10) pprint(data, width=120, depth=10)
# 如果是 expressor 模型,特别显示 token_counts 的详细信息 # 如果是 expressor 模型,特别显示 token_counts 的详细信息
if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']: if isinstance(data, dict) and "nb" in data and "token_counts" in data["nb"]:
print("\n" + "="*50) print("\n" + "=" * 50)
print("🔍 详细词汇统计 (token_counts):") print("🔍 详细词汇统计 (token_counts):")
token_counts = data['nb']['token_counts'] token_counts = data["nb"]["token_counts"]
for style_id, tokens in token_counts.items(): for style_id, tokens in token_counts.items():
print(f"\n📝 {style_id}:") print(f"\n📝 {style_id}:")
if tokens: if tokens:
@ -59,18 +60,20 @@ def view_pkl_file(file_path):
print(f" ... 还有 {len(sorted_tokens) - 10} 个词") print(f" ... 还有 {len(sorted_tokens) - 10} 个词")
else: else:
print(" (无词汇数据)") print(" (无词汇数据)")
except Exception as e: except Exception as e:
print(f"❌ 读取文件失败: {e}") print(f"❌ 读取文件失败: {e}")
def main(): def main():
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("用法: python view_pkl.py <pkl文件路径>") print("用法: python view_pkl.py <pkl文件路径>")
print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl") print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl")
return return
file_path = sys.argv[1] file_path = sys.argv[1]
view_pkl_file(file_path) view_pkl_file(file_path)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -7,57 +7,60 @@ import pickle
import sys import sys
import os import os
def view_token_counts(file_path): def view_token_counts(file_path):
"""查看 expressor.pkl 文件中的词汇统计""" """查看 expressor.pkl 文件中的词汇统计"""
if not os.path.exists(file_path): if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}") print(f"❌ 文件不存在: {file_path}")
return return
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
data = pickle.load(f) data = pickle.load(f)
print(f"📁 文件: {file_path}") print(f"📁 文件: {file_path}")
print("=" * 60) print("=" * 60)
if 'nb' not in data or 'token_counts' not in data['nb']: if "nb" not in data or "token_counts" not in data["nb"]:
print("❌ 这不是一个 expressor 模型文件") print("❌ 这不是一个 expressor 模型文件")
return return
token_counts = data['nb']['token_counts'] token_counts = data["nb"]["token_counts"]
candidates = data.get('candidates', {}) candidates = data.get("candidates", {})
print(f"🎯 找到 {len(token_counts)} 个风格") print(f"🎯 找到 {len(token_counts)} 个风格")
print("=" * 60) print("=" * 60)
for style_id, tokens in token_counts.items(): for style_id, tokens in token_counts.items():
style_text = candidates.get(style_id, "未知风格") style_text = candidates.get(style_id, "未知风格")
print(f"\n📝 {style_id}: {style_text}") print(f"\n📝 {style_id}: {style_text}")
print(f"📊 词汇数量: {len(tokens)}") print(f"📊 词汇数量: {len(tokens)}")
if tokens: if tokens:
# 按词频排序 # 按词频排序
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True) sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
print("🔤 词汇统计 (按频率排序):") print("🔤 词汇统计 (按频率排序):")
for i, (word, count) in enumerate(sorted_tokens): for i, (word, count) in enumerate(sorted_tokens):
print(f" {i+1:2d}. '{word}': {count}") print(f" {i + 1:2d}. '{word}': {count}")
else: else:
print(" (无词汇数据)") print(" (无词汇数据)")
print("-" * 40) print("-" * 40)
except Exception as e: except Exception as e:
print(f"❌ 读取文件失败: {e}") print(f"❌ 读取文件失败: {e}")
def main(): def main():
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("用法: python view_tokens.py <expressor.pkl文件路径>") print("用法: python view_tokens.py <expressor.pkl文件路径>")
print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl") print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl")
return return
file_path = sys.argv[1] file_path = sys.argv[1]
view_token_counts(file_path) view_token_counts(file_path)
if __name__ == "__main__": if __name__ == "__main__":
main() main()