pull/1359/head
SengokuCola 2025-11-13 19:00:59 +08:00
commit d306e40db0
46 changed files with 1000 additions and 1041 deletions

2
bot.py
View File

@ -30,7 +30,7 @@ else:
raise raise
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
initialize_logging() initialize_logging()

View File

@ -1,15 +1,11 @@
from typing import List, Tuple, Type, Optional from typing import List, Tuple, Type, Optional
from src.plugin_system import ( from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField
BasePlugin,
register_plugin,
BaseCommand,
ComponentInfo,
ConfigField
)
from src.plugin_system.apis import send_api, frequency_api from src.plugin_system.apis import send_api, frequency_api
class SetTalkFrequencyCommand(BaseCommand): class SetTalkFrequencyCommand(BaseCommand):
"""设置当前聊天的talk_frequency值""" """设置当前聊天的talk_frequency值"""
command_name = "set_talk_frequency" command_name = "set_talk_frequency"
command_description = "设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>" command_description = "设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>"
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$" command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
@ -43,7 +39,7 @@ class SetTalkFrequencyCommand(BaseCommand):
await send_api.text_to_stream( await send_api.text_to_stream(
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}", f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
chat_id, chat_id,
storage_message=False storage_message=False,
) )
return True, None, False return True, None, False
@ -60,6 +56,7 @@ class SetTalkFrequencyCommand(BaseCommand):
class ShowFrequencyCommand(BaseCommand): class ShowFrequencyCommand(BaseCommand):
"""显示当前聊天的频率控制状态""" """显示当前聊天的频率控制状态"""
command_name = "show_frequency" command_name = "show_frequency"
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s" command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
command_pattern = r"^/chat\s+(?:show|s)$" command_pattern = r"^/chat\s+(?:show|s)$"
@ -116,11 +113,7 @@ class BetterFrequencyPlugin(BasePlugin):
config_file_name: str = "config.toml" config_file_name: str = "config.toml"
# 配置节描述 # 配置节描述
config_section_descriptions = { config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"}
"plugin": "插件基本信息",
"frequency": "频率控制配置",
"features": "功能开关配置"
}
# 配置Schema定义 # 配置Schema定义
config_schema: dict = { config_schema: dict = {
@ -141,10 +134,11 @@ class BetterFrequencyPlugin(BasePlugin):
# 根据配置决定是否注册命令组件 # 根据配置决定是否注册命令组件
if self.config.get("features", {}).get("enable_commands", True): if self.config.get("features", {}).get("enable_commands", True):
components.extend([ components.extend(
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand), [
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand), (SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
]) (ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
]
)
return components return components

View File

@ -6,15 +6,16 @@ import sys
import os import os
from datetime import datetime from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
# 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级) # 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path: if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT) sys.path.insert(0, PROJECT_ROOT)
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
SECONDS_5_MINUTES = 5 * 60 SECONDS_5_MINUTES = 5 * 60
@ -30,13 +31,13 @@ def clean_output_text(text: str) -> str:
return text return text
# 移除表情包内容:[表情包:...] # 移除表情包内容:[表情包:...]
text = re.sub(r'\[表情包:[^\]]*\]', '', text) text = re.sub(r"\[表情包:[^\]]*\]", "", text)
# 移除回复内容:[回复...],说:... 的完整模式 # 移除回复内容:[回复...],说:... 的完整模式
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text) text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
# 清理多余的空格和换行 # 清理多余的空格和换行
text = re.sub(r'\s+', ' ', text).strip() text = re.sub(r"\s+", " ", text).strip()
return text return text
@ -89,7 +90,7 @@ def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[Databa
for msg in messages: for msg in messages:
groups.setdefault(msg.chat_id, []).append(msg) groups.setdefault(msg.chat_id, []).append(msg)
# 保证每个分组内按时间升序 # 保证每个分组内按时间升序
for chat_id, msgs in groups.items(): for _chat_id, msgs in groups.items():
msgs.sort(key=lambda m: m.time or 0) msgs.sort(key=lambda m: m.time or 0)
return groups return groups
@ -170,8 +171,8 @@ def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseM
continue continue
last = bucket[-1] last = bucket[-1]
same_user = (msg.user_info.user_id == last.user_info.user_id) same_user = msg.user_info.user_id == last.user_info.user_id
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES) close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
if same_user and close_enough: if same_user and close_enough:
bucket.append(msg) bucket.append(msg)
@ -209,13 +210,11 @@ def build_pairs_for_chat(
for merged_idx, merged_msg in enumerate(merged_messages): for merged_idx, merged_msg in enumerate(merged_messages):
# 找到这个合并消息对应的第一个原始消息 # 找到这个合并消息对应的第一个原始消息
while (original_idx < n_original and while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
original_messages[original_idx].time < merged_msg.time):
original_idx += 1 original_idx += 1
# 如果找到了时间匹配的原始消息,建立映射 # 如果找到了时间匹配的原始消息,建立映射
if (original_idx < n_original and if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
original_messages[original_idx].time == merged_msg.time):
merged_to_original_map[merged_idx] = original_idx merged_to_original_map[merged_idx] = original_idx
for merged_idx in range(n_merged): for merged_idx in range(n_merged):
@ -266,7 +265,7 @@ def build_pairs(
groups = group_by_chat(messages) groups = group_by_chat(messages)
all_pairs: List[Tuple[str, str, str]] = [] all_pairs: List[Tuple[str, str, str]] = []
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用 for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
# 对消息进行合并用于output # 对消息进行合并用于output
merged = merge_adjacent_same_user(msgs) merged = merge_adjacent_same_user(msgs)
# 传递原始消息和合并后消息input使用原始消息output使用合并后消息 # 传递原始消息和合并后消息input使用原始消息output使用合并后消息
@ -385,5 +384,3 @@ def run_interactive() -> int:
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(main()) sys.exit(main())

View File

@ -1,4 +1,3 @@
import time
import sys import sys
import os import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -6,16 +5,17 @@ import matplotlib.dates as mdates
from datetime import datetime from datetime import datetime
from typing import List, Tuple from typing import List, Tuple
import numpy as np import numpy as np
from src.common.database.database_model import Expression, ChatStreams
# Add project root to Python path # Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.common.database.database_model import Expression, ChatStreams
# 设置中文字体 # 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans'] plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams['axes.unicode_minus'] = False plt.rcParams["axes.unicode_minus"] = False
def get_chat_name(chat_id: str) -> str: def get_chat_name(chat_id: str) -> str:
@ -45,12 +45,7 @@ def get_expression_data() -> List[Tuple[float, float, str, str]]:
if expr.create_date is None: if expr.create_date is None:
continue continue
data.append(( data.append((expr.create_date, expr.count, expr.chat_id, expr.type))
expr.create_date,
expr.count,
expr.chat_id,
expr.type
))
return data return data
@ -64,8 +59,8 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
# 分离数据 # 分离数据
create_dates = [item[0] for item in data] create_dates = [item[0] for item in data]
counts = [item[1] for item in data] counts = [item[1] for item in data]
chat_ids = [item[2] for item in data] _chat_ids = [item[2] for item in data]
expression_types = [item[3] for item in data] _expression_types = [item[3] for item in data]
# 转换时间戳为datetime对象 # 转换时间戳为datetime对象
dates = [datetime.fromtimestamp(ts) for ts in create_dates] dates = [datetime.fromtimestamp(ts) for ts in create_dates]
@ -73,15 +68,15 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
# 计算时间跨度,自动调整显示格式 # 计算时间跨度,自动调整显示格式
time_span = max(dates) - min(dates) time_span = max(dates) - min(dates)
if time_span.days > 30: # 超过30天按月显示 if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator() major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7) minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示 elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1) major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12) minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示 else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M' date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6) major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1) minor_locator = mdates.HourLocator(interval=1)
@ -89,12 +84,12 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
fig, ax = plt.subplots(figsize=(12, 8)) fig, ax = plt.subplots(figsize=(12, 8))
# 创建散点图 # 创建散点图
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis') scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap="viridis")
# 设置标签和标题 # 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12) ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12) ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold') ax.set_title("表达式使用次数随时间分布散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整 # 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format)) ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
@ -107,13 +102,13 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
# 添加颜色条 # 添加颜色条
cbar = plt.colorbar(scatter) cbar = plt.colorbar(scatter)
cbar.set_label('数据点顺序', fontsize=10) cbar.set_label("数据点顺序", fontsize=10)
# 调整布局 # 调整布局
plt.tight_layout() plt.tight_layout()
# 显示统计信息 # 显示统计信息
print(f"\n=== 数据统计 ===") print("\n=== 数据统计 ===")
print(f"总数据点数量: {len(data)}") print(f"总数据点数量: {len(data)}")
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')}{max(dates).strftime('%Y-%m-%d %H:%M:%S')}") print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')}{max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
print(f"使用次数范围: {min(counts):.1f}{max(counts):.1f}") print(f"使用次数范围: {min(counts):.1f}{max(counts):.1f}")
@ -122,7 +117,7 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
# 保存图片 # 保存图片
if save_path: if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n散点图已保存到: {save_path}") print(f"\n散点图已保存到: {save_path}")
# 显示图片 # 显示图片
@ -147,15 +142,15 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
all_dates = [datetime.fromtimestamp(item[0]) for item in data] all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates) time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示 if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator() major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7) minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示 elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1) major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12) minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示 else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M' date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6) major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1) minor_locator = mdates.HourLocator(interval=1)
@ -174,14 +169,21 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
# 截断过长的聊天名称 # 截断过长的聊天名称
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
ax.scatter(dates, counts, alpha=0.7, s=40, ax.scatter(
c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)", dates,
edgecolors='black', linewidth=0.5) counts,
alpha=0.7,
s=40,
c=[colors[i]],
label=f"{display_name} ({len(chat_data)}个)",
edgecolors="black",
linewidth=0.5,
)
# 设置标签和标题 # 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12) ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12) ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold') ax.set_title("按聊天分组的表达式使用次数散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整 # 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format)) ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
@ -190,7 +192,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
plt.xticks(rotation=45) plt.xticks(rotation=45)
# 添加图例 # 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8) ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
# 添加网格 # 添加网格
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
@ -199,7 +201,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
plt.tight_layout() plt.tight_layout()
# 显示统计信息 # 显示统计信息
print(f"\n=== 分组统计 ===") print("\n=== 分组统计 ===")
print(f"总聊天数量: {len(chat_groups)}") print(f"总聊天数量: {len(chat_groups)}")
for chat_id, chat_data in chat_groups.items(): for chat_id, chat_data in chat_groups.items():
chat_name = get_chat_name(chat_id) chat_name = get_chat_name(chat_id)
@ -208,7 +210,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
# 保存图片 # 保存图片
if save_path: if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n分组散点图已保存到: {save_path}") print(f"\n分组散点图已保存到: {save_path}")
# 显示图片 # 显示图片
@ -233,15 +235,15 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
all_dates = [datetime.fromtimestamp(item[0]) for item in data] all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates) time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示 if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator() major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7) minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示 elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d' date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1) major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12) minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示 else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M' date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6) major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1) minor_locator = mdates.HourLocator(interval=1)
@ -256,14 +258,21 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
counts = [item[1] for item in type_data] counts = [item[1] for item in type_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates] dates = [datetime.fromtimestamp(ts) for ts in create_dates]
ax.scatter(dates, counts, alpha=0.7, s=40, ax.scatter(
c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)", dates,
edgecolors='black', linewidth=0.5) counts,
alpha=0.7,
s=40,
c=[colors[i]],
label=f"{expr_type} ({len(type_data)}个)",
edgecolors="black",
linewidth=0.5,
)
# 设置标签和标题 # 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12) ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12) ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold') ax.set_title("按表达式类型分组的散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整 # 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format)) ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
@ -272,7 +281,7 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
plt.xticks(rotation=45) plt.xticks(rotation=45)
# 添加图例 # 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
# 添加网格 # 添加网格
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
@ -281,14 +290,14 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
plt.tight_layout() plt.tight_layout()
# 显示统计信息 # 显示统计信息
print(f"\n=== 类型统计 ===") print("\n=== 类型统计 ===")
for expr_type, type_data in type_groups.items(): for expr_type, type_data in type_groups.items():
counts = [item[1] for item in type_data] counts = [item[1] for item in type_data]
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}") print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片 # 保存图片
if save_path: if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n类型散点图已保存到: {save_path}") print(f"\n类型散点图已保存到: {save_path}")
# 显示图片 # 显示图片

View File

@ -945,9 +945,7 @@ class EmojiManager:
prompt, image_base64, "jpg", temperature=0.5 prompt, image_base64, "jpg", temperature=0.5
) )
else: else:
prompt = ( prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析精简回答"
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析精简回答"
)
description, _ = await self.vlm.generate_response_for_image( description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.5 prompt, image_base64, image_format, temperature=0.5
) )

View File

@ -12,6 +12,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import frequency_api from src.plugin_system.apis import frequency_api
def init_prompt(): def init_prompt():
Prompt( Prompt(
"""{name_block} """{name_block}
@ -54,7 +55,6 @@ class FrequencyControl:
"""设置发言频率调整值""" """设置发言频率调整值"""
self.talk_frequency_adjust = max(0.1, min(5.0, value)) self.talk_frequency_adjust = max(0.1, min(5.0, value))
async def trigger_frequency_adjust(self) -> None: async def trigger_frequency_adjust(self) -> None:
msg_list = get_raw_msg_by_timestamp_with_chat( msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id, chat_id=self.chat_id,
@ -62,7 +62,6 @@ class FrequencyControl:
timestamp_end=time.time(), timestamp_end=time.time(),
) )
if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20: if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20:
return return
else: else:
@ -118,7 +117,8 @@ class FrequencyControl:
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2)) self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
self.last_frequency_adjust_time = time.time() self.last_frequency_adjust_time = time.time()
else: else:
logger.info(f"频率调整response不符合要求取消本次调整") logger.info("频率调整response不符合要求取消本次调整")
class FrequencyControlManager: class FrequencyControlManager:
"""频率控制管理器,管理多个聊天流的频率控制实例""" """频率控制管理器,管理多个聊天流的频率控制实例"""
@ -143,6 +143,7 @@ class FrequencyControlManager:
"""获取所有有频率控制的聊天ID""" """获取所有有频率控制的聊天ID"""
return list(self.frequency_control_dict.keys()) return list(self.frequency_control_dict.keys())
init_prompt() init_prompt()
# 创建全局实例 # 创建全局实例

View File

@ -1,5 +1,4 @@
import asyncio import asyncio
from multiprocessing import context
import time import time
import traceback import traceback
import random import random
@ -102,7 +101,7 @@ class HeartFChatting:
self.is_mute = False self.is_mute = False
self.last_active_time = time.time() # 记录上一次非noreply时间 self.last_active_time = time.time() # 记录上一次非noreply时间
self.question_probability_multiplier = 1 self.question_probability_multiplier = 1
self.questioned = False self.questioned = False
@ -191,9 +190,6 @@ class HeartFChatting:
filter_command=True, filter_command=True,
) )
# 根据连续 no_reply 次数动态调整阈值 # 根据连续 no_reply 次数动态调整阈值
# 3次 no_reply 时,阈值调高到 1.550%概率为150%概率为2 # 3次 no_reply 时,阈值调高到 1.550%概率为150%概率为2
# 5次 no_reply 时,提高到 2大于等于两条消息的阈值 # 5次 no_reply 时,提高到 2大于等于两条消息的阈值
@ -207,7 +203,7 @@ class HeartFChatting:
if len(recent_messages_list) >= threshold: if len(recent_messages_list) >= threshold:
# for message in recent_messages_list: # for message in recent_messages_list:
# print(message.processed_plain_text) # print(message.processed_plain_text)
# !处理no_reply_until_call逻辑 # !处理no_reply_until_call逻辑
if self.no_reply_until_call: if self.no_reply_until_call:
for message in recent_messages_list: for message in recent_messages_list:
@ -395,14 +391,15 @@ class HeartFChatting:
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if ) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
if recent_messages_list is None: if recent_messages_list is None:
recent_messages_list = [] recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError _reply_text = "" # 初始化reply_text变量避免UnboundLocalError
start_time = time.time() start_time = time.time()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
asyncio.create_task(self.expression_learner.trigger_learning_for_chat()) asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
asyncio.create_task(frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()) asyncio.create_task(
frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()
)
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容 # 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
# asyncio.create_task(check_and_make_question(self.stream_id)) # asyncio.create_task(check_and_make_question(self.stream_id))
@ -412,7 +409,6 @@ class HeartFChatting:
# 注意后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理 # 注意后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
# asyncio.create_task(self.chat_history_summarizer.process()) # asyncio.create_task(self.chat_history_summarizer.process())
cycle_timers, thinking_id = self.start_cycle() cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
@ -457,7 +453,12 @@ class HeartFChatting:
# 处理回复结果 # 处理回复结果
if isinstance(reply_result, BaseException): if isinstance(reply_result, BaseException):
logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}") logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}")
reply_result = {"action_type": "reply", "success": False, "result": "回复生成异常", "loop_info": None} reply_result = {
"action_type": "reply",
"success": False,
"result": "回复生成异常",
"loop_info": None,
}
else: else:
# 正常流程只执行planner # 正常流程只执行planner
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
@ -558,7 +559,7 @@ class HeartFChatting:
"taken_time": time.time(), "taken_time": time.time(),
} }
) )
reply_text = reply_text_from_reply _reply_text = reply_text_from_reply
else: else:
# 没有回复信息构建纯动作的loop_info # 没有回复信息构建纯动作的loop_info
loop_info = { loop_info = {
@ -571,7 +572,7 @@ class HeartFChatting:
"taken_time": time.time(), "taken_time": time.time(),
}, },
} }
reply_text = action_reply_text _reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers) self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers) self.print_cycle_info(cycle_timers)
@ -647,7 +648,6 @@ class HeartFChatting:
result = await action_handler.execute() result = await action_handler.execute()
success, action_text = result success, action_text = result
return success, action_text return success, action_text
except Exception as e: except Exception as e:
@ -655,8 +655,6 @@ class HeartFChatting:
traceback.print_exc() traceback.print_exc()
return False, "" return False, ""
async def _send_response( async def _send_response(
self, self,
reply_set: "ReplySetModel", reply_set: "ReplySetModel",
@ -732,7 +730,6 @@ class HeartFChatting:
action_reasoning=reason, action_reasoning=reason,
) )
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""} return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
elif action_planner_info.action_type == "no_reply_until_call": elif action_planner_info.action_type == "no_reply_until_call":
@ -753,7 +750,12 @@ class HeartFChatting:
action_name="no_reply_until_call", action_name="no_reply_until_call",
action_reasoning=reason, action_reasoning=reason,
) )
return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""} return {
"action_type": "no_reply_until_call",
"success": True,
"result": "保持沉默,直到有人直接叫的名字",
"command": "",
}
elif action_planner_info.action_type == "reply": elif action_planner_info.action_type == "reply":
# 直接当场执行reply逻辑 # 直接当场执行reply逻辑
@ -783,19 +785,16 @@ class HeartFChatting:
enable_tool=global_config.tool.enable_tool, enable_tool=global_config.tool.enable_tool,
request_type="replyer", request_type="replyer",
from_plugin=False, from_plugin=False,
reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()), reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
) )
if not success or not llm_response or not llm_response.reply_set: if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message: if action_planner_info.action_message:
logger.info( logger.info(f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败")
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
else: else:
logger.info("回复生成失败") logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None} return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
response_set = llm_response.reply_set response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply( loop_info, reply_text, _ = await self._send_and_store_reply(
@ -817,12 +816,12 @@ class HeartFChatting:
# 执行普通动作 # 执行普通动作
with Timer("动作执行", cycle_timers): with Timer("动作执行", cycle_timers):
success, result = await self._handle_action( success, result = await self._handle_action(
action = action_planner_info.action_type, action=action_planner_info.action_type,
action_reasoning = action_planner_info.action_reasoning or "", action_reasoning=action_planner_info.action_reasoning or "",
action_data = action_planner_info.action_data or {}, action_data=action_planner_info.action_data or {},
cycle_timers = cycle_timers, cycle_timers=cycle_timers,
thinking_id = thinking_id, thinking_id=thinking_id,
action_message= action_planner_info.action_message, action_message=action_planner_info.action_message,
) )
self.last_active_time = time.time() self.last_active_time = time.time()

View File

@ -13,10 +13,11 @@ from src.person_info.person_info import Person
from src.common.database.database_model import Images from src.common.database.database_model import Images
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.heart_flow.heartFC_chat import HeartFChatting pass
logger = get_logger("chat") logger = get_logger("chat")
class HeartFCMessageReceiver: class HeartFCMessageReceiver:
"""心流处理器,负责处理接收到的消息并计算兴趣度""" """心流处理器,负责处理接收到的消息并计算兴趣度"""

View File

@ -15,7 +15,6 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType from src.plugin_system.base import BaseCommand, EventType
from src.person_info.person_info import Person
# 定义日志配置 # 定义日志配置
@ -171,7 +170,11 @@ class ChatBot:
# 撤回事件打印;无法获取被撤回者则省略 # 撤回事件打印;无法获取被撤回者则省略
if sub_type == "recall": if sub_type == "recall":
op_name = getattr(op, "user_cardname", None) or getattr(op, "user_nickname", None) or str(getattr(op, "user_id", None)) op_name = (
getattr(op, "user_cardname", None)
or getattr(op, "user_nickname", None)
or str(getattr(op, "user_id", None))
)
recalled_name = None recalled_name = None
try: try:
if isinstance(recalled, dict): if isinstance(recalled, dict):
@ -189,7 +192,7 @@ class ChatBot:
logger.info(f"{op_name} 撤回了消息") logger.info(f"{op_name} 撤回了消息")
else: else:
logger.debug( logger.debug(
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op,'user_nickname',None)}({getattr(op,'user_id',None)}) " f"[notice] sub_type={sub_type} scene={scene} op={getattr(op, 'user_nickname', None)}({getattr(op, 'user_id', None)}) "
f"gid={gid} msg_id={msg_id} recalled={recalled_id}" f"gid={gid} msg_id={msg_id} recalled={recalled_id}"
) )
except Exception: except Exception:
@ -234,7 +237,6 @@ class ChatBot:
# 确保所有任务已启动 # 确保所有任务已启动
await self._ensure_started() await self._ensure_started()
if message_data["message_info"].get("group_info") is not None: if message_data["message_info"].get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str( message_data["message_info"]["group_info"]["group_id"] = str(
message_data["message_info"]["group_info"]["group_id"] message_data["message_info"]["group_info"]["group_id"]

View File

@ -143,7 +143,6 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0 self.last_obs_time_mark = 0.0
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = [] self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
def find_message_by_id( def find_message_by_id(
@ -306,7 +305,9 @@ class ActionPlanner:
loop_start_time=loop_start_time, loop_start_time=loop_start_time,
) )
logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}") logger.info(
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
self.add_plan_log(reasoning, actions) self.add_plan_log(reasoning, actions)
@ -402,8 +403,7 @@ class ActionPlanner:
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]: ) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)""" """构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try: try:
actions_before_now_block = self.get_plan_log_str()
actions_before_now_block=self.get_plan_log_str()
# 构建聊天上下文描述 # 构建聊天上下文描述
chat_context_description = "你现在正在一个群聊中" chat_context_description = "你现在正在一个群聊中"
@ -564,7 +564,7 @@ class ActionPlanner:
filtered_actions: Dict[str, ActionInfo], filtered_actions: Dict[str, ActionInfo],
available_actions: Dict[str, ActionInfo], available_actions: Dict[str, ActionInfo],
loop_start_time: float, loop_start_time: float,
) -> Tuple[str,List[ActionPlannerInfo]]: ) -> Tuple[str, List[ActionPlannerInfo]]:
"""执行主规划器""" """执行主规划器"""
llm_content = None llm_content = None
actions: List[ActionPlannerInfo] = [] actions: List[ActionPlannerInfo] = []
@ -589,7 +589,7 @@ class ActionPlanner:
except Exception as req_e: except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return f"LLM 请求失败,模型出现问题: {req_e}",[ return f"LLM 请求失败,模型出现问题: {req_e}", [
ActionPlannerInfo( ActionPlannerInfo(
action_type="no_reply", action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}", reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
@ -608,7 +608,11 @@ class ActionPlanner:
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象") logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
filtered_actions_list = list(filtered_actions.items()) filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects: for json_obj in json_objects:
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list, extracted_reasoning)) actions.extend(
self._parse_single_action(
json_obj, message_id_list, filtered_actions_list, extracted_reasoning
)
)
else: else:
# 尝试解析为直接的JSON # 尝试解析为直接的JSON
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}") logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
@ -631,7 +635,7 @@ class ActionPlanner:
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}") logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
return extracted_reasoning,actions return extracted_reasoning, actions
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]: def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
"""创建no_reply""" """创建no_reply"""
@ -674,7 +678,7 @@ class ActionPlanner:
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释 json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
if json_str := json_str.strip(): if json_str := json_str.strip():
# 尝试按行分割每行可能是一个JSON对象 # 尝试按行分割每行可能是一个JSON对象
lines = [line.strip() for line in json_str.split('\n') if line.strip()] lines = [line.strip() for line in json_str.split("\n") if line.strip()]
for line in lines: for line in lines:
try: try:
# 尝试解析每一行作为独立的JSON对象 # 尝试解析每一行作为独立的JSON对象

View File

@ -276,7 +276,6 @@ class DefaultReplyer:
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood() mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}" return f"你现在的心情是:{mood_state}"
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块 """构建工具信息块
@ -303,7 +302,7 @@ class DefaultReplyer:
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result") _result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}】: {content}\n" tool_info_str += f"- 【{tool_name}】: {content}\n"
@ -605,9 +604,11 @@ class DefaultReplyer:
prompt_personality = global_config.personality.personality prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态 # 检查是否需要随机替换为状态
if (global_config.personality.states and if (
global_config.personality.state_probability > 0 and global_config.personality.states
random.random() < global_config.personality.state_probability): and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
# 随机选择一个状态替换personality # 随机选择一个状态替换personality
selected_state = random.choice(global_config.personality.states) selected_state = random.choice(global_config.personality.states)
prompt_personality = selected_state prompt_personality = selected_state
@ -720,7 +721,7 @@ class DefaultReplyer:
available_actions = {} available_actions = {}
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info) _is_group_chat = bool(chat_stream.group_info)
platform = chat_stream.platform platform = chat_stream.platform
user_id = "用户ID" user_id = "用户ID"
@ -956,9 +957,7 @@ class DefaultReplyer:
) )
elif has_text and pic_part: elif has_text and pic_part:
# 既有图片又有文字 # 既有图片又有文字
reply_target_block = ( reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
else: else:
# 只包含文字 # 只包含文字
reply_target_block = ( reply_target_block = (
@ -975,7 +974,9 @@ class DefaultReplyer:
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。" reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part: elif has_text and pic_part:
# 既有图片又有文字 # 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。" reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
)
else: else:
# 只包含文字 # 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。" reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
@ -1132,6 +1133,7 @@ class DefaultReplyer:
logger.error(f"获取知识库内容时发生异常: {str(e)}") logger.error(f"获取知识库内容时发生异常: {str(e)}")
return "" return ""
def weighted_sample_no_replacement(items, weights, k) -> list: def weighted_sample_no_replacement(items, weights, k) -> list:
""" """
加权且不放回地随机抽取k个元素 加权且不放回地随机抽取k个元素

View File

@ -46,6 +46,7 @@ init_memory_retrieval_prompt()
logger = get_logger("replyer") logger = get_logger("replyer")
class PrivateReplyer: class PrivateReplyer:
def __init__( def __init__(
self, self,
@ -277,9 +278,7 @@ class PrivateReplyer:
expression_habits_block = "" expression_habits_block = ""
expression_habits_title = "" expression_habits_title = ""
if style_habits_str.strip(): if style_habits_str.strip():
expression_habits_title = ( expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
"在回复时,你可以参考以下的语言习惯,不要生硬使用:"
)
expression_habits_block += f"{style_habits_str}\n" expression_habits_block += f"{style_habits_str}\n"
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
@ -291,7 +290,6 @@ class PrivateReplyer:
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood() mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}" return f"你现在的心情是:{mood_state}"
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块 """构建工具信息块
@ -519,9 +517,11 @@ class PrivateReplyer:
prompt_personality = global_config.personality.personality prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态 # 检查是否需要随机替换为状态
if (global_config.personality.states and if (
global_config.personality.state_probability > 0 and global_config.personality.states
random.random() < global_config.personality.state_probability): and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
# 随机选择一个状态替换personality # 随机选择一个状态替换personality
selected_state = random.choice(global_config.personality.states) selected_state = random.choice(global_config.personality.states)
prompt_personality = selected_state prompt_personality = selected_state
@ -647,8 +647,6 @@ class PrivateReplyer:
sender = person_name sender = person_name
target = reply_message.processed_plain_text target = reply_message.processed_plain_text
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
# 在picid替换之前分析内容类型防止prompt注入 # 在picid替换之前分析内容类型防止prompt注入
@ -710,9 +708,7 @@ class PrivateReplyer:
self._time_and_run_task( self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits" self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
), ),
self._time_and_run_task( self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"),
self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"
),
self._time_and_run_task( self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info" self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
), ),
@ -859,8 +855,6 @@ class PrivateReplyer:
# 将[picid:xxx]替换为具体的图片描述 # 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target) target = self._replace_picids_with_descriptions(target)
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id,
timestamp=time.time(), timestamp=time.time(),
@ -900,9 +894,7 @@ class PrivateReplyer:
) )
elif has_text and pic_part: elif has_text and pic_part:
# 既有图片又有文字 # 既有图片又有文字
reply_target_block = ( reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
else: else:
# 只包含文字 # 只包含文字
reply_target_block = ( reply_target_block = (
@ -919,7 +911,9 @@ class PrivateReplyer:
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。" reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part: elif has_text and pic_part:
# 既有图片又有文字 # 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。" reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
)
else: else:
# 只包含文字 # 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。" reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
@ -1106,6 +1100,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
pool.pop(idx) pool.pop(idx)
break break
return selected return selected

View File

@ -1,16 +1,13 @@
from src.chat.utils.prompt_builder import Prompt from src.chat.utils.prompt_builder import Prompt
# from src.chat.memory_system.memory_activator import MemoryActivator # from src.chat.memory_system.memory_activator import MemoryActivator
def init_replyer_prompt(): def init_replyer_prompt():
Prompt("正在群里聊天", "chat_target_group2") Prompt("正在群里聊天", "chat_target_group2")
Prompt("{sender_name}聊天", "chat_target_private2") Prompt("{sender_name}聊天", "chat_target_private2")
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval} {expression_habits_block}{memory_retrieval}
你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片: 你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片:
@ -28,9 +25,8 @@ def init_replyer_prompt():
"replyer_prompt", "replyer_prompt",
) )
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval} {expression_habits_block}{memory_retrieval}
你正在和{sender_name}聊天这是你们之前聊的内容: 你正在和{sender_name}聊天这是你们之前聊的内容:
@ -47,9 +43,8 @@ def init_replyer_prompt():
"private_replyer_prompt", "private_replyer_prompt",
) )
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval} {expression_habits_block}{memory_retrieval}
你正在和{sender_name}聊天这是你们之前聊的内容: 你正在和{sender_name}聊天这是你们之前聊的内容:

View File

@ -2,6 +2,7 @@
聊天内容概括器 聊天内容概括器
用于累积打包和压缩聊天记录 用于累积打包和压缩聊天记录
""" """
import asyncio import asyncio
import json import json
import time import time
@ -23,6 +24,7 @@ logger = get_logger("chat_history_summarizer")
@dataclass @dataclass
class MessageBatch: class MessageBatch:
"""消息批次""" """消息批次"""
messages: List[DatabaseMessages] messages: List[DatabaseMessages]
start_time: float start_time: float
end_time: float end_time: float
@ -52,8 +54,7 @@ class ChatHistorySummarizer:
# LLM请求器用于压缩聊天内容 # LLM请求器用于压缩聊天内容
self.summarizer_llm = LLMRequest( self.summarizer_llm = LLMRequest(
model_set=model_config.model_task_config.utils, model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
request_type="chat_history_summarizer"
) )
# 后台循环相关 # 后台循环相关
@ -117,9 +118,7 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages) before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages) self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time self.current_batch.end_time = current_time
logger.info( logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
else: else:
# 创建新批次 # 创建新批次
self.current_batch = MessageBatch( self.current_batch = MessageBatch(
@ -127,9 +126,7 @@ class ChatHistorySummarizer:
start_time=new_messages[0].time if new_messages else current_time, start_time=new_messages[0].time if new_messages else current_time,
end_time=current_time, end_time=current_time,
) )
logger.info( logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息"
)
# 检查是否需要打包 # 检查是否需要打包
await self._check_and_package(current_time) await self._check_and_package(current_time)
@ -137,6 +134,7 @@ class ChatHistorySummarizer:
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}") logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
async def _check_and_package(self, current_time: float): async def _check_and_package(self, current_time: float):
@ -153,9 +151,9 @@ class ChatHistorySummarizer:
if time_since_last_message < 60: if time_since_last_message < 60:
time_str = f"{time_since_last_message:.1f}" time_str = f"{time_since_last_message:.1f}"
elif time_since_last_message < 3600: elif time_since_last_message < 3600:
time_str = f"{time_since_last_message/60:.1f}分钟" time_str = f"{time_since_last_message / 60:.1f}分钟"
else: else:
time_str = f"{time_since_last_message/3600:.1f}小时" time_str = f"{time_since_last_message / 3600:.1f}小时"
preparing_status = "" if self.current_batch.is_preparing else "" preparing_status = "" if self.current_batch.is_preparing else ""
@ -250,26 +248,23 @@ class ChatHistorySummarizer:
participants_set: Set[str] = set() participants_set: Set[str] = set()
for msg in messages: for msg in messages:
# 使用 msg.user_platform扁平化字段或 msg.user_info.platform # 使用 msg.user_platform扁平化字段或 msg.user_info.platform
platform = getattr(msg, 'user_platform', None) or (msg.user_info.platform if msg.user_info else None) or msg.chat_info.platform platform = (
person = Person( getattr(msg, "user_platform", None)
platform=platform, or (msg.user_info.platform if msg.user_info else None)
user_id=msg.user_info.user_id or msg.chat_info.platform
) )
person = Person(platform=platform, user_id=msg.user_info.user_id)
person_name = person.person_name person_name = person.person_name
if person_name: if person_name:
participants_set.add(person_name) participants_set.add(person_name)
participants = list(participants_set) participants = list(participants_set)
logger.info( logger.info(f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}")
f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}"
)
# 使用LLM压缩聊天内容 # 使用LLM压缩聊天内容
success, theme, keywords, summary = await self._compress_with_llm(original_text) success, theme, keywords, summary = await self._compress_with_llm(original_text)
if not success: if not success:
logger.warning( logger.warning(f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}")
f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}"
)
# 清空当前批次,避免重复处理 # 清空当前批次,避免重复处理
self.current_batch = None self.current_batch = None
return return
@ -297,6 +292,7 @@ class ChatHistorySummarizer:
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}") logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# 出错时也清空批次,避免重复处理 # 出错时也清空批次,避免重复处理
self.current_batch = None self.current_batch = None
@ -338,23 +334,23 @@ class ChatHistorySummarizer:
# 移除可能的markdown代码块标记 # 移除可能的markdown代码块标记
json_str = response.strip() json_str = response.strip()
json_str = re.sub(r'^```json\s*', '', json_str, flags=re.MULTILINE) json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
json_str = re.sub(r'^```\s*', '', json_str, flags=re.MULTILINE) json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
json_str = json_str.strip() json_str = json_str.strip()
# 尝试找到JSON对象的开始和结束位置 # 尝试找到JSON对象的开始和结束位置
# 查找第一个 { 和最后一个匹配的 } # 查找第一个 { 和最后一个匹配的 }
start_idx = json_str.find('{') start_idx = json_str.find("{")
if start_idx == -1: if start_idx == -1:
raise ValueError("未找到JSON对象开始标记") raise ValueError("未找到JSON对象开始标记")
# 从后往前查找最后一个 } # 从后往前查找最后一个 }
end_idx = json_str.rfind('}') end_idx = json_str.rfind("}")
if end_idx == -1 or end_idx <= start_idx: if end_idx == -1 or end_idx <= start_idx:
raise ValueError("未找到JSON对象结束标记") raise ValueError("未找到JSON对象结束标记")
# 提取JSON字符串 # 提取JSON字符串
json_str = json_str[start_idx:end_idx + 1] json_str = json_str[start_idx : end_idx + 1]
# 尝试解析JSON # 尝试解析JSON
try: try:
@ -372,7 +368,7 @@ class ChatHistorySummarizer:
if escape_next: if escape_next:
fixed_chars.append(char) fixed_chars.append(char)
escape_next = False escape_next = False
elif char == '\\': elif char == "\\":
fixed_chars.append(char) fixed_chars.append(char)
escape_next = True escape_next = True
elif char == '"' and not escape_next: elif char == '"' and not escape_next:
@ -385,7 +381,7 @@ class ChatHistorySummarizer:
fixed_chars.append(char) fixed_chars.append(char)
i += 1 i += 1
json_str = ''.join(fixed_chars) json_str = "".join(fixed_chars)
# 再次尝试解析 # 再次尝试解析
result = json.loads(json_str) result = json.loads(json_str)
@ -450,6 +446,7 @@ class ChatHistorySummarizer:
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}") logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise raise
@ -490,6 +487,6 @@ class ChatHistorySummarizer:
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 后台检查循环出错: {e}") logger.error(f"{self.log_prefix} 后台检查循环出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
self._running = False self._running = False

View File

@ -2,7 +2,7 @@ import time
import random import random
import re import re
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install from rich.traceback import install
from src.config.config import global_config from src.config.config import global_config
@ -568,7 +568,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
output_lines = [] output_lines = []
current_time = time.time() current_time = time.time()
for action in actions: for action in actions:
action_time = action.time or current_time action_time = action.time or current_time
action_name = action.action_name or "未知动作" action_name = action.action_name or "未知动作"
@ -596,7 +595,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}" line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}"
output_lines.append(line) output_lines.append(line)
return "\n".join(output_lines) return "\n".join(output_lines)
@ -936,7 +934,6 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
return formatted_string return formatted_string
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
""" """
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身) 从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)

View File

@ -2,6 +2,7 @@
记忆遗忘任务 记忆遗忘任务
每5分钟进行一次遗忘检查根据不同的遗忘阶段删除记忆 每5分钟进行一次遗忘检查根据不同的遗忘阶段删除记忆
""" """
import time import time
import random import random
from typing import List from typing import List
@ -48,11 +49,7 @@ class MemoryForgetTask(AsyncTask):
# 查询符合条件的记忆forget_times=0 且 end_time < time_threshold # 查询符合条件的记忆forget_times=0 且 end_time < time_threshold
candidates = list( candidates = list(
ChatHistory.select() ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
.where(
(ChatHistory.forget_times == 0) &
(ChatHistory.end_time < time_threshold)
)
) )
if not candidates: if not candidates:
@ -101,11 +98,11 @@ class MemoryForgetTask(AsyncTask):
if remaining: if remaining:
# 批量更新 # 批量更新
ids_to_update = [r.id for r in remaining] ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=1).where( ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1") logger.info(
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
)
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True) logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
@ -122,11 +119,7 @@ class MemoryForgetTask(AsyncTask):
# 查询符合条件的记忆forget_times=1 且 end_time < time_threshold # 查询符合条件的记忆forget_times=1 且 end_time < time_threshold
candidates = list( candidates = list(
ChatHistory.select() ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
.where(
(ChatHistory.forget_times == 1) &
(ChatHistory.end_time < time_threshold)
)
) )
if not candidates: if not candidates:
@ -168,11 +161,11 @@ class MemoryForgetTask(AsyncTask):
remaining = [r for r in candidates if r.id not in to_delete_ids] remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining: if remaining:
ids_to_update = [r.id for r in remaining] ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=2).where( ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2") logger.info(
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
)
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True) logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
@ -189,11 +182,7 @@ class MemoryForgetTask(AsyncTask):
# 查询符合条件的记忆forget_times=2 且 end_time < time_threshold # 查询符合条件的记忆forget_times=2 且 end_time < time_threshold
candidates = list( candidates = list(
ChatHistory.select() ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
.where(
(ChatHistory.forget_times == 2) &
(ChatHistory.end_time < time_threshold)
)
) )
if not candidates: if not candidates:
@ -235,11 +224,11 @@ class MemoryForgetTask(AsyncTask):
remaining = [r for r in candidates if r.id not in to_delete_ids] remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining: if remaining:
ids_to_update = [r.id for r in remaining] ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=3).where( ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3") logger.info(
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
)
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True) logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
@ -256,11 +245,7 @@ class MemoryForgetTask(AsyncTask):
# 查询符合条件的记忆forget_times=3 且 end_time < time_threshold # 查询符合条件的记忆forget_times=3 且 end_time < time_threshold
candidates = list( candidates = list(
ChatHistory.select() ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
.where(
(ChatHistory.forget_times == 3) &
(ChatHistory.end_time < time_threshold)
)
) )
if not candidates: if not candidates:
@ -302,16 +287,18 @@ class MemoryForgetTask(AsyncTask):
remaining = [r for r in candidates if r.id not in to_delete_ids] remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining: if remaining:
ids_to_update = [r.id for r in remaining] ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=4).where( ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4") logger.info(
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
)
except Exception as e: except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True) logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
def _handle_same_count_random(self, candidates: List[ChatHistory], delete_count: int, mode: str) -> List[ChatHistory]: def _handle_same_count_random(
self, candidates: List[ChatHistory], delete_count: int, mode: str
) -> List[ChatHistory]:
""" """
处理count相同的情况随机选择要删除的记录 处理count相同的情况随机选择要删除的记录
@ -373,4 +360,3 @@ class MemoryForgetTask(AsyncTask):
start_idx = idx start_idx = idx
return to_delete return to_delete

View File

@ -504,7 +504,11 @@ class StatisticOutputTask(AsyncTask):
} }
# 获取bot的QQ账号 # 获取bot的QQ账号
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else "" bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
@ -588,7 +592,9 @@ class StatisticOutputTask(AsyncTask):
continue continue
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据 last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳 last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段 self.stat_period = [
item for item in self.stat_period if item[0] != "all_time"
] # 删除"所有时间"的统计时段
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的")) self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
except Exception as e: except Exception as e:
logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}") logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}")
@ -699,7 +705,11 @@ class StatisticOutputTask(AsyncTask):
# 计算花费/消息数量排除自己回复指标每100条 # 计算花费/消息数量排除自己回复指标每100条
total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies
cost_per_100_messages_excluding_replies = (stats[TOTAL_COST] / total_messages_excluding_replies * 100) if total_messages_excluding_replies > 0 else 0.0 cost_per_100_messages_excluding_replies = (
(stats[TOTAL_COST] / total_messages_excluding_replies * 100)
if total_messages_excluding_replies > 0
else 0.0
)
output = [ output = [
f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}", f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
@ -709,7 +719,9 @@ class StatisticOutputTask(AsyncTask):
f"总Token数: {_format_large_number(total_tokens)}", f"总Token数: {_format_large_number(total_tokens)}",
f"总花费: {stats[TOTAL_COST]:.2f}¥", f"总花费: {stats[TOTAL_COST]:.2f}¥",
f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A", f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A",
f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条" if total_messages_excluding_replies > 0 else "花费/消息数量(排除回复): N/A", f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条"
if total_messages_excluding_replies > 0
else "花费/消息数量(排除回复): N/A",
f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A", f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A",
f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A", f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A",
f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A", f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A",
@ -745,7 +757,16 @@ class StatisticOutputTask(AsyncTask):
formatted_out_tokens = _format_large_number(out_tokens) formatted_out_tokens = _format_large_number(out_tokens)
formatted_tokens = _format_large_number(tokens) formatted_tokens = _format_large_number(tokens)
output.append( output.append(
data_fmt.format(name, formatted_count, formatted_in_tokens, formatted_out_tokens, formatted_tokens, cost, avg_time_cost, std_time_cost) data_fmt.format(
name,
formatted_count,
formatted_in_tokens,
formatted_out_tokens,
formatted_tokens,
cost,
avg_time_cost,
std_time_cost,
)
) )
output.append("") output.append("")
@ -892,7 +913,11 @@ class StatisticOutputTask(AsyncTask):
logger.warning(f"生成HTML聊天统计时发生错误chat_id: {chat_id}, 错误: {e}") logger.warning(f"生成HTML聊天统计时发生错误chat_id: {chat_id}, 错误: {e}")
chat_rows.append(f"<tr><td>未知聊天</td><td>{_format_large_number(count, html=True)}</td></tr>") chat_rows.append(f"<tr><td>未知聊天</td><td>{_format_large_number(count, html=True)}</td></tr>")
chat_rows_html = "\n".join(chat_rows) if chat_rows else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>" chat_rows_html = (
"\n".join(chat_rows)
if chat_rows
else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
)
# 生成HTML # 生成HTML
return f""" return f"""
<div id=\"{div_id}\" class=\"tab-content\"> <div id=\"{div_id}\" class=\"tab-content\">
@ -1777,10 +1802,10 @@ class StatisticOutputTask(AsyncTask):
metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1) metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1)
# 7天尺度1天为单位 # 7天尺度1天为单位
metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24*7, interval_hours=24) metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24 * 7, interval_hours=24)
# 30天尺度1天为单位 # 30天尺度1天为单位
metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24*30, interval_hours=24) metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24 * 30, interval_hours=24)
return metrics_data return metrics_data
@ -1809,7 +1834,11 @@ class StatisticOutputTask(AsyncTask):
total_online_hours = [0.0] * len(time_points) total_online_hours = [0.0] * len(time_points)
# 获取bot的QQ账号 # 获取bot的QQ账号
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else "" bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
interval_seconds = interval_hours * 3600 interval_seconds = interval_hours * 3600
@ -1867,19 +1896,19 @@ class StatisticOutputTask(AsyncTask):
for idx in range(len(time_points)): for idx in range(len(time_points)):
# 花费/消息数量每100条 # 花费/消息数量每100条
if total_messages[idx] > 0: if total_messages[idx] > 0:
cost_per_100_messages[idx] = (total_costs[idx] / total_messages[idx] * 100) cost_per_100_messages[idx] = total_costs[idx] / total_messages[idx] * 100
# 花费/时间(每小时) # 花费/时间(每小时)
if total_online_hours[idx] > 0: if total_online_hours[idx] > 0:
cost_per_hour[idx] = (total_costs[idx] / total_online_hours[idx]) cost_per_hour[idx] = total_costs[idx] / total_online_hours[idx]
# Token/时间(每小时) # Token/时间(每小时)
if total_online_hours[idx] > 0: if total_online_hours[idx] > 0:
tokens_per_hour[idx] = (total_tokens[idx] / total_online_hours[idx]) tokens_per_hour[idx] = total_tokens[idx] / total_online_hours[idx]
# 花费/回复数量每100条 # 花费/回复数量每100条
if total_replies[idx] > 0: if total_replies[idx] > 0:
cost_per_100_replies[idx] = (total_costs[idx] / total_replies[idx] * 100) cost_per_100_replies[idx] = total_costs[idx] / total_replies[idx] * 100
# 生成时间标签 # 生成时间标签
if interval_hours == 1: if interval_hours == 1:

View File

@ -4,14 +4,11 @@ import time
import jieba import jieba
import json import json
import ast import ast
import numpy as np
from collections import Counter
from typing import Optional, Tuple, List, TYPE_CHECKING from typing import Optional, Tuple, List, TYPE_CHECKING
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
@ -146,7 +143,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
elif current_account: elif current_account:
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\)(.+?)\],说:", text): if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\)(.+?)\],说:", text):
is_mentioned = True is_mentioned = True
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text): elif re.search(
rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text
):
is_mentioned = True is_mentioned = True
# 6) 名称/别名 提及(去除 @/回复标记后再匹配) # 6) 名称/别名 提及(去除 @/回复标记后再匹配)
@ -185,7 +184,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
return embedding return embedding
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]: def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并 """将文本分割成句子,并根据概率合并
1. 识别分割点, ; 空格但如果分割点左右都是英文字母则不分割 1. 识别分割点, ; 空格但如果分割点左右都是英文字母则不分割
@ -227,7 +225,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
prev_char = text[i - 1] prev_char = text[i - 1]
next_char = text[i + 1] next_char = text[i + 1]
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则 # 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
if char == ' ': if char == " ":
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char) prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
next_is_alnum = next_char.isdigit() or is_english_letter(next_char) next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
if prev_is_alnum and next_is_alnum: if prev_is_alnum and next_is_alnum:
@ -340,7 +338,7 @@ def _get_random_default_reply() -> str:
"不知道", "不知道",
"不晓得", "不晓得",
"懒得说", "懒得说",
"()" "()",
] ]
return random.choice(default_replies) return random.choice(default_replies)
@ -469,7 +467,6 @@ def calculate_typing_time(
return total_time # 加上回车时间 return total_time # 加上回车时间
def truncate_message(message: str, max_length=20) -> str: def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度""" """截断消息,使其不超过指定长度"""
return f"{message[:max_length]}..." if len(message) > max_length else message return f"{message[:max_length]}..." if len(message) > max_length else message
@ -546,7 +543,6 @@ def get_western_ratio(paragraph):
return western_count / len(alnum_chars) return western_count / len(alnum_chars)
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str: def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch # sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式 """将时间戳转换为人类可读的时间格式

View File

@ -103,14 +103,16 @@ class ImageManager:
invalid_values = ["", "None"] invalid_values = ["", "None"]
# 清理 Images 表 # 清理 Images 表
deleted_images = Images.delete().where( deleted_images = (
(Images.description >> None) | (Images.description << invalid_values) Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
).execute() )
# 清理 ImageDescriptions 表 # 清理 ImageDescriptions 表
deleted_descriptions = ImageDescriptions.delete().where( deleted_descriptions = (
(ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values) ImageDescriptions.delete()
).execute() .where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
.execute()
)
if deleted_images or deleted_descriptions: if deleted_images or deleted_descriptions:
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}") logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}")

View File

@ -220,7 +220,7 @@ class DatabaseActionRecords(BaseDataModel):
chat_id: str, chat_id: str,
chat_info_stream_id: str, chat_info_stream_id: str,
chat_info_platform: str, chat_info_platform: str,
action_reasoning:str action_reasoning: str,
): ):
self.action_id = action_id self.action_id = action_id
self.time = time self.time = time

View File

@ -317,10 +317,12 @@ class Expression(BaseModel):
class Meta: class Meta:
table_name = "expression" table_name = "expression"
class Jargon(BaseModel): class Jargon(BaseModel):
""" """
用于存储俚语的模型 用于存储俚语的模型
""" """
content = TextField() content = TextField()
raw_content = TextField(null=True) raw_content = TextField(null=True)
type = TextField(null=True) type = TextField(null=True)
@ -336,10 +338,12 @@ class Jargon(BaseModel):
class Meta: class Meta:
table_name = "jargon" table_name = "jargon"
class ChatHistory(BaseModel): class ChatHistory(BaseModel):
""" """
用于存储聊天历史概括的模型 用于存储聊天历史概括的模型
""" """
chat_id = TextField(index=True) # 聊天ID chat_id = TextField(index=True) # 聊天ID
start_time = DoubleField() # 起始时间 start_time = DoubleField() # 起始时间
end_time = DoubleField() # 结束时间 end_time = DoubleField() # 结束时间
@ -359,6 +363,7 @@ class ThinkingBack(BaseModel):
""" """
用于存储记忆检索思考过程的模型 用于存储记忆检索思考过程的模型
""" """
chat_id = TextField(index=True) # 聊天ID chat_id = TextField(index=True) # 聊天ID
question = TextField() # 提出的问题 question = TextField() # 提出的问题
context = TextField(null=True) # 上下文信息 context = TextField(null=True) # 上下文信息
@ -371,6 +376,7 @@ class ThinkingBack(BaseModel):
class Meta: class Meta:
table_name = "thinking_back" table_name = "thinking_back"
MODELS = [ MODELS = [
ChatStreams, ChatStreams,
LLMUsage, LLMUsage,
@ -387,6 +393,7 @@ MODELS = [
ThinkingBack, ThinkingBack,
] ]
def create_tables(): def create_tables():
""" """
创建所有在模型中定义的数据库表 创建所有在模型中定义的数据库表

View File

@ -311,6 +311,7 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set()) ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表""" """过滤正则表达式列表"""
@dataclass @dataclass
class MemoryConfig(ConfigBase): class MemoryConfig(ConfigBase):
"""记忆配置类""" """记忆配置类"""
@ -321,6 +322,7 @@ class MemoryConfig(ConfigBase):
memory_build_frequency: int = 1 memory_build_frequency: int = 1
"""记忆构建频率""" """记忆构建频率"""
@dataclass @dataclass
class ExpressionConfig(ConfigBase): class ExpressionConfig(ConfigBase):
"""表达配置类""" """表达配置类"""
@ -501,6 +503,7 @@ class MoodConfig(ConfigBase):
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大" emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
"""情感特征,影响情绪的变化情况""" """情感特征,影响情绪的变化情况"""
@dataclass @dataclass
class VoiceConfig(ConfigBase): class VoiceConfig(ConfigBase):
"""语音识别配置类""" """语音识别配置类"""

View File

@ -3,7 +3,6 @@ import difflib
import random import random
from datetime import datetime from datetime import datetime
from typing import Optional, List, Dict from typing import Optional, List, Dict
from collections import defaultdict
def filter_message_content(content: Optional[str]) -> str: def filter_message_content(content: Optional[str]) -> str:
@ -20,13 +19,13 @@ def filter_message_content(content: Optional[str]) -> str:
return "" return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r'\[回复.*?\],说:\s*', '', content) content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容 # 移除@<...>格式的内容
content = re.sub(r'@<[^>]*>', '', content) content = re.sub(r"@<[^>]*>", "", content)
# 移除[picid:...]格式的图片ID # 移除[picid:...]格式的图片ID
content = re.sub(r'\[picid:[^\]]*\]', '', content) content = re.sub(r"\[picid:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容 # 移除[表情包:...]格式的内容
content = re.sub(r'\[表情包:[^\]]*\]', '', content) content = re.sub(r"\[表情包:[^\]]*\]", "", content)
return content.strip() return content.strip()

View File

@ -1,7 +1,6 @@
import time import time
import json import json
import os import os
from datetime import datetime
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import traceback import traceback
from src.common.logger import get_logger from src.common.logger import get_logger
@ -158,8 +157,6 @@ class ExpressionLearner:
traceback.print_exc() traceback.print_exc()
return return
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]: async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
""" """
学习并存储表达方式 学习并存储表达方式
@ -195,9 +192,7 @@ class ExpressionLearner:
) in learnt_expressions: ) in learnt_expressions:
# 查找是否已存在相似表达方式 # 查找是否已存在相似表达方式
query = Expression.select().where( query = Expression.select().where(
(Expression.chat_id == self.chat_id) (Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style)
& (Expression.situation == situation)
& (Expression.style == style)
) )
if query.exists(): if query.exists():
# 表达方式完全相同,只更新时间戳 # 表达方式完全相同,只更新时间戳
@ -222,19 +217,17 @@ class ExpressionLearner:
learner.add_style(style, situation) learner.add_style(style, situation)
# 学习映射关系 # 学习映射关系
success = style_learner_manager.learn_mapping( success = style_learner_manager.learn_mapping(self.chat_id, up_content, style)
self.chat_id,
up_content,
style
)
if success: if success:
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else "")) logger.debug(
f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}"
+ (f" (situation: {situation})" if situation else "")
)
else: else:
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}") logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
except Exception as e: except Exception as e:
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}") logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
# 保存当前聊天室的 style_learner 模型 # 保存当前聊天室的 style_learner 模型
if has_new_expressions: if has_new_expressions:
try: try:
@ -367,9 +360,7 @@ class ExpressionLearner:
return matched_expressions return matched_expressions
async def learn_expression( async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]:
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式 """从指定聊天流学习表达方式
Args: Args:
@ -409,7 +400,6 @@ class ExpressionLearner:
expressions: List[Tuple[str, str]] = self.parse_expression_response(response) expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# logger.debug(f"学习{type_str}的response: {response}") # logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源 # 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context( matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str expressions, random_msg_match_str
@ -449,7 +439,6 @@ class ExpressionLearner:
return filtered_with_up return filtered_with_up
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
""" """
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组 解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组

View File

@ -1,8 +1,6 @@
import json import json
import time import time
import random
import hashlib import hashlib
import re
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json from json_repair import repair_json
@ -115,7 +113,9 @@ class ExpressionSelector:
return group_chat_ids return group_chat_ids
return [chat_id] return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]: def get_model_predicted_expressions(
self, chat_id: str, target_message: str, total_num: int = 10
) -> List[Dict[str, Any]]:
""" """
使用 style_learner 模型预测最合适的表达方式 使用 style_learner 模型预测最合适的表达方式
@ -136,7 +136,6 @@ class ExpressionSelector:
# 支持多chat_id合并预测 # 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
predicted_expressions = [] predicted_expressions = []
# 为每个相关的chat_id进行预测 # 为每个相关的chat_id进行预测
@ -155,25 +154,31 @@ class ExpressionSelector:
if style_id and situation: if style_id and situation:
# 从数据库查找对应的表达记录 # 从数据库查找对应的表达记录
expr_query = Expression.select().where( expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) & (Expression.chat_id == related_chat_id)
(Expression.situation == situation) & & (Expression.situation == situation)
(Expression.style == best_style) & (Expression.style == best_style)
) )
if expr_query.exists(): if expr_query.exists():
expr = expr_query.get() expr = expr_query.get()
predicted_expressions.append({ predicted_expressions.append(
"id": expr.id, {
"situation": expr.situation, "id": expr.id,
"style": expr.style, "situation": expr.situation,
"last_active_time": expr.last_active_time, "style": expr.style,
"source_id": expr.chat_id, "last_active_time": expr.last_active_time,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "source_id": expr.chat_id,
"prediction_score": scores.get(best_style, 0.0), "create_date": expr.create_date
"prediction_input": filtered_target_message if expr.create_date is not None
}) else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": filtered_target_message,
}
)
else: else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式") logger.warning(
f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式"
)
except Exception as e: except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}") logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
@ -207,9 +212,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式 # 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where( style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)))
(Expression.chat_id.in_(related_chat_ids))
)
style_exprs = [ style_exprs = [
{ {
@ -236,7 +239,6 @@ class ExpressionSelector:
logger.error(f"随机选择表达方式失败: {e}") logger.error(f"随机选择表达方式失败: {e}")
return [] return []
async def select_suitable_expressions( async def select_suitable_expressions(
self, self,
chat_id: str, chat_id: str,
@ -425,17 +427,13 @@ class ExpressionSelector:
updates_by_key[key] = expr updates_by_key[key] = expr
for chat_id, situation, style in updates_by_key: for chat_id, situation, style in updates_by_key:
query = Expression.select().where( query = Expression.select().where(
(Expression.chat_id == chat_id) (Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
& (Expression.situation == situation)
& (Expression.style == style)
) )
if query.exists(): if query.exists():
expr_obj = query.get() expr_obj = query.get()
expr_obj.last_active_time = time.time() expr_obj.last_active_time = time.time()
expr_obj.save() expr_obj.save()
logger.debug( logger.debug("表达方式激活: 更新last_active_time in db")
"表达方式激活: 更新last_active_time in db"
)
init_prompt() init_prompt()

View File

@ -6,18 +6,21 @@ import os
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .online_nb import OnlineNaiveBayes from .online_nb import OnlineNaiveBayes
class ExpressorModel: class ExpressorModel:
""" """
直接使用朴素贝叶斯精排可在线学习 直接使用朴素贝叶斯精排可在线学习
支持存储situation字段不参与计算仅与style对应 支持存储situation字段不参与计算仅与style对应
""" """
def __init__(self, def __init__(
alpha: float = 0.5, self,
beta: float = 0.5, alpha: float = 0.5,
gamma: float = 1.0, beta: float = 0.5,
vocab_size: int = 200000, gamma: float = 1.0,
use_jieba: bool = True): vocab_size: int = 200000,
use_jieba: bool = True,
):
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba) self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size) self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
self._candidates: Dict[str, str] = {} # cid -> text (style) self._candidates: Dict[str, str] = {} # cid -> text (style)
@ -96,25 +99,27 @@ class ExpressorModel:
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]: def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""获取所有候选的style和situation信息""" """获取所有候选的style和situation信息"""
return {cid: (style, self._situations.get(cid)) return {cid: (style, self._situations.get(cid)) for cid, style in self._candidates.items()}
for cid, style in self._candidates.items()}
def save(self, path: str): def save(self, path: str):
"""保存模型""" """保存模型"""
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f: with open(path, "wb") as f:
pickle.dump({ pickle.dump(
"candidates": self._candidates, {
"situations": self._situations, "candidates": self._candidates,
"nb": { "situations": self._situations,
"cls_counts": dict(self.nb.cls_counts), "nb": {
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()}, "cls_counts": dict(self.nb.cls_counts),
"alpha": self.nb.alpha, "token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"beta": self.nb.beta, "alpha": self.nb.alpha,
"gamma": self.nb.gamma, "beta": self.nb.beta,
"V": self.nb.V, "gamma": self.nb.gamma,
} "V": self.nb.V,
}, f) },
},
f,
)
def load(self, path: str): def load(self, path: str):
"""加载模型""" """加载模型"""
@ -133,8 +138,10 @@ class ExpressorModel:
self.nb.V = obj["nb"]["V"] self.nb.V = obj["nb"]["V"]
self.nb._logZ.clear() self.nb._logZ.clear()
def defaultdict_dict(d: Dict[str, Dict[str, float]]): def defaultdict_dict(d: Dict[str, Dict[str, float]]):
from collections import defaultdict from collections import defaultdict
outer = defaultdict(lambda: defaultdict(float)) outer = defaultdict(lambda: defaultdict(float))
for k, inner in d.items(): for k, inner in d.items():
outer[k].update(inner) outer[k].update(inner)

View File

@ -2,6 +2,7 @@ import math
from typing import Dict, List from typing import Dict, List
from collections import defaultdict, Counter from collections import defaultdict, Counter
class OnlineNaiveBayes: class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000): def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha self.alpha = alpha
@ -9,9 +10,9 @@ class OnlineNaiveBayes:
self.gamma = gamma self.gamma = gamma
self.V = vocab_size self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str): def _invalidate(self, cid: str):
if cid in self._logZ: if cid in self._logZ:

View File

@ -3,17 +3,20 @@ from typing import List, Optional, Set
try: try:
import jieba import jieba
_HAS_JIEBA = True _HAS_JIEBA = True
except Exception: except Exception:
_HAS_JIEBA = False _HAS_JIEBA = False
_WORD_RE = re.compile(r"[A-Za-z0-9_]+") _WORD_RE = re.compile(r"[A-Za-z0-9_]+")
# 匹配纯符号的正则表达式 # 匹配纯符号的正则表达式
_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$') _SYMBOL_RE = re.compile(r"^[^\w\u4e00-\u9fff]+$")
def simple_en_tokenize(text: str) -> List[str]: def simple_en_tokenize(text: str) -> List[str]:
return _WORD_RE.findall(text.lower()) return _WORD_RE.findall(text.lower())
class Tokenizer: class Tokenizer:
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True): def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
self.stopwords = stopwords or set() self.stopwords = stopwords or set()

View File

@ -30,7 +30,7 @@ class StyleLearner:
"beta": 0.5, "beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘 "gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000, "vocab_size": 200000,
"use_jieba": True "use_jieba": True,
} }
# 初始化表达模型 # 初始化表达模型
@ -47,7 +47,7 @@ class StyleLearner:
"total_samples": 0, "total_samples": 0,
"style_counts": defaultdict(int), "style_counts": defaultdict(int),
"last_update": None, "last_update": None,
"style_usage_frequency": defaultdict(int) # 风格使用频率 "style_usage_frequency": defaultdict(int), # 风格使用频率
} }
def add_style(self, style: str, situation: str = None) -> bool: def add_style(self, style: str, situation: str = None) -> bool:
@ -80,8 +80,10 @@ class StyleLearner:
# 添加到expressor模型 # 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation) self.expressor.add_candidate(style_id, style, situation)
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" + logger.info(
(f", situation: '{situation}'" if situation else "")) f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})"
+ (f", situation: '{situation}'" if situation else "")
)
return True return True
except Exception as e: except Exception as e:
@ -341,7 +343,7 @@ class StyleLearner:
"style_counts": dict(self.learning_stats["style_counts"]), "style_counts": dict(self.learning_stats["style_counts"]),
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]), "style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
"last_update": self.learning_stats["last_update"], "last_update": self.learning_stats["last_update"],
"all_styles": list(self.style_to_id.keys()) "all_styles": list(self.style_to_id.keys()),
} }
def save(self, base_path: str) -> bool: def save(self, base_path: str) -> bool:
@ -362,7 +364,7 @@ class StyleLearner:
"id_to_style": self.id_to_style, "id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation, "id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id, "next_style_id": self.next_style_id,
"learning_stats": self.learning_stats "learning_stats": self.learning_stats,
} }
# 先保存expressor模型 # 先保存expressor模型

View File

@ -3,5 +3,3 @@ from .jargon_miner import extract_and_store_jargon
__all__ = [ __all__ = [
"extract_and_store_jargon", "extract_and_store_jargon",
] ]

View File

@ -358,10 +358,7 @@ async def _default_stream_response_handler(
model_dbg = None model_dbg = None
# 统一日志格式 # 统一日志格式
logger.info( logger.info("模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整" % (model_dbg or ""))
"模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整"
% (model_dbg or "")
)
return resp, _usage_record return resp, _usage_record
except Exception: except Exception:
@ -404,9 +401,7 @@ def _default_normal_response_parser(
raw_snippet = str(resp)[:300] raw_snippet = str(resp)[:300]
except Exception: except Exception:
raw_snippet = "<unserializable>" raw_snippet = "<unserializable>"
logger.debug( logger.debug(f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}")
f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}"
)
except Exception: except Exception:
# 日志采集失败不应影响控制流 # 日志采集失败不应影响控制流
pass pass
@ -464,10 +459,7 @@ def _default_normal_response_parser(
# print(resp) # print(resp)
_model_name = resp.model _model_name = resp.model
# 统一日志格式 # 统一日志格式
logger.info( logger.info("模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整" % (_model_name or ""))
"模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整"
% (_model_name or "")
)
return api_response, _usage_record return api_response, _usage_record
except Exception as e: except Exception as e:
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}") logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")

View File

@ -328,9 +328,7 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。") logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning( logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}")
f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval) await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e: except NetworkConnectionError as e:
@ -340,9 +338,7 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。") logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning( logger.warning(f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}")
f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval) await asyncio.sleep(api_provider.retry_interval)
except RespNotOkException as e: except RespNotOkException as e:

View File

@ -5,6 +5,7 @@ from maim_message import MessageServer
from src.common.remote import TelemetryHeartBeatTask from src.common.remote import TelemetryHeartBeatTask
from src.manager.async_task_manager import async_task_manager from src.manager.async_task_manager import async_task_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
# from src.chat.utils.token_statistics import TokenStatisticsTask # from src.chat.utils.token_statistics import TokenStatisticsTask
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
@ -73,6 +74,7 @@ class MainSystem:
# 添加记忆遗忘任务 # 添加记忆遗忘任务
from src.chat.utils.memory_forget_task import MemoryForgetTask from src.chat.utils.memory_forget_task import MemoryForgetTask
await async_task_manager.add_task(MemoryForgetTask()) await async_task_manager.add_task(MemoryForgetTask())
# 启动API服务器 # 启动API服务器
@ -106,7 +108,6 @@ class MainSystem:
self.app.register_message_handler(chat_bot.message_process) self.app.register_message_handler(chat_bot.message_process)
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process) self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
# 触发 ON_START 事件 # 触发 ON_START 事件
from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType from src.plugin_system.base.component_types import EventType

View File

@ -9,10 +9,7 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools") logger = get_logger("memory_retrieval_tools")
async def query_jargon( async def query_jargon(keyword: str, chat_id: str) -> str:
keyword: str,
chat_id: str
) -> str:
"""根据关键词在jargon库中查询 """根据关键词在jargon库中查询
Args: Args:
@ -28,25 +25,13 @@ async def query_jargon(
return "关键词为空" return "关键词为空"
# 先尝试精确匹配 # 先尝试精确匹配
results = search_jargon( results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=False
)
is_fuzzy_match = False is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索 # 如果精确匹配未找到,尝试模糊搜索
if not results: if not results:
results = search_jargon( results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=True
)
is_fuzzy_match = True is_fuzzy_match = True
if results: if results:
@ -86,14 +71,6 @@ def register_tool():
register_memory_retrieval_tool( register_memory_retrieval_tool(
name="query_jargon", name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索默认会先尝试精确匹配如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。", description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索默认会先尝试精确匹配如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
parameters=[ parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
{ execute_func=query_jargon,
"name": "keyword",
"type": "string",
"description": "关键词(黑话/俚语/缩写)",
"required": True
}
],
execute_func=query_jargon
) )

View File

@ -14,11 +14,7 @@ class MemoryRetrievalTool:
"""记忆检索工具基类""" """记忆检索工具基类"""
def __init__( def __init__(
self, self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
): ):
""" """
初始化工具 初始化工具
@ -145,10 +141,7 @@ _tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool( def register_memory_retrieval_tool(
name: str, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
) -> None: ) -> None:
"""注册记忆检索工具的便捷函数 """注册记忆检索工具的便捷函数
@ -165,4 +158,3 @@ def register_memory_retrieval_tool(
def get_tool_registry() -> MemoryRetrievalToolRegistry: def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例""" """获取工具注册器实例"""
return _tool_registry return _tool_registry

View File

@ -1,10 +1,7 @@
import math
import random
import time import time
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive

View File

@ -6,7 +6,9 @@ logger = get_logger("frequency_api")
def get_current_talk_value(chat_id: str) -> float: def get_current_talk_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id) return frequency_control_manager.get_or_create_frequency_control(
chat_id
).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None: def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:

View File

@ -109,7 +109,7 @@ def get_messages_by_time_in_chat(
limit=limit, limit=limit,
limit_mode=limit_mode, limit_mode=limit_mode,
filter_bot=filter_mai, filter_bot=filter_mai,
filter_command=filter_command filter_command=filter_command,
) )

View File

@ -77,7 +77,7 @@ class BaseAction(ABC):
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy() self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
"""NORMAL模式下的激活类型""" """NORMAL模式下的激活类型"""
self.activation_type = getattr(self.__class__, "activation_type") self.activation_type = self.__class__.activation_type
"""激活类型""" """激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0) self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率""" """当激活类型为RANDOM时的概率"""
@ -108,16 +108,11 @@ class BaseAction(ABC):
self.is_group = False self.is_group = False
self.target_id = None self.target_id = None
self.group_id = ( self.group_id = (
str(self.action_message.chat_info.group_info.group_id) str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None
if self.action_message.chat_info.group_info
else None
) )
self.group_name = ( self.group_name = (
self.action_message.chat_info.group_info.group_name self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None
if self.action_message.chat_info.group_info
else None
) )
self.user_id = str(self.action_message.user_info.user_id) self.user_id = str(self.action_message.user_info.user_id)
@ -132,7 +127,6 @@ class BaseAction(ABC):
self.target_id = self.user_id self.target_id = self.user_id
self.log_prefix = f"[{self.user_nickname} 的 私聊]" self.log_prefix = f"[{self.user_nickname} 的 私聊]"
logger.debug( logger.debug(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}" f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
) )
@ -448,7 +442,6 @@ class BaseAction(ABC):
wait_start_time = asyncio.get_event_loop().time() wait_start_time = asyncio.get_event_loop().time()
while True: while True:
# 检查新消息 # 检查新消息
current_time = time.time() current_time = time.time()
new_message_count = message_api.count_new_messages( new_message_count = message_api.count_new_messages(
@ -497,7 +490,7 @@ class BaseAction(ABC):
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代") raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
# 获取focus_activation_type和normal_activation_type # 获取focus_activation_type和normal_activation_type
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS) focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS) _normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
# 处理activation_type如果插件中声明了就用插件的值否则默认使用focus_activation_type # 处理activation_type如果插件中声明了就用插件的值否则默认使用focus_activation_type
activation_type = getattr(cls, "activation_type", focus_activation_type) activation_type = getattr(cls, "activation_type", focus_activation_type)

View File

@ -346,9 +346,7 @@ class EventsManager:
if not isinstance(result, tuple) or len(result) != 5: if not isinstance(result, tuple) or len(result) != 5:
if isinstance(result, tuple): if isinstance(result, tuple):
annotated = ", ".join( annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False))
f"{name}={val!r}" for name, val in zip(expected_fields, result)
)
actual_desc = f"{len(result)} 个元素 ({annotated})" actual_desc = f"{len(result)} 个元素 ({annotated})"
else: else:
actual_desc = f"非 tuple 类型: {type(result)}" actual_desc = f"非 tuple 类型: {type(result)}"
@ -380,7 +378,6 @@ class EventsManager:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True) logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
return True, None # 发生异常时默认不中断其他处理 return True, None # 发生异常时默认不中断其他处理
def _task_done_callback( def _task_done_callback(
self, self,
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]], task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],

View File

@ -189,9 +189,8 @@ class ToolExecutor:
tool_info["content"] = str(content) tool_info["content"] = str(content)
# 空内容直接跳过(空字符串、全空白字符串、空列表/空元组) # 空内容直接跳过(空字符串、全空白字符串、空列表/空元组)
content_check = tool_info["content"] content_check = tool_info["content"]
if ( if (isinstance(content_check, str) and not content_check.strip()) or (
(isinstance(content_check, str) and not content_check.strip()) isinstance(content_check, (list, tuple)) and len(content_check) == 0
or (isinstance(content_check, (list, tuple)) and len(content_check) == 0)
): ):
logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示") logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示")
continue continue

View File

@ -8,6 +8,7 @@ import sys
import os import os
from pprint import pprint from pprint import pprint
def view_pkl_file(file_path): def view_pkl_file(file_path):
"""查看 pkl 文件内容""" """查看 pkl 文件内容"""
if not os.path.exists(file_path): if not os.path.exists(file_path):
@ -15,7 +16,7 @@ def view_pkl_file(file_path):
return return
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
data = pickle.load(f) data = pickle.load(f)
print(f"📁 文件: {file_path}") print(f"📁 文件: {file_path}")
@ -44,10 +45,10 @@ def view_pkl_file(file_path):
pprint(data, width=120, depth=10) pprint(data, width=120, depth=10)
# 如果是 expressor 模型,特别显示 token_counts 的详细信息 # 如果是 expressor 模型,特别显示 token_counts 的详细信息
if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']: if isinstance(data, dict) and "nb" in data and "token_counts" in data["nb"]:
print("\n" + "="*50) print("\n" + "=" * 50)
print("🔍 详细词汇统计 (token_counts):") print("🔍 详细词汇统计 (token_counts):")
token_counts = data['nb']['token_counts'] token_counts = data["nb"]["token_counts"]
for style_id, tokens in token_counts.items(): for style_id, tokens in token_counts.items():
print(f"\n📝 {style_id}:") print(f"\n📝 {style_id}:")
if tokens: if tokens:
@ -63,6 +64,7 @@ def view_pkl_file(file_path):
except Exception as e: except Exception as e:
print(f"❌ 读取文件失败: {e}") print(f"❌ 读取文件失败: {e}")
def main(): def main():
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("用法: python view_pkl.py <pkl文件路径>") print("用法: python view_pkl.py <pkl文件路径>")
@ -72,5 +74,6 @@ def main():
file_path = sys.argv[1] file_path = sys.argv[1]
view_pkl_file(file_path) view_pkl_file(file_path)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -7,6 +7,7 @@ import pickle
import sys import sys
import os import os
def view_token_counts(file_path): def view_token_counts(file_path):
"""查看 expressor.pkl 文件中的词汇统计""" """查看 expressor.pkl 文件中的词汇统计"""
if not os.path.exists(file_path): if not os.path.exists(file_path):
@ -14,18 +15,18 @@ def view_token_counts(file_path):
return return
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
data = pickle.load(f) data = pickle.load(f)
print(f"📁 文件: {file_path}") print(f"📁 文件: {file_path}")
print("=" * 60) print("=" * 60)
if 'nb' not in data or 'token_counts' not in data['nb']: if "nb" not in data or "token_counts" not in data["nb"]:
print("❌ 这不是一个 expressor 模型文件") print("❌ 这不是一个 expressor 模型文件")
return return
token_counts = data['nb']['token_counts'] token_counts = data["nb"]["token_counts"]
candidates = data.get('candidates', {}) candidates = data.get("candidates", {})
print(f"🎯 找到 {len(token_counts)} 个风格") print(f"🎯 找到 {len(token_counts)} 个风格")
print("=" * 60) print("=" * 60)
@ -41,7 +42,7 @@ def view_token_counts(file_path):
print("🔤 词汇统计 (按频率排序):") print("🔤 词汇统计 (按频率排序):")
for i, (word, count) in enumerate(sorted_tokens): for i, (word, count) in enumerate(sorted_tokens):
print(f" {i+1:2d}. '{word}': {count}") print(f" {i + 1:2d}. '{word}': {count}")
else: else:
print(" (无词汇数据)") print(" (无词汇数据)")
@ -50,6 +51,7 @@ def view_token_counts(file_path):
except Exception as e: except Exception as e:
print(f"❌ 读取文件失败: {e}") print(f"❌ 读取文件失败: {e}")
def main(): def main():
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("用法: python view_tokens.py <expressor.pkl文件路径>") print("用法: python view_tokens.py <expressor.pkl文件路径>")
@ -59,5 +61,6 @@ def main():
file_path = sys.argv[1] file_path = sys.argv[1]
view_token_counts(file_path) view_token_counts(file_path)
if __name__ == "__main__": if __name__ == "__main__":
main() main()