pull/1359/head
SengokuCola 2025-11-13 19:00:59 +08:00
commit d306e40db0
46 changed files with 1000 additions and 1041 deletions

2
bot.py
View File

@ -30,7 +30,7 @@ else:
raise
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
initialize_logging()

View File

@ -1,15 +1,11 @@
from typing import List, Tuple, Type, Optional
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseCommand,
ComponentInfo,
ConfigField
)
from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField
from src.plugin_system.apis import send_api, frequency_api
class SetTalkFrequencyCommand(BaseCommand):
"""设置当前聊天的talk_frequency值"""
command_name = "set_talk_frequency"
command_description = "设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>"
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
@ -19,35 +15,35 @@ class SetTalkFrequencyCommand(BaseCommand):
# 获取命令参数 - 使用命名捕获组
if not self.matched_groups or "value" not in self.matched_groups:
return False, "命令格式错误", False
value_str = self.matched_groups["value"]
if not value_str:
return False, "无法获取数值参数", False
value = float(value_str)
# 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id
# 设置talk_frequency
frequency_api.set_talk_frequency_adjust(chat_id, value)
final_value = frequency_api.get_current_talk_value(chat_id)
adjust_value = frequency_api.get_talk_frequency_adjust(chat_id)
base_value = final_value / adjust_value
# 发送反馈消息(不保存到数据库)
await send_api.text_to_stream(
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
chat_id,
storage_message=False
storage_message=False,
)
return True, None, False
except ValueError:
error_msg = "数值格式错误,请输入有效的数字"
await self.send_text(error_msg, storage_message=False)
@ -60,6 +56,7 @@ class SetTalkFrequencyCommand(BaseCommand):
class ShowFrequencyCommand(BaseCommand):
"""显示当前聊天的频率控制状态"""
command_name = "show_frequency"
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
command_pattern = r"^/chat\s+(?:show|s)$"
@ -116,11 +113,7 @@ class BetterFrequencyPlugin(BasePlugin):
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息",
"frequency": "频率控制配置",
"features": "功能开关配置"
}
config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"}
# 配置Schema定义
config_schema: dict = {
@ -138,13 +131,14 @@ class BetterFrequencyPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
components = []
# 根据配置决定是否注册命令组件
if self.config.get("features", {}).get("enable_commands", True):
components.extend([
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
])
components.extend(
[
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
]
)
return components

View File

@ -6,15 +6,16 @@ import sys
import os
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
# 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
SECONDS_5_MINUTES = 5 * 60
@ -28,16 +29,16 @@ def clean_output_text(text: str) -> str:
"""
if not text:
return text
# 移除表情包内容:[表情包:...]
text = re.sub(r'\[表情包:[^\]]*\]', '', text)
text = re.sub(r"\[表情包:[^\]]*\]", "", text)
# 移除回复内容:[回复...],说:... 的完整模式
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text)
text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
# 清理多余的空格和换行
text = re.sub(r'\s+', ' ', text).strip()
text = re.sub(r"\s+", " ", text).strip()
return text
@ -89,7 +90,7 @@ def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[Databa
for msg in messages:
groups.setdefault(msg.chat_id, []).append(msg)
# 保证每个分组内按时间升序
for chat_id, msgs in groups.items():
for _chat_id, msgs in groups.items():
msgs.sort(key=lambda m: m.time or 0)
return groups
@ -170,8 +171,8 @@ def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseM
continue
last = bucket[-1]
same_user = (msg.user_info.user_id == last.user_info.user_id)
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
same_user = msg.user_info.user_id == last.user_info.user_id
close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
if same_user and close_enough:
bucket.append(msg)
@ -199,38 +200,36 @@ def build_pairs_for_chat(
pairs: List[Tuple[str, str, str]] = []
n_merged = len(merged_messages)
n_original = len(original_messages)
if n_merged == 0 or n_original == 0:
return pairs
# 为每个合并后的消息找到对应的原始消息位置
merged_to_original_map = {}
original_idx = 0
for merged_idx, merged_msg in enumerate(merged_messages):
# 找到这个合并消息对应的第一个原始消息
while (original_idx < n_original and
original_messages[original_idx].time < merged_msg.time):
while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
original_idx += 1
# 如果找到了时间匹配的原始消息,建立映射
if (original_idx < n_original and
original_messages[original_idx].time == merged_msg.time):
if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
merged_to_original_map[merged_idx] = original_idx
for merged_idx in range(n_merged):
merged_msg = merged_messages[merged_idx]
# 如果指定了 target_user_id只处理该用户的消息作为 output
if target_user_id and merged_msg.user_info.user_id != target_user_id:
continue
# 找到对应的原始消息位置
if merged_idx not in merged_to_original_map:
continue
original_idx = merged_to_original_map[merged_idx]
# 选择上下文窗口大小
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
start = max(0, original_idx - window)
@ -266,7 +265,7 @@ def build_pairs(
groups = group_by_chat(messages)
all_pairs: List[Tuple[str, str, str]] = []
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
# 对消息进行合并用于output
merged = merge_adjacent_same_user(msgs)
# 传递原始消息和合并后消息input使用原始消息output使用合并后消息
@ -385,5 +384,3 @@ def run_interactive() -> int:
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,4 +1,3 @@
import time
import sys
import os
import matplotlib.pyplot as plt
@ -6,16 +5,17 @@ import matplotlib.dates as mdates
from datetime import datetime
from typing import List, Tuple
import numpy as np
from src.common.database.database_model import Expression, ChatStreams
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression, ChatStreams
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
def get_chat_name(chat_id: str) -> str:
@ -39,19 +39,14 @@ def get_expression_data() -> List[Tuple[float, float, str, str]]:
"""获取Expression表中的数据返回(create_date, count, chat_id, expression_type)的列表"""
expressions = Expression.select()
data = []
for expr in expressions:
# 如果create_date为空跳过该记录
if expr.create_date is None:
continue
data.append((
expr.create_date,
expr.count,
expr.chat_id,
expr.type
))
data.append((expr.create_date, expr.count, expr.chat_id, expr.type))
return data
@ -60,71 +55,71 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
if not data:
print("没有找到有效的表达式数据")
return
# 分离数据
create_dates = [item[0] for item in data]
counts = [item[1] for item in data]
chat_ids = [item[2] for item in data]
expression_types = [item[3] for item in data]
_chat_ids = [item[2] for item in data]
_expression_types = [item[3] for item in data]
# 转换时间戳为datetime对象
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
# 计算时间跨度,自动调整显示格式
time_span = max(dates) - min(dates)
if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d'
date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d'
date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M'
date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(12, 8))
# 创建散点图
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis')
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap="viridis")
# 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12)
ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold')
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title("表达式使用次数随时间分布散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加网格
ax.grid(True, alpha=0.3)
# 添加颜色条
cbar = plt.colorbar(scatter)
cbar.set_label('数据点顺序', fontsize=10)
cbar.set_label("数据点顺序", fontsize=10)
# 调整布局
plt.tight_layout()
# 显示统计信息
print(f"\n=== 数据统计 ===")
print("\n=== 数据统计 ===")
print(f"总数据点数量: {len(data)}")
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')}{max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
print(f"使用次数范围: {min(counts):.1f}{max(counts):.1f}")
print(f"平均使用次数: {np.mean(counts):.2f}")
print(f"中位数使用次数: {np.median(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n散点图已保存到: {save_path}")
# 显示图片
plt.show()
@ -134,7 +129,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
if not data:
print("没有找到有效的表达式数据")
return
# 按chat_id分组
chat_groups = {}
for item in data:
@ -142,75 +137,82 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
if chat_id not in chat_groups:
chat_groups[chat_id] = []
chat_groups[chat_id].append(item)
# 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d'
date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d'
date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M'
date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(14, 10))
# 为每个聊天分配不同颜色
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
create_dates = [item[0] for item in chat_data]
counts = [item[1] for item in chat_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
chat_name = get_chat_name(chat_id)
# 截断过长的聊天名称
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
ax.scatter(dates, counts, alpha=0.7, s=40,
c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)",
edgecolors='black', linewidth=0.5)
ax.scatter(
dates,
counts,
alpha=0.7,
s=40,
c=[colors[i]],
label=f"{display_name} ({len(chat_data)}个)",
edgecolors="black",
linewidth=0.5,
)
# 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12)
ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold')
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title("按聊天分组的表达式使用次数散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
# 添加网格
ax.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 显示统计信息
print(f"\n=== 分组统计 ===")
print("\n=== 分组统计 ===")
print(f"总聊天数量: {len(chat_groups)}")
for chat_id, chat_data in chat_groups.items():
chat_name = get_chat_name(chat_id)
counts = [item[1] for item in chat_data]
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n分组散点图已保存到: {save_path}")
# 显示图片
plt.show()
@ -220,7 +222,7 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
if not data:
print("没有找到有效的表达式数据")
return
# 按type分组
type_groups = {}
for item in data:
@ -228,69 +230,76 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
if expr_type not in type_groups:
type_groups[expr_type] = []
type_groups[expr_type].append(item)
# 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d'
date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d'
date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M'
date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(12, 8))
# 为每个类型分配不同颜色
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
for i, (expr_type, type_data) in enumerate(type_groups.items()):
create_dates = [item[0] for item in type_data]
counts = [item[1] for item in type_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
ax.scatter(dates, counts, alpha=0.7, s=40,
c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)",
edgecolors='black', linewidth=0.5)
ax.scatter(
dates,
counts,
alpha=0.7,
s=40,
c=[colors[i]],
label=f"{expr_type} ({len(type_data)}个)",
edgecolors="black",
linewidth=0.5,
)
# 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12)
ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold')
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title("按表达式类型分组的散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
# 添加网格
ax.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 显示统计信息
print(f"\n=== 类型统计 ===")
print("\n=== 类型统计 ===")
for expr_type, type_data in type_groups.items():
counts = [item[1] for item in type_data]
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n类型散点图已保存到: {save_path}")
# 显示图片
plt.show()
@ -298,35 +307,35 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
def main():
"""主函数"""
print("开始分析表达式数据...")
# 获取数据
data = get_expression_data()
if not data:
print("没有找到有效的表达式数据create_date不为空的数据")
return
print(f"找到 {len(data)} 条有效数据")
# 创建输出目录
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
# 生成时间戳用于文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 1. 创建基础散点图
print("\n1. 创建基础散点图...")
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
# 2. 创建按聊天分组的散点图
print("\n2. 创建按聊天分组的散点图...")
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
# 3. 创建按类型分组的散点图
print("\n3. 创建按类型分组的散点图...")
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
print("\n分析完成!")

View File

@ -945,9 +945,7 @@ class EmojiManager:
prompt, image_base64, "jpg", temperature=0.5
)
else:
prompt = (
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析精简回答"
)
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析精简回答"
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.5
)

View File

@ -12,6 +12,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.plugin_system.apis import frequency_api
def init_prompt():
Prompt(
"""{name_block}
@ -28,7 +29,7 @@ def init_prompt():
""",
"frequency_adjust_prompt",
)
logger = get_logger("frequency_control")
@ -40,7 +41,7 @@ class FrequencyControl:
self.chat_id = chat_id
# 发言频率调整值
self.talk_frequency_adjust: float = 1.0
self.last_frequency_adjust_time: float = 0.0
self.frequency_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust"
@ -53,16 +54,14 @@ class FrequencyControl:
def set_talk_frequency_adjust(self, value: float) -> None:
"""设置发言频率调整值"""
self.talk_frequency_adjust = max(0.1, min(5.0, value))
async def trigger_frequency_adjust(self) -> None:
msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=self.last_frequency_adjust_time,
timestamp_end=time.time(),
)
if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20:
return
else:
@ -73,7 +72,7 @@ class FrequencyControl:
limit=20,
limit_mode="latest",
)
message_str = build_readable_messages(
new_msg_list,
replace_bot_name=True,
@ -97,15 +96,15 @@ class FrequencyControl:
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
prompt,
)
# logger.info(f"频率调整 prompt: {prompt}")
# logger.info(f"频率调整 response: {response}")
if global_config.debug.show_prompt:
logger.info(f"频率调整 prompt: {prompt}")
logger.info(f"频率调整 response: {response}")
logger.info(f"频率调整 reasoning_content: {reasoning_content}")
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
# LLM依然输出过多内容时取消本次调整。合法最多4个字但有的模型可能会输出一些markdown换行符等需要长度宽限
@ -118,7 +117,8 @@ class FrequencyControl:
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
self.last_frequency_adjust_time = time.time()
else:
logger.info(f"频率调整response不符合要求取消本次调整")
logger.info("频率调整response不符合要求取消本次调整")
class FrequencyControlManager:
"""频率控制管理器,管理多个聊天流的频率控制实例"""
@ -143,6 +143,7 @@ class FrequencyControlManager:
"""获取所有有频率控制的聊天ID"""
return list(self.frequency_control_dict.keys())
init_prompt()
# 创建全局实例

View File

@ -1,5 +1,4 @@
import asyncio
from multiprocessing import context
import time
import traceback
import random
@ -102,14 +101,14 @@ class HeartFChatting:
self.is_mute = False
self.last_active_time = time.time() # 记录上一次非noreply时间
self.last_active_time = time.time() # 记录上一次非noreply时间
self.question_probability_multiplier = 1
self.questioned = False
# 跟踪连续 no_reply 次数,用于动态调整阈值
self.consecutive_no_reply_count = 0
# 聊天内容概括器
self.chat_history_summarizer = ChatHistorySummarizer(chat_id=self.stream_id)
@ -127,10 +126,10 @@ class HeartFChatting:
self._loop_task = asyncio.create_task(self._main_chat_loop())
self._loop_task.add_done_callback(self._handle_loop_completion)
# 启动聊天内容概括器的后台定期检查循环
await self.chat_history_summarizer.start()
logger.info(f"{self.log_prefix} HeartFChatting 启动完成")
except Exception as e:
@ -180,7 +179,7 @@ class HeartFChatting:
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
async def _loopbody(self):
async def _loopbody(self):
recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=self.last_read_time,
@ -191,9 +190,6 @@ class HeartFChatting:
filter_command=True,
)
# 根据连续 no_reply 次数动态调整阈值
# 3次 no_reply 时,阈值调高到 1.550%概率为150%概率为2
# 5次 no_reply 时,提高到 2大于等于两条消息的阈值
@ -204,10 +200,10 @@ class HeartFChatting:
threshold = 2 if random.random() < 0.5 else 1
else:
threshold = 1
if len(recent_messages_list) >= threshold:
# for message in recent_messages_list:
# print(message.processed_plain_text)
# print(message.processed_plain_text)
# !处理no_reply_until_call逻辑
if self.no_reply_until_call:
for message in recent_messages_list:
@ -337,7 +333,7 @@ class HeartFChatting:
# 重置连续 no_reply 计数
self.consecutive_no_reply_count = 0
reason = "有人提到了你,进行回复"
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
@ -395,15 +391,16 @@ class HeartFChatting:
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
if recent_messages_list is None:
recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
_reply_text = "" # 初始化reply_text变量避免UnboundLocalError
start_time = time.time()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
asyncio.create_task(frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust())
asyncio.create_task(
frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()
)
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
# asyncio.create_task(check_and_make_question(self.stream_id))
# 添加jargon提取任务 - 提取聊天中的黑话/俚语并入库(内部自行取消息并带冷却)
@ -411,8 +408,7 @@ class HeartFChatting:
# 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
# 注意后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
# asyncio.create_task(self.chat_history_summarizer.process())
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
@ -427,7 +423,7 @@ class HeartFChatting:
# 如果被提及让回复生成和planner并行执行
if force_reply_message:
logger.info(f"{self.log_prefix} 检测到提及回复生成与planner并行执行")
# 并行执行planner和回复生成
planner_task = asyncio.create_task(
self._run_planner_without_reply(
@ -457,7 +453,12 @@ class HeartFChatting:
# 处理回复结果
if isinstance(reply_result, BaseException):
logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}")
reply_result = {"action_type": "reply", "success": False, "result": "回复生成异常", "loop_info": None}
reply_result = {
"action_type": "reply",
"success": False,
"result": "回复生成异常",
"loop_info": None,
}
else:
# 正常流程只执行planner
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
@ -516,7 +517,7 @@ class HeartFChatting:
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
# 如果有独立的回复结果,添加到结果列表中
if reply_result:
results = list(results) + [reply_result]
@ -558,7 +559,7 @@ class HeartFChatting:
"taken_time": time.time(),
}
)
reply_text = reply_text_from_reply
_reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
@ -571,7 +572,7 @@ class HeartFChatting:
"taken_time": time.time(),
},
}
reply_text = action_reply_text
_reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
@ -647,7 +648,6 @@ class HeartFChatting:
result = await action_handler.execute()
success, action_text = result
return success, action_text
except Exception as e:
@ -655,8 +655,6 @@ class HeartFChatting:
traceback.print_exc()
return False, ""
async def _send_response(
self,
reply_set: "ReplySetModel",
@ -732,7 +730,6 @@ class HeartFChatting:
action_reasoning=reason,
)
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
elif action_planner_info.action_type == "no_reply_until_call":
@ -753,7 +750,12 @@ class HeartFChatting:
action_name="no_reply_until_call",
action_reasoning=reason,
)
return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""}
return {
"action_type": "no_reply_until_call",
"success": True,
"result": "保持沉默,直到有人直接叫的名字",
"command": "",
}
elif action_planner_info.action_type == "reply":
# 直接当场执行reply逻辑
@ -783,19 +785,16 @@ class HeartFChatting:
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()),
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
)
if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message:
logger.info(
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
logger.info(f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败")
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply(
@ -817,12 +816,12 @@ class HeartFChatting:
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, result = await self._handle_action(
action = action_planner_info.action_type,
action_reasoning = action_planner_info.action_reasoning or "",
action_data = action_planner_info.action_data or {},
cycle_timers = cycle_timers,
thinking_id = thinking_id,
action_message= action_planner_info.action_message,
action=action_planner_info.action_type,
action_reasoning=action_planner_info.action_reasoning or "",
action_data=action_planner_info.action_data or {},
cycle_timers=cycle_timers,
thinking_id=thinking_id,
action_message=action_planner_info.action_message,
)
self.last_active_time = time.time()

View File

@ -13,10 +13,11 @@ from src.person_info.person_info import Person
from src.common.database.database_model import Images
if TYPE_CHECKING:
from src.chat.heart_flow.heartFC_chat import HeartFChatting
pass
logger = get_logger("chat")
class HeartFCMessageReceiver:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""

View File

@ -15,7 +15,6 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.person_info.person_info import Person
# 定义日志配置
@ -171,7 +170,11 @@ class ChatBot:
# 撤回事件打印;无法获取被撤回者则省略
if sub_type == "recall":
op_name = getattr(op, "user_cardname", None) or getattr(op, "user_nickname", None) or str(getattr(op, "user_id", None))
op_name = (
getattr(op, "user_cardname", None)
or getattr(op, "user_nickname", None)
or str(getattr(op, "user_id", None))
)
recalled_name = None
try:
if isinstance(recalled, dict):
@ -189,7 +192,7 @@ class ChatBot:
logger.info(f"{op_name} 撤回了消息")
else:
logger.debug(
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op,'user_nickname',None)}({getattr(op,'user_id',None)}) "
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op, 'user_nickname', None)}({getattr(op, 'user_id', None)}) "
f"gid={gid} msg_id={msg_id} recalled={recalled_id}"
)
except Exception:
@ -234,7 +237,6 @@ class ChatBot:
# 确保所有任务已启动
await self._ensure_started()
if message_data["message_info"].get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str(
message_data["message_info"]["group_info"]["group_id"]

View File

@ -143,7 +143,6 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
def find_message_by_id(
@ -306,7 +305,9 @@ class ActionPlanner:
loop_start_time=loop_start_time,
)
logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
logger.info(
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
self.add_plan_log(reasoning, actions)
@ -316,7 +317,7 @@ class ActionPlanner:
self.plan_log.append((reasoning, time.time(), actions))
if len(self.plan_log) > 20:
self.plan_log.pop(0)
def add_plan_excute_log(self, result: str):
self.plan_log.append(("", time.time(), result))
if len(self.plan_log) > 20:
@ -325,17 +326,17 @@ class ActionPlanner:
def get_plan_log_str(self, max_action_records: int = 2, max_execution_records: int = 5) -> str:
"""
获取计划日志字符串
Args:
max_action_records: 显示多少条最新的action记录默认2
max_execution_records: 显示多少条最新执行结果记录默认8
Returns:
格式化的日志字符串
"""
action_records = []
execution_records = []
# 从后往前遍历,收集最新的记录
for reasoning, timestamp, content in reversed(self.plan_log):
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
@ -346,13 +347,13 @@ class ActionPlanner:
# 这是执行结果记录
if len(execution_records) < max_execution_records:
execution_records.append((reasoning, timestamp, content, "execution"))
# 合并所有记录并按时间戳排序
all_records = action_records + execution_records
all_records.sort(key=lambda x: x[1]) # 按时间戳排序
plan_log_str = ""
# 按时间顺序添加所有记录
for reasoning, timestamp, content, record_type in all_records:
time_str = datetime.fromtimestamp(timestamp).strftime("%H:%M:%S")
@ -361,21 +362,21 @@ class ActionPlanner:
plan_log_str += f"{time_str}:{reasoning}\n"
else:
plan_log_str += f"{time_str}:你执行了action:{content}\n"
return plan_log_str
def _has_consecutive_no_reply(self, min_count: int = 3) -> bool:
"""
检查是否有连续min_count次以上的no_reply
Args:
min_count: 需要连续的最少次数默认3
Returns:
如果有连续min_count次以上no_reply返回True否则返回False
"""
consecutive_count = 0
# 从后往前遍历plan_log检查最新的连续记录
for _reasoning, _timestamp, content in reversed(self.plan_log):
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
@ -387,7 +388,7 @@ class ActionPlanner:
else:
# 如果遇到非no_reply的action重置计数
break
return False
async def build_planner_prompt(
@ -402,8 +403,7 @@ class ActionPlanner:
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
actions_before_now_block=self.get_plan_log_str()
actions_before_now_block = self.get_plan_log_str()
# 构建聊天上下文描述
chat_context_description = "你现在正在一个群聊中"
@ -537,7 +537,7 @@ class ActionPlanner:
for require_item in action_info.action_require:
require_text += f"- {require_item}\n"
require_text = require_text.rstrip("\n")
if not action_info.parallel_action:
parallel_text = "(当选择这个动作时,请不要选择其他动作)"
else:
@ -564,7 +564,7 @@ class ActionPlanner:
filtered_actions: Dict[str, ActionInfo],
available_actions: Dict[str, ActionInfo],
loop_start_time: float,
) -> Tuple[str,List[ActionPlannerInfo]]:
) -> Tuple[str, List[ActionPlannerInfo]]:
"""执行主规划器"""
llm_content = None
actions: List[ActionPlannerInfo] = []
@ -589,7 +589,7 @@ class ActionPlanner:
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return f"LLM 请求失败,模型出现问题: {req_e}",[
return f"LLM 请求失败,模型出现问题: {req_e}", [
ActionPlannerInfo(
action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
@ -608,7 +608,11 @@ class ActionPlanner:
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects:
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list, extracted_reasoning))
actions.extend(
self._parse_single_action(
json_obj, message_id_list, filtered_actions_list, extracted_reasoning
)
)
else:
# 尝试解析为直接的JSON
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
@ -631,7 +635,7 @@ class ActionPlanner:
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
return extracted_reasoning,actions
return extracted_reasoning, actions
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
"""创建no_reply"""
@ -674,7 +678,7 @@ class ActionPlanner:
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
if json_str := json_str.strip():
# 尝试按行分割每行可能是一个JSON对象
lines = [line.strip() for line in json_str.split('\n') if line.strip()]
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
for line in lines:
try:
# 尝试解析每一行作为独立的JSON对象
@ -688,7 +692,7 @@ class ActionPlanner:
except json.JSONDecodeError:
# 如果单行解析失败尝试将整个块作为一个JSON对象或数组
pass
# 如果按行解析没有成功尝试将整个块作为一个JSON对象或数组
if not json_objects:
json_obj = json.loads(repair_json(json_str))

View File

@ -134,12 +134,12 @@ class DefaultReplyer:
try:
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
# logger.debug(f"replyer生成内容: {content}")
logger.info(f"replyer生成内容: {content}")
if global_config.debug.show_replyer_reasoning:
logger.info(f"replyer生成推理:\n{reasoning_content}")
logger.info(f"replyer生成模型: {model_name}")
llm_response.content = content
llm_response.reasoning = reasoning_content
llm_response.model = model_name
@ -268,14 +268,13 @@ class DefaultReplyer:
expression_habits_block += f"{style_habits_str}\n"
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_mood_state_prompt(self) -> str:
"""构建情绪状态提示"""
if not global_config.mood.enable_mood:
return ""
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}"
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@ -303,7 +302,7 @@ class DefaultReplyer:
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
_result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}】: {content}\n"
@ -343,45 +342,45 @@ class DefaultReplyer:
def _replace_picids_with_descriptions(self, text: str) -> str:
"""将文本中的[picid:xxx]替换为具体的图片描述
Args:
text: 包含picid标记的文本
Returns:
替换后的文本
"""
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str:
pic_id = match.group(1)
description = translate_pid_to_description(pic_id)
return f"[图片:{description}]"
return re.sub(pic_pattern, replace_pic_id, text)
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
"""分析target内容类型基于原始picid格式
Args:
target: 目标消息内容包含[picid:xxx]格式
Returns:
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
"""
if not target or not target.strip():
return False, False, "", ""
# 检查是否只包含picid标记
picid_pattern = r"\[picid:[^\]]+\]"
picid_matches = re.findall(picid_pattern, target)
# 移除所有picid标记后检查是否还有文字内容
text_without_picids = re.sub(picid_pattern, "", target).strip()
has_only_pics = len(picid_matches) > 0 and not text_without_picids
has_text = bool(text_without_picids)
# 提取图片部分(转换为[图片:描述]格式)
pic_part = ""
if picid_matches:
@ -396,7 +395,7 @@ class DefaultReplyer:
else:
pic_descriptions.append(f"[图片:{description}]")
pic_part = "".join(pic_descriptions)
return has_only_pics, has_text, pic_part, text_without_picids
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
@ -481,7 +480,7 @@ class DefaultReplyer:
)
return all_dialogue_prompt
def core_background_build_chat_history_prompts(
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]:
@ -603,25 +602,27 @@ class DefaultReplyer:
# 获取基础personality
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (global_config.personality.states and
global_config.personality.state_probability > 0 and
random.random() < global_config.personality.state_probability):
if (
global_config.personality.states
and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
# 随机选择一个状态替换personality
selected_state = random.choice(global_config.personality.states)
prompt_personality = selected_state
prompt_personality = f"{prompt_personality};"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]:
"""
解析聊天prompt配置字符串并生成对应的 chat_id prompt内容
Args:
chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串
Returns:
tuple: (chat_id, prompt_content)如果解析失败则返回 None
"""
@ -657,10 +658,10 @@ class DefaultReplyer:
def get_chat_prompt_for_chat(self, chat_id: str) -> str:
"""
根据聊天流ID获取匹配的额外prompt仅匹配group类型
Args:
chat_id: 聊天流ID哈希值
Returns:
str: 匹配的额外prompt内容如果没有匹配则返回空字符串
"""
@ -670,21 +671,21 @@ class DefaultReplyer:
for chat_prompt_str in global_config.experimental.chat_prompts:
if not isinstance(chat_prompt_str, str):
continue
# 解析配置字符串检查类型是否为group
parts = chat_prompt_str.split(":", 3)
if len(parts) != 4:
continue
stream_type = parts[2]
# 只匹配group类型
if stream_type != "group":
continue
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
if result is None:
continue
config_chat_id, prompt_content = result
if config_chat_id == chat_id:
logger.debug(f"匹配到群聊prompt配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
@ -720,7 +721,7 @@ class DefaultReplyer:
available_actions = {}
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info)
_is_group_chat = bool(chat_stream.group_info)
platform = chat_stream.platform
user_id = "用户ID"
@ -736,10 +737,10 @@ class DefaultReplyer:
target = reply_message.processed_plain_text
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
@ -911,10 +912,10 @@ class DefaultReplyer:
sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
@ -956,9 +957,7 @@ class DefaultReplyer:
)
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
# 只包含文字
reply_target_block = (
@ -975,7 +974,9 @@ class DefaultReplyer:
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
)
else:
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
@ -1132,6 +1133,7 @@ class DefaultReplyer:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
加权且不放回地随机抽取k个元素

View File

@ -46,6 +46,7 @@ init_memory_retrieval_prompt()
logger = get_logger("replyer")
class PrivateReplyer:
def __init__(
self,
@ -277,9 +278,7 @@ class PrivateReplyer:
expression_habits_block = ""
expression_habits_title = ""
if style_habits_str.strip():
expression_habits_title = (
"在回复时,你可以参考以下的语言习惯,不要生硬使用:"
)
expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
expression_habits_block += f"{style_habits_str}\n"
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
@ -291,7 +290,6 @@ class PrivateReplyer:
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}"
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@ -358,45 +356,45 @@ class PrivateReplyer:
def _replace_picids_with_descriptions(self, text: str) -> str:
"""将文本中的[picid:xxx]替换为具体的图片描述
Args:
text: 包含picid标记的文本
Returns:
替换后的文本
"""
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str:
pic_id = match.group(1)
description = translate_pid_to_description(pic_id)
return f"[图片:{description}]"
return re.sub(pic_pattern, replace_pic_id, text)
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
"""分析target内容类型基于原始picid格式
Args:
target: 目标消息内容包含[picid:xxx]格式
Returns:
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
"""
if not target or not target.strip():
return False, False, "", ""
# 检查是否只包含picid标记
picid_pattern = r"\[picid:[^\]]+\]"
picid_matches = re.findall(picid_pattern, target)
# 移除所有picid标记后检查是否还有文字内容
text_without_picids = re.sub(picid_pattern, "", target).strip()
has_only_pics = len(picid_matches) > 0 and not text_without_picids
has_text = bool(text_without_picids)
# 提取图片部分(转换为[图片:描述]格式)
pic_part = ""
if picid_matches:
@ -411,7 +409,7 @@ class PrivateReplyer:
else:
pic_descriptions.append(f"[图片:{description}]")
pic_part = "".join(pic_descriptions)
return has_only_pics, has_text, pic_part, text_without_picids
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
@ -517,25 +515,27 @@ class PrivateReplyer:
# 获取基础personality
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (global_config.personality.states and
global_config.personality.state_probability > 0 and
random.random() < global_config.personality.state_probability):
if (
global_config.personality.states
and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
# 随机选择一个状态替换personality
selected_state = random.choice(global_config.personality.states)
prompt_personality = selected_state
prompt_personality = f"{prompt_personality};"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
def _parse_chat_prompt_config_to_chat_id(self, chat_prompt_str: str) -> Optional[tuple[str, str]]:
"""
解析聊天prompt配置字符串并生成对应的 chat_id prompt内容
Args:
chat_prompt_str: 格式为 "platform:id:type:prompt内容" 的字符串
Returns:
tuple: (chat_id, prompt_content)如果解析失败则返回 None
"""
@ -571,10 +571,10 @@ class PrivateReplyer:
def get_chat_prompt_for_chat(self, chat_id: str) -> str:
"""
根据聊天流ID获取匹配的额外prompt仅匹配private类型
Args:
chat_id: 聊天流ID哈希值
Returns:
str: 匹配的额外prompt内容如果没有匹配则返回空字符串
"""
@ -584,21 +584,21 @@ class PrivateReplyer:
for chat_prompt_str in global_config.experimental.chat_prompts:
if not isinstance(chat_prompt_str, str):
continue
# 解析配置字符串检查类型是否为private
parts = chat_prompt_str.split(":", 3)
if len(parts) != 4:
continue
stream_type = parts[2]
# 只匹配private类型
if stream_type != "private":
continue
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
if result is None:
continue
config_chat_id, prompt_content = result
if config_chat_id == chat_id:
logger.debug(f"匹配到私聊prompt配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
@ -647,13 +647,11 @@ class PrivateReplyer:
sender = person_name
target = reply_message.processed_plain_text
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
@ -662,7 +660,7 @@ class PrivateReplyer:
timestamp=time.time(),
limit=global_config.chat.max_context_size,
)
dialogue_prompt = build_readable_messages(
message_list_before_now_long,
replace_bot_name=True,
@ -710,9 +708,7 @@ class PrivateReplyer:
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
),
self._time_and_run_task(
self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"
),
self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
),
@ -852,15 +848,13 @@ class PrivateReplyer:
sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
@ -900,9 +894,7 @@ class PrivateReplyer:
)
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
# 只包含文字
reply_target_block = (
@ -919,7 +911,9 @@ class PrivateReplyer:
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
)
else:
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
@ -1010,7 +1004,7 @@ class PrivateReplyer:
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt
)
content = content.strip()
logger.info(f"使用 {model_name} 生成回复内容: {content}")
@ -1106,6 +1100,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
pool.pop(idx)
break
return selected

View File

@ -1,16 +1,13 @@
from src.chat.utils.prompt_builder import Prompt
# from src.chat.memory_system.memory_activator import MemoryActivator
def init_replyer_prompt():
Prompt("正在群里聊天", "chat_target_group2")
Prompt("{sender_name}聊天", "chat_target_private2")
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}
你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片:
@ -27,10 +24,9 @@ def init_replyer_prompt():
现在你说""",
"replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}
你正在和{sender_name}聊天这是你们之前聊的内容:
@ -46,10 +42,9 @@ def init_replyer_prompt():
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )""",
"private_replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}
你正在和{sender_name}聊天这是你们之前聊的内容:
@ -65,4 +60,4 @@ def init_replyer_prompt():
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )
""",
"private_replyer_self_prompt",
)
)

View File

@ -2,6 +2,7 @@
聊天内容概括器
用于累积打包和压缩聊天记录
"""
import asyncio
import json
import time
@ -23,6 +24,7 @@ logger = get_logger("chat_history_summarizer")
@dataclass
class MessageBatch:
"""消息批次"""
messages: List[DatabaseMessages]
start_time: float
end_time: float
@ -31,11 +33,11 @@ class MessageBatch:
class ChatHistorySummarizer:
"""聊天内容概括器"""
def __init__(self, chat_id: str, check_interval: int = 60):
"""
初始化聊天内容概括器
Args:
chat_id: 聊天ID
check_interval: 定期检查间隔默认60秒
@ -43,24 +45,23 @@ class ChatHistorySummarizer:
self.chat_id = chat_id
self._chat_display_name = self._get_chat_display_name()
self.log_prefix = f"[{self._chat_display_name}]"
# 记录时间点,用于计算新消息
self.last_check_time = time.time()
# 当前累积的消息批次
self.current_batch: Optional[MessageBatch] = None
# LLM请求器用于压缩聊天内容
self.summarizer_llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="chat_history_summarizer"
model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
)
# 后台循环相关
self.check_interval = check_interval # 检查间隔(秒)
self._periodic_task: Optional[asyncio.Task] = None
self._running = False
def _get_chat_display_name(self) -> str:
"""获取聊天显示名称"""
try:
@ -76,17 +77,17 @@ class ChatHistorySummarizer:
if len(self.chat_id) > 20:
return f"{self.chat_id[:8]}..."
return self.chat_id
async def process(self, current_time: Optional[float] = None):
"""
处理聊天内容概括
Args:
current_time: 当前时间戳如果为None则使用time.time()
"""
if current_time is None:
current_time = time.time()
try:
logger.info(
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
@ -101,25 +102,23 @@ class ChatHistorySummarizer:
filter_mai=False, # 不过滤bot消息因为需要检查bot是否发言
filter_command=False,
)
if not new_messages:
# 没有新消息,检查是否需要打包
if self.current_batch and self.current_batch.messages:
await self._check_and_package(current_time)
self.last_check_time = current_time
return
# 有新消息,更新最后检查时间
self.last_check_time = current_time
# 如果有当前批次,添加新消息
if self.current_batch:
before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time
logger.info(
f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
else:
# 创建新批次
self.current_batch = MessageBatch(
@ -127,23 +126,22 @@ class ChatHistorySummarizer:
start_time=new_messages[0].time if new_messages else current_time,
end_time=current_time,
)
logger.info(
f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息"
)
logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
# 检查是否需要打包
await self._check_and_package(current_time)
except Exception as e:
logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}")
import traceback
traceback.print_exc()
async def _check_and_package(self, current_time: float):
"""检查是否需要打包"""
if not self.current_batch or not self.current_batch.messages:
return
messages = self.current_batch.messages
message_count = len(messages)
last_message_time = messages[-1].time if messages else current_time
@ -153,48 +151,48 @@ class ChatHistorySummarizer:
if time_since_last_message < 60:
time_str = f"{time_since_last_message:.1f}"
elif time_since_last_message < 3600:
time_str = f"{time_since_last_message/60:.1f}分钟"
time_str = f"{time_since_last_message / 60:.1f}分钟"
else:
time_str = f"{time_since_last_message/3600:.1f}小时"
time_str = f"{time_since_last_message / 3600:.1f}小时"
preparing_status = "" if self.current_batch.is_preparing else ""
logger.info(
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距最后消息: {time_str} | 准备结束模式: {preparing_status}"
)
# 检查打包条件
should_package = False
# 条件1: 消息长度超过120直接打包
if message_count >= 120:
should_package = True
logger.info(f"{self.log_prefix} 触发打包条件: 消息数量达到 {message_count} 条(阈值: 120条")
# 条件2: 最后一条消息的时间和当前时间差>600秒直接打包
elif time_since_last_message > 600:
should_package = True
logger.info(f"{self.log_prefix} 触发打包条件: 距最后消息 {time_str}(阈值: 10分钟")
# 条件3: 消息长度超过100进入准备结束模式
elif message_count > 100:
if not self.current_batch.is_preparing:
self.current_batch.is_preparing = True
logger.info(f"{self.log_prefix} 消息数量 {message_count} 条超过阈值100条进入准备结束模式")
# 在准备结束模式下,如果最后一条消息的时间和当前时间差>10秒就打包
if time_since_last_message > 10:
should_package = True
logger.info(f"{self.log_prefix} 触发打包条件: 准备结束模式下,距最后消息 {time_str}(阈值: 10秒")
if should_package:
await self._package_and_store()
async def _package_and_store(self):
"""打包并存储聊天记录"""
if not self.current_batch or not self.current_batch.messages:
return
messages = self.current_batch.messages
start_time = self.current_batch.start_time
end_time = self.current_batch.end_time
@ -202,12 +200,12 @@ class ChatHistorySummarizer:
logger.info(
f"{self.log_prefix} 开始打包批次 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
)
# 检查是否有bot发言
# 第一条消息前推600s到最后一条消息的时间内
check_start_time = max(start_time - 600, 0)
check_end_time = end_time
# 使用包含边界的时间范围查询
bot_messages = message_api.get_messages_by_time_in_chat_inclusive(
chat_id=self.chat_id,
@ -218,7 +216,7 @@ class ChatHistorySummarizer:
filter_mai=False,
filter_command=False,
)
# 检查是否有bot的发言
has_bot_message = False
bot_user_id = str(global_config.bot.qq_account)
@ -226,14 +224,14 @@ class ChatHistorySummarizer:
if msg.user_info.user_id == bot_user_id:
has_bot_message = True
break
if not has_bot_message:
logger.info(
f"{self.log_prefix} 批次内无Bot发言丢弃批次 | 检查时间范围: {check_start_time:.2f} - {check_end_time:.2f}"
)
self.current_batch = None
return
# 有bot发言进行压缩和存储
try:
# 构建对话原文
@ -245,39 +243,36 @@ class ChatHistorySummarizer:
truncate=False,
show_actions=False,
)
# 获取参与的所有人的昵称
participants_set: Set[str] = set()
for msg in messages:
# 使用 msg.user_platform扁平化字段或 msg.user_info.platform
platform = getattr(msg, 'user_platform', None) or (msg.user_info.platform if msg.user_info else None) or msg.chat_info.platform
person = Person(
platform=platform,
user_id=msg.user_info.user_id
platform = (
getattr(msg, "user_platform", None)
or (msg.user_info.platform if msg.user_info else None)
or msg.chat_info.platform
)
person = Person(platform=platform, user_id=msg.user_info.user_id)
person_name = person.person_name
if person_name:
participants_set.add(person_name)
participants = list(participants_set)
logger.info(
f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}"
)
logger.info(f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}")
# 使用LLM压缩聊天内容
success, theme, keywords, summary = await self._compress_with_llm(original_text)
if not success:
logger.warning(
f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}"
)
logger.warning(f"{self.log_prefix} LLM压缩失败不存储到数据库 | 消息数: {len(messages)}")
# 清空当前批次,避免重复处理
self.current_batch = None
return
logger.info(
f"{self.log_prefix} LLM压缩完成 | 主题: {theme} | 关键词数: {len(keywords)} | 概括长度: {len(summary)}"
)
# 存储到数据库
await self._store_to_database(
start_time=start_time,
@ -288,23 +283,24 @@ class ChatHistorySummarizer:
keywords=keywords,
summary=summary,
)
logger.info(f"{self.log_prefix} 成功打包并存储聊天记录 | 消息数: {len(messages)} | 主题: {theme}")
# 清空当前批次
self.current_batch = None
except Exception as e:
logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}")
import traceback
traceback.print_exc()
# 出错时也清空批次,避免重复处理
self.current_batch = None
async def _compress_with_llm(self, original_text: str) -> tuple[bool, str, List[str], str]:
"""
使用LLM压缩聊天内容
Returns:
tuple[bool, str, List[str], str]: (是否成功, 主题, 关键词列表, 概括)
"""
@ -325,37 +321,37 @@ class ChatHistorySummarizer:
{original_text}
请直接返回JSON不要包含其他内容"""
try:
response, _ = await self.summarizer_llm.generate_response_async(
prompt=prompt,
temperature=0.3,
max_tokens=500,
)
# 解析JSON响应
import re
# 移除可能的markdown代码块标记
json_str = response.strip()
json_str = re.sub(r'^```json\s*', '', json_str, flags=re.MULTILINE)
json_str = re.sub(r'^```\s*', '', json_str, flags=re.MULTILINE)
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
json_str = json_str.strip()
# 尝试找到JSON对象的开始和结束位置
# 查找第一个 { 和最后一个匹配的 }
start_idx = json_str.find('{')
start_idx = json_str.find("{")
if start_idx == -1:
raise ValueError("未找到JSON对象开始标记")
# 从后往前查找最后一个 }
end_idx = json_str.rfind('}')
end_idx = json_str.rfind("}")
if end_idx == -1 or end_idx <= start_idx:
raise ValueError("未找到JSON对象结束标记")
# 提取JSON字符串
json_str = json_str[start_idx:end_idx + 1]
json_str = json_str[start_idx : end_idx + 1]
# 尝试解析JSON
try:
result = json.loads(json_str)
@ -372,7 +368,7 @@ class ChatHistorySummarizer:
if escape_next:
fixed_chars.append(char)
escape_next = False
elif char == '\\':
elif char == "\\":
fixed_chars.append(char)
escape_next = True
elif char == '"' and not escape_next:
@ -384,27 +380,27 @@ class ChatHistorySummarizer:
else:
fixed_chars.append(char)
i += 1
json_str = ''.join(fixed_chars)
json_str = "".join(fixed_chars)
# 再次尝试解析
result = json.loads(json_str)
theme = result.get("theme", "未命名对话")
keywords = result.get("keywords", [])
summary = result.get("summary", "无概括")
# 确保keywords是列表
if isinstance(keywords, str):
keywords = [keywords]
return True, theme, keywords, summary
except Exception as e:
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
# 返回失败标志和默认值
return False, "未命名对话", [], "压缩失败,无法生成概括"
async def _store_to_database(
self,
start_time: float,
@ -419,7 +415,7 @@ class ChatHistorySummarizer:
try:
from src.common.database.database_model import ChatHistory
from src.plugin_system.apis import database_api
# 准备数据
data = {
"chat_id": self.chat_id,
@ -432,7 +428,7 @@ class ChatHistorySummarizer:
"summary": summary,
"count": 0,
}
# 使用db_save存储使用start_time和chat_id作为唯一标识
# 由于可能有多条记录我们使用组合键但peewee不支持所以使用start_time作为唯一标识
# 但为了避免冲突我们使用组合键chat_id + start_time
@ -441,28 +437,29 @@ class ChatHistorySummarizer:
ChatHistory,
data=data,
)
if saved_record:
logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库")
else:
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
except Exception as e:
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
import traceback
traceback.print_exc()
raise
async def start(self):
"""启动后台定期检查循环"""
if self._running:
logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动")
return
self._running = True
self._periodic_task = asyncio.create_task(self._periodic_check_loop())
logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}")
async def stop(self):
"""停止后台定期检查循环"""
self._running = False
@ -474,14 +471,14 @@ class ChatHistorySummarizer:
pass
self._periodic_task = None
logger.info(f"{self.log_prefix} 已停止后台定期检查循环")
async def _periodic_check_loop(self):
"""后台定期检查循环"""
try:
while self._running:
# 执行一次检查
await self.process()
# 等待指定间隔后再次检查
await asyncio.sleep(self.check_interval)
except asyncio.CancelledError:
@ -490,6 +487,6 @@ class ChatHistorySummarizer:
except Exception as e:
logger.error(f"{self.log_prefix} 后台检查循环出错: {e}")
import traceback
traceback.print_exc()
self._running = False

View File

@ -2,7 +2,7 @@ import time
import random
import re
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from src.config.config import global_config
@ -568,7 +568,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
output_lines = []
current_time = time.time()
for action in actions:
action_time = action.time or current_time
action_name = action.action_name or "未知动作"
@ -595,7 +594,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}"
output_lines.append(line)
return "\n".join(output_lines)
@ -936,7 +934,6 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
return formatted_string
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)

View File

@ -2,6 +2,7 @@
记忆遗忘任务
每5分钟进行一次遗忘检查根据不同的遗忘阶段删除记忆
"""
import time
import random
from typing import List
@ -15,27 +16,27 @@ logger = get_logger("memory_forget_task")
class MemoryForgetTask(AsyncTask):
"""记忆遗忘任务每5分钟执行一次"""
def __init__(self):
# 每5分钟执行一次300秒
super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)
async def run(self):
"""执行遗忘检查"""
try:
current_time = time.time()
logger.info("[记忆遗忘] 开始遗忘检查...")
# 执行4个阶段的遗忘检查
await self._forget_stage_1(current_time)
await self._forget_stage_2(current_time)
await self._forget_stage_3(current_time)
await self._forget_stage_4(current_time)
logger.info("[记忆遗忘] 遗忘检查完成")
except Exception as e:
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
async def _forget_stage_1(self, current_time: float):
"""
第一次遗忘检查
@ -45,38 +46,34 @@ class MemoryForgetTask(AsyncTask):
try:
# 30分钟 = 1800秒
time_threshold = current_time - 1800
# 查询符合条件的记忆forget_times=0 且 end_time < time_threshold
candidates = list(
ChatHistory.select()
.where(
(ChatHistory.forget_times == 0) &
(ChatHistory.end_time < time_threshold)
)
ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高25%和最低25%
total_count = len(candidates)
delete_count = int(total_count * 0.25) # 25%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段1] 删除数量为0跳过")
return
# 选择要删除的记录处理count相同的情况随机选择
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重避免重复删除使用id去重
seen_ids = set()
unique_to_delete = []
@ -85,7 +82,7 @@ class MemoryForgetTask(AsyncTask):
seen_ids.add(record.id)
unique_to_delete.append(record)
to_delete = unique_to_delete
# 删除记录并更新forget_times
deleted_count = 0
for record in to_delete:
@ -94,22 +91,22 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")
# 更新剩余记录的forget_times为1
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
# 批量更新
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=1).where(
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1")
ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
async def _forget_stage_2(self, current_time: float):
"""
第二次遗忘检查
@ -119,41 +116,37 @@ class MemoryForgetTask(AsyncTask):
try:
# 8小时 = 28800秒
time_threshold = current_time - 28800
# 查询符合条件的记忆forget_times=1 且 end_time < time_threshold
candidates = list(
ChatHistory.select()
.where(
(ChatHistory.forget_times == 1) &
(ChatHistory.end_time < time_threshold)
)
ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高7%和最低7%
total_count = len(candidates)
delete_count = int(total_count * 0.07) # 7%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段2] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
@ -162,21 +155,21 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")
# 更新剩余记录的forget_times为2
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=2).where(
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2")
ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
async def _forget_stage_3(self, current_time: float):
"""
第三次遗忘检查
@ -186,41 +179,37 @@ class MemoryForgetTask(AsyncTask):
try:
# 48小时 = 172800秒
time_threshold = current_time - 172800
# 查询符合条件的记忆forget_times=2 且 end_time < time_threshold
candidates = list(
ChatHistory.select()
.where(
(ChatHistory.forget_times == 2) &
(ChatHistory.end_time < time_threshold)
)
ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高5%和最低5%
total_count = len(candidates)
delete_count = int(total_count * 0.05) # 5%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段3] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
@ -229,21 +218,21 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")
# 更新剩余记录的forget_times为3
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=3).where(
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3")
ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
async def _forget_stage_4(self, current_time: float):
"""
第四次遗忘检查
@ -253,41 +242,37 @@ class MemoryForgetTask(AsyncTask):
try:
# 7天 = 604800秒
time_threshold = current_time - 604800
# 查询符合条件的记忆forget_times=3 且 end_time < time_threshold
candidates = list(
ChatHistory.select()
.where(
(ChatHistory.forget_times == 3) &
(ChatHistory.end_time < time_threshold)
)
ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高2%和最低2%
total_count = len(candidates)
delete_count = int(total_count * 0.02) # 2%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段4] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
@ -296,38 +281,40 @@ class MemoryForgetTask(AsyncTask):
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")
# 更新剩余记录的forget_times为4
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=4).where(
ChatHistory.id.in_(ids_to_update)
).execute()
logger.info(f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4")
ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
def _handle_same_count_random(self, candidates: List[ChatHistory], delete_count: int, mode: str) -> List[ChatHistory]:
def _handle_same_count_random(
self, candidates: List[ChatHistory], delete_count: int, mode: str
) -> List[ChatHistory]:
"""
处理count相同的情况随机选择要删除的记录
Args:
candidates: 候选记录列表已按count排序
delete_count: 要删除的数量
mode: "high" 表示选择最高count的记录"low" 表示选择最低count的记录
Returns:
要删除的记录列表
"""
if not candidates or delete_count == 0:
return []
to_delete = []
if mode == "high":
# 从最高count开始选择
start_idx = 0
@ -339,7 +326,7 @@ class MemoryForgetTask(AsyncTask):
while idx < len(candidates) and candidates[idx].count == current_count:
same_count_records.append(candidates[idx])
idx += 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete)
if len(same_count_records) <= needed:
@ -347,9 +334,9 @@ class MemoryForgetTask(AsyncTask):
else:
# 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx
else: # mode == "low"
# 从最低count开始选择
start_idx = len(candidates) - 1
@ -361,7 +348,7 @@ class MemoryForgetTask(AsyncTask):
while idx >= 0 and candidates[idx].count == current_count:
same_count_records.append(candidates[idx])
idx -= 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete)
if len(same_count_records) <= needed:
@ -369,8 +356,7 @@ class MemoryForgetTask(AsyncTask):
else:
# 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx
return to_delete
start_idx = idx
return to_delete

View File

@ -153,7 +153,7 @@ def _format_large_number(num: float | int, html: bool = False) -> str:
else:
number_part = f"{value:.1f}"
k_suffix = "K"
if html:
# HTML输出K着色为主题色并加粗大写
return f"{number_part}<span style='color: #8b5cf6; font-weight: bold;'>K</span>"
@ -502,9 +502,13 @@ class StatisticOutputTask(AsyncTask):
}
for period_key, _ in collect_period
}
# 获取bot的QQ账号
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else ""
bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
@ -547,7 +551,7 @@ class StatisticOutputTask(AsyncTask):
is_bot_reply = False
if bot_qq_account and message.user_id == bot_qq_account:
is_bot_reply = True
for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
@ -588,7 +592,9 @@ class StatisticOutputTask(AsyncTask):
continue
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
self.stat_period = [
item for item in self.stat_period if item[0] != "all_time"
] # 删除"所有时间"的统计时段
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
except Exception as e:
logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}")
@ -640,12 +646,12 @@ class StatisticOutputTask(AsyncTask):
# 更新上次完整统计数据的时间戳
# 将所有defaultdict转换为普通dict以避免类型冲突
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
# 将 name_mapping 中的元组转换为列表因为JSON不支持元组
json_safe_name_mapping = {}
for chat_id, (chat_name, timestamp) in self.name_mapping.items():
json_safe_name_mapping[chat_id] = [chat_name, timestamp]
local_storage["last_full_statistics"] = {
"name_mapping": json_safe_name_mapping,
"stat_data": clean_stat_data,
@ -682,24 +688,28 @@ class StatisticOutputTask(AsyncTask):
"""
# 计算总token数从所有模型的token数中累加
total_tokens = sum(stats[TOTAL_TOK_BY_MODEL].values()) if stats[TOTAL_TOK_BY_MODEL] else 0
# 计算花费/消息数量指标每100条
cost_per_100_messages = (stats[TOTAL_COST] / stats[TOTAL_MSG_CNT] * 100) if stats[TOTAL_MSG_CNT] > 0 else 0.0
# 计算花费/时间指标(花费/小时)
online_hours = stats[ONLINE_TIME] / 3600.0 if stats[ONLINE_TIME] > 0 else 0.0
cost_per_hour = stats[TOTAL_COST] / online_hours if online_hours > 0 else 0.0
# 计算token/时间指标token/小时)
tokens_per_hour = (total_tokens / online_hours) if online_hours > 0 else 0.0
# 计算花费/回复数量指标每100条
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
cost_per_100_replies = (stats[TOTAL_COST] / total_replies * 100) if total_replies > 0 else 0.0
# 计算花费/消息数量排除自己回复指标每100条
total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies
cost_per_100_messages_excluding_replies = (stats[TOTAL_COST] / total_messages_excluding_replies * 100) if total_messages_excluding_replies > 0 else 0.0
cost_per_100_messages_excluding_replies = (
(stats[TOTAL_COST] / total_messages_excluding_replies * 100)
if total_messages_excluding_replies > 0
else 0.0
)
output = [
f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
@ -709,7 +719,9 @@ class StatisticOutputTask(AsyncTask):
f"总Token数: {_format_large_number(total_tokens)}",
f"总花费: {stats[TOTAL_COST]:.2f}¥",
f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A",
f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条" if total_messages_excluding_replies > 0 else "花费/消息数量(排除回复): N/A",
f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条"
if total_messages_excluding_replies > 0
else "花费/消息数量(排除回复): N/A",
f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A",
f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A",
f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A",
@ -745,7 +757,16 @@ class StatisticOutputTask(AsyncTask):
formatted_out_tokens = _format_large_number(out_tokens)
formatted_tokens = _format_large_number(tokens)
output.append(
data_fmt.format(name, formatted_count, formatted_in_tokens, formatted_out_tokens, formatted_tokens, cost, avg_time_cost, std_time_cost)
data_fmt.format(
name,
formatted_count,
formatted_in_tokens,
formatted_out_tokens,
formatted_tokens,
cost,
avg_time_cost,
std_time_cost,
)
)
output.append("")
@ -891,8 +912,12 @@ class StatisticOutputTask(AsyncTask):
except (IndexError, TypeError) as e:
logger.warning(f"生成HTML聊天统计时发生错误chat_id: {chat_id}, 错误: {e}")
chat_rows.append(f"<tr><td>未知聊天</td><td>{_format_large_number(count, html=True)}</td></tr>")
chat_rows_html = "\n".join(chat_rows) if chat_rows else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
chat_rows_html = (
"\n".join(chat_rows)
if chat_rows
else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
)
# 生成HTML
return f"""
<div id=\"{div_id}\" class=\"tab-content\">
@ -1197,7 +1222,7 @@ class StatisticOutputTask(AsyncTask):
# 添加图表内容
chart_data = self._generate_chart_data(stat)
tab_content_list.append(self._generate_chart_tab(chart_data))
# 添加指标趋势图表
metrics_data = self._generate_metrics_data(now)
tab_content_list.append(self._generate_metrics_tab(metrics_data))
@ -1772,121 +1797,125 @@ class StatisticOutputTask(AsyncTask):
def _generate_metrics_data(self, now: datetime) -> dict:
"""生成指标趋势数据"""
metrics_data = {}
# 24小时尺度1小时为单位
metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1)
# 7天尺度1天为单位
metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24*7, interval_hours=24)
metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24 * 7, interval_hours=24)
# 30天尺度1天为单位
metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24*30, interval_hours=24)
metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24 * 30, interval_hours=24)
return metrics_data
def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict:
"""收集指定时间范围内每个间隔的指标数据"""
start_time = now - timedelta(hours=hours)
time_points = []
current_time = start_time
# 生成时间点
while current_time <= now:
time_points.append(current_time)
current_time += timedelta(hours=interval_hours)
# 初始化数据结构
cost_per_100_messages = [0.0] * len(time_points) # 花费/消息数量每100条
cost_per_hour = [0.0] * len(time_points) # 花费/时间(每小时)
tokens_per_hour = [0.0] * len(time_points) # Token/时间(每小时)
cost_per_100_replies = [0.0] * len(time_points) # 花费/回复数量每100条
# 每个时间点的累计数据
total_costs = [0.0] * len(time_points)
total_tokens = [0] * len(time_points)
total_messages = [0] * len(time_points)
total_replies = [0] * len(time_points)
total_online_hours = [0.0] * len(time_points)
# 获取bot的QQ账号
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else ""
bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
interval_seconds = interval_hours * 3600
# 查询LLM使用记录
query_start_time = start_time
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp
# 找到对应的时间间隔索引
time_diff = (record_time - start_time).total_seconds()
interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points):
cost = record.cost or 0.0
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
total_token = prompt_tokens + completion_tokens
total_costs[interval_index] += cost
total_tokens[interval_index] += total_token
# 查询消息记录
query_start_timestamp = start_time.timestamp()
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time
time_diff = message_time_ts - query_start_timestamp
interval_index = int(time_diff // interval_seconds)
if 0 <= interval_index < len(time_points):
total_messages[interval_index] += 1
# 检查是否是bot发送的消息回复
if bot_qq_account and message.user_id == bot_qq_account:
total_replies[interval_index] += 1
# 查询在线时间记录
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= start_time): # type: ignore
record_start = record.start_timestamp
record_end = record.end_timestamp
# 找到记录覆盖的所有时间间隔
for idx, time_point in enumerate(time_points):
interval_start = time_point
interval_end = time_point + timedelta(hours=interval_hours)
# 计算重叠部分
overlap_start = max(record_start, interval_start)
overlap_end = min(record_end, interval_end)
if overlap_end > overlap_start:
overlap_hours = (overlap_end - overlap_start).total_seconds() / 3600.0
total_online_hours[idx] += overlap_hours
# 计算指标
for idx in range(len(time_points)):
# 花费/消息数量每100条
if total_messages[idx] > 0:
cost_per_100_messages[idx] = (total_costs[idx] / total_messages[idx] * 100)
cost_per_100_messages[idx] = total_costs[idx] / total_messages[idx] * 100
# 花费/时间(每小时)
if total_online_hours[idx] > 0:
cost_per_hour[idx] = (total_costs[idx] / total_online_hours[idx])
cost_per_hour[idx] = total_costs[idx] / total_online_hours[idx]
# Token/时间(每小时)
if total_online_hours[idx] > 0:
tokens_per_hour[idx] = (total_tokens[idx] / total_online_hours[idx])
tokens_per_hour[idx] = total_tokens[idx] / total_online_hours[idx]
# 花费/回复数量每100条
if total_replies[idx] > 0:
cost_per_100_replies[idx] = (total_costs[idx] / total_replies[idx] * 100)
cost_per_100_replies[idx] = total_costs[idx] / total_replies[idx] * 100
# 生成时间标签
if interval_hours == 1:
time_labels = [t.strftime("%H:%M") for t in time_points]
else:
time_labels = [t.strftime("%m-%d") for t in time_points]
return {
"time_labels": time_labels,
"cost_per_100_messages": cost_per_100_messages,
@ -1894,7 +1923,7 @@ class StatisticOutputTask(AsyncTask):
"tokens_per_hour": tokens_per_hour,
"cost_per_100_replies": cost_per_100_replies,
}
def _generate_metrics_tab(self, metrics_data: dict) -> str:
"""生成指标趋势图表选项卡HTML内容"""
colors = {
@ -1903,7 +1932,7 @@ class StatisticOutputTask(AsyncTask):
"tokens_per_hour": "#c7bbff",
"cost_per_100_replies": "#d9ceff",
}
return f"""
<div id="metrics" class="tab-content">
<h2>指标趋势图表</h2>

View File

@ -4,14 +4,11 @@ import time
import jieba
import json
import ast
import numpy as np
from collections import Counter
from typing import Optional, Tuple, List, TYPE_CHECKING
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
@ -32,10 +29,10 @@ def is_english_letter(char: str) -> bool:
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
"""解析 platforms 列表,返回平台到账号的映射
Args:
platforms: 格式为 ["platform:account"] 的列表 ["tg:123456789", "wx:wxid123"]
Returns:
字典键为平台名值为账号
"""
@ -49,12 +46,12 @@ def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
"""根据当前平台获取对应的账号
Args:
platform: 当前消息的平台
platform_accounts: platforms 列表解析的平台账号映射
qq_account: QQ 账号兼容旧配置
Returns:
当前平台对应的账号
"""
@ -72,12 +69,12 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
"""检查消息是否提到了机器人(统一多平台实现)"""
text = message.processed_plain_text or ""
platform = getattr(message.message_info, "platform", "") or ""
# 获取各平台账号
platforms_list = getattr(global_config.bot, "platforms", []) or []
platform_accounts = parse_platform_accounts(platforms_list)
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
# 获取当前平台对应的账号
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
@ -146,7 +143,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
elif current_account:
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\)(.+?)\],说:", text):
is_mentioned = True
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text):
elif re.search(
rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text
):
is_mentioned = True
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
@ -185,7 +184,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
return embedding
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并
1. 识别分割点, ; 空格但如果分割点左右都是英文字母则不分割
@ -227,7 +225,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
prev_char = text[i - 1]
next_char = text[i + 1]
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
if char == ' ':
if char == " ":
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
if prev_is_alnum and next_is_alnum:
@ -340,7 +338,7 @@ def _get_random_default_reply() -> str:
"不知道",
"不晓得",
"懒得说",
"()"
"()",
]
return random.choice(default_replies)
@ -469,7 +467,6 @@ def calculate_typing_time(
return total_time # 加上回车时间
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
return f"{message[:max_length]}..." if len(message) > max_length else message
@ -546,7 +543,6 @@ def get_western_ratio(paragraph):
return western_count / len(alnum_chars)
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式

View File

@ -103,14 +103,16 @@ class ImageManager:
invalid_values = ["", "None"]
# 清理 Images 表
deleted_images = Images.delete().where(
(Images.description >> None) | (Images.description << invalid_values)
).execute()
deleted_images = (
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
)
# 清理 ImageDescriptions 表
deleted_descriptions = ImageDescriptions.delete().where(
(ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values)
).execute()
deleted_descriptions = (
ImageDescriptions.delete()
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
.execute()
)
if deleted_images or deleted_descriptions:
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}")

View File

@ -220,7 +220,7 @@ class DatabaseActionRecords(BaseDataModel):
chat_id: str,
chat_info_stream_id: str,
chat_info_platform: str,
action_reasoning:str
action_reasoning: str,
):
self.action_id = action_id
self.time = time
@ -235,4 +235,4 @@ class DatabaseActionRecords(BaseDataModel):
self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform
self.action_reasoning = action_reasoning
self.action_reasoning = action_reasoning

View File

@ -317,10 +317,12 @@ class Expression(BaseModel):
class Meta:
table_name = "expression"
class Jargon(BaseModel):
"""
用于存储俚语的模型
"""
content = TextField()
raw_content = TextField(null=True)
type = TextField(null=True)
@ -332,14 +334,16 @@ class Jargon(BaseModel):
is_jargon = BooleanField(null=True) # None表示未判定True表示是黑话False表示不是黑话
last_inference_count = IntegerField(null=True) # 最后一次判定的count值用于避免重启后重复判定
is_complete = BooleanField(default=False) # 是否已完成所有推断count>=100后不再推断
class Meta:
table_name = "jargon"
class ChatHistory(BaseModel):
"""
用于存储聊天历史概括的模型
"""
chat_id = TextField(index=True) # 聊天ID
start_time = DoubleField() # 起始时间
end_time = DoubleField() # 结束时间
@ -350,7 +354,7 @@ class ChatHistory(BaseModel):
summary = TextField() # 概括:对这段话的平文本概括
count = IntegerField(default=0) # 被检索次数
forget_times = IntegerField(default=0) # 被遗忘检查的次数
class Meta:
table_name = "chat_history"
@ -359,6 +363,7 @@ class ThinkingBack(BaseModel):
"""
用于存储记忆检索思考过程的模型
"""
chat_id = TextField(index=True) # 聊天ID
question = TextField() # 提出的问题
context = TextField(null=True) # 上下文信息
@ -367,10 +372,11 @@ class ThinkingBack(BaseModel):
thinking_steps = TextField(null=True) # 思考步骤JSON格式
create_time = DoubleField() # 创建时间
update_time = DoubleField() # 更新时间
class Meta:
table_name = "thinking_back"
MODELS = [
ChatStreams,
LLMUsage,
@ -387,6 +393,7 @@ MODELS = [
ThinkingBack,
]
def create_tables():
"""
创建所有在模型中定义的数据库表

View File

@ -27,7 +27,7 @@ class BotConfig(ConfigBase):
nickname: str
"""昵称"""
platforms: list[str] = field(default_factory=lambda: [])
"""其他平台列表"""
@ -311,16 +311,18 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
max_memory_number: int = 100
"""记忆最大数量"""
memory_build_frequency: int = 1
"""记忆构建频率"""
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
@ -494,13 +496,14 @@ class MoodConfig(ConfigBase):
enable_mood: bool = True
"""是否启用情绪系统"""
mood_update_threshold: float = 1
"""情绪更新阈值,越高,更新越慢"""
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
"""情感特征,影响情绪的变化情况"""
@dataclass
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
@ -644,16 +647,16 @@ class DebugConfig(ConfigBase):
show_prompt: bool = False
"""是否显示prompt"""
show_replyer_prompt: bool = True
"""是否显示回复器prompt"""
show_replyer_reasoning: bool = True
"""是否显示回复器推理"""
show_jargon_prompt: bool = False
"""是否显示jargon相关提示词"""
show_planner_prompt: bool = False
"""是否显示planner相关提示词"""

View File

@ -3,31 +3,30 @@ import difflib
import random
from datetime import datetime
from typing import Optional, List, Dict
from collections import defaultdict
def filter_message_content(content: Optional[str]) -> str:
"""
过滤消息内容移除回复@图片等格式
Args:
content: 原始消息内容
Returns:
str: 过滤后的内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容
content = re.sub(r'@<[^>]*>', '', content)
content = re.sub(r"@<[^>]*>", "", content)
# 移除[picid:...]格式的图片ID
content = re.sub(r'\[picid:[^\]]*\]', '', content)
content = re.sub(r"\[picid:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
return content.strip()
@ -35,11 +34,11 @@ def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度值范围0-1
"""
@ -49,10 +48,10 @@ def calculate_similarity(text1: str, text2: str) -> float:
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
Args:
timestamp: 时间戳
Returns:
str: 格式化后的日期字符串
"""
@ -65,11 +64,11 @@ def format_create_date(timestamp: float) -> str:
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""

View File

@ -1,7 +1,6 @@
import time
import json
import os
from datetime import datetime
from typing import List, Optional, Tuple
import traceback
from src.common.logger import get_logger
@ -158,8 +157,6 @@ class ExpressionLearner:
traceback.print_exc()
return
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
@ -169,7 +166,7 @@ class ExpressionLearner:
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (
@ -186,7 +183,7 @@ class ExpressionLearner:
# 存储到数据库 Expression 表并训练 style_learner
has_new_expressions = False # 记录是否有新的表达方式
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
for (
situation,
style,
@ -195,9 +192,7 @@ class ExpressionLearner:
) in learnt_expressions:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == self.chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
(Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style)
)
if query.exists():
# 表达方式完全相同,只更新时间戳
@ -216,39 +211,37 @@ class ExpressionLearner:
up_content=up_content,
)
has_new_expressions = True
# 训练 style_learnerup_content 和 style 必定存在)
try:
learner.add_style(style, situation)
# 学习映射关系
success = style_learner_manager.learn_mapping(
self.chat_id,
up_content,
style
)
success = style_learner_manager.learn_mapping(self.chat_id, up_content, style)
if success:
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
logger.debug(
f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}"
+ (f" (situation: {situation})" if situation else "")
)
else:
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
# 保存当前聊天室的 style_learner 模型
if has_new_expressions:
try:
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
save_success = learner.save(style_learner_manager.model_save_path)
if save_success:
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
else:
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}")
return learnt_expressions
async def match_expression_context(
@ -334,7 +327,7 @@ class ExpressionLearner:
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
logger.debug(f"match_responses 内容: {match_responses}")
@ -344,12 +337,12 @@ class ExpressionLearner:
if not isinstance(match_response, dict):
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
continue
# 获取表达方式序号
if "expression_pair" not in match_response:
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
continue
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过
@ -367,9 +360,7 @@ class ExpressionLearner:
return matched_expressions
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, str]]]:
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式
Args:
@ -409,7 +400,6 @@ class ExpressionLearner:
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
@ -426,17 +416,17 @@ class ExpressionLearner:
if similarity >= 0.85: # 85%相似度阈值
pos = i
break
if pos is None or pos == 0:
# 没有匹配到目标句或没有上一句,跳过该表达
continue
# 检查目标句是否为空
target_content = bare_lines[pos][1]
if not target_content:
# 目标句为空,跳过该表达
continue
prev_original_idx = bare_lines[pos - 1][0]
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
if not up_content:
@ -449,7 +439,6 @@ class ExpressionLearner:
return filtered_with_up
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组
@ -483,21 +472,21 @@ class ExpressionLearner:
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
"""
为每条消息构建精简文本列表保留到原消息索引的映射
Args:
messages: 消息列表
Returns:
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
"""
bare_lines: List[Tuple[int, str]] = []
for idx, msg in enumerate(messages):
content = msg.processed_plain_text or ""
content = filter_message_content(content)
# 即使content为空也要记录防止错位
bare_lines.append((idx, content))
return bare_lines

View File

@ -1,8 +1,6 @@
import json
import time
import random
import hashlib
import re
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
@ -115,30 +113,31 @@ class ExpressionSelector:
return group_chat_ids
return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
def get_model_predicted_expressions(
self, chat_id: str, target_message: str, total_num: int = 10
) -> List[Dict[str, Any]]:
"""
使用 style_learner 模型预测最合适的表达方式
Args:
chat_id: 聊天室ID
target_message: 目标消息内容
total_num: 需要预测的数量
Returns:
List[Dict[str, Any]]: 预测的表达方式列表
"""
try:
# 过滤目标消息内容,移除回复、表情包等特殊格式
filtered_target_message = filter_message_content(target_message)
logger.info(f"{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}")
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
predicted_expressions = []
# 为每个相关的chat_id进行预测
for related_chat_id in related_chat_ids:
try:
@ -146,59 +145,65 @@ class ExpressionSelector:
best_style, scores = style_learner_manager.predict_style(
related_chat_id, filtered_target_message, top_k=total_num
)
if best_style and scores:
# 获取预测风格的完整信息
learner = style_learner_manager.get_learner(related_chat_id)
style_id, situation = learner.get_style_info(best_style)
if style_id and situation:
# 从数据库查找对应的表达记录
expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) &
(Expression.situation == situation) &
(Expression.style == best_style)
(Expression.chat_id == related_chat_id)
& (Expression.situation == situation)
& (Expression.style == best_style)
)
if expr_query.exists():
expr = expr_query.get()
predicted_expressions.append({
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": filtered_target_message
})
predicted_expressions.append(
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date
if expr.create_date is not None
else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": filtered_target_message,
}
)
else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
logger.warning(
f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式"
)
except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
continue
# 按预测分数排序,取前 total_num 个
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
selected_expressions = predicted_expressions[:total_num]
logger.info(f"{chat_id} 预测到 {len(selected_expressions)} 个表达方式")
return selected_expressions
except Exception as e:
logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择
return self._random_expressions(chat_id, total_num)
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
随机选择表达方式
Args:
chat_id: 聊天室ID
total_num: 需要选择的数量
Returns:
List[Dict[str, Any]]: 随机选择的表达方式列表
"""
@ -207,9 +212,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids))
)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)))
style_exprs = [
{
@ -228,15 +231,14 @@ class ExpressionSelector:
selected_style = weighted_sample(style_exprs, total_num)
else:
selected_style = []
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
return selected_style
except Exception as e:
logger.error(f"随机选择表达方式失败: {e}")
return []
async def select_suitable_expressions(
self,
chat_id: str,
@ -246,13 +248,13 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
根据配置模式选择适合的表达方式
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
@ -263,7 +265,7 @@ class ExpressionSelector:
# 获取配置模式
expression_mode = global_config.expression.mode
if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
@ -284,12 +286,12 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
exp_model模式直接使用模型预测不经过LLM
Args:
chat_id: 聊天流ID
target_message: 目标消息内容
max_num: 最大选择数量
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
@ -297,14 +299,14 @@ class ExpressionSelector:
# 使用模型预测最合适的表达方式
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
return selected_expressions, selected_ids
except Exception as e:
logger.error(f"exp_model模式选择表达方式失败: {e}")
return [], []
@ -318,13 +320,13 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
classic模式随机选择+LLM选择
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
@ -425,17 +427,13 @@ class ExpressionSelector:
updates_by_key[key] = expr
for chat_id, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
(Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
"表达方式激活: 更新last_active_time in db"
)
logger.debug("表达方式激活: 更新last_active_time in db")
init_prompt()

View File

@ -6,18 +6,21 @@ import os
from .tokenizer import Tokenizer
from .online_nb import OnlineNaiveBayes
class ExpressorModel:
"""
直接使用朴素贝叶斯精排可在线学习
支持存储situation字段不参与计算仅与style对应
"""
def __init__(self,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
vocab_size: int = 200000,
use_jieba: bool = True):
def __init__(
self,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
vocab_size: int = 200000,
use_jieba: bool = True,
):
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
self._candidates: Dict[str, str] = {} # cid -> text (style)
@ -28,7 +31,7 @@ class ExpressorModel:
self._candidates[cid] = text
if situation is not None:
self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0
@ -46,7 +49,7 @@ class ExpressorModel:
toks = self.tokenizer.tokenize(text)
if not toks:
return None, {}
if not self._candidates:
return None, {}
@ -58,7 +61,7 @@ class ExpressorModel:
# 取最高分
if not scores:
return None, {}
# 根据k参数限制返回的候选数量
if k is not None and k > 0:
# 按分数降序排序取前k个
@ -81,40 +84,42 @@ class ExpressorModel:
def decay(self, factor: float):
self.nb.decay(factor=factor)
def get_situation(self, cid: str) -> Optional[str]:
"""获取候选对应的situation"""
return self._situations.get(cid)
def get_style(self, cid: str) -> Optional[str]:
"""获取候选对应的style"""
return self._candidates.get(cid)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""获取候选的style和situation信息"""
return self._candidates.get(cid), self._situations.get(cid)
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""获取所有候选的style和situation信息"""
return {cid: (style, self._situations.get(cid))
for cid, style in self._candidates.items()}
return {cid: (style, self._situations.get(cid)) for cid, style in self._candidates.items()}
def save(self, path: str):
"""保存模型"""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
pickle.dump({
"candidates": self._candidates,
"situations": self._situations,
"nb": {
"cls_counts": dict(self.nb.cls_counts),
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"alpha": self.nb.alpha,
"beta": self.nb.beta,
"gamma": self.nb.gamma,
"V": self.nb.V,
}
}, f)
pickle.dump(
{
"candidates": self._candidates,
"situations": self._situations,
"nb": {
"cls_counts": dict(self.nb.cls_counts),
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"alpha": self.nb.alpha,
"beta": self.nb.beta,
"gamma": self.nb.gamma,
"V": self.nb.V,
},
},
f,
)
def load(self, path: str):
"""加载模型"""
@ -133,9 +138,11 @@ class ExpressorModel:
self.nb.V = obj["nb"]["V"]
self.nb._logZ.clear()
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
from collections import defaultdict
outer = defaultdict(lambda: defaultdict(float))
for k, inner in d.items():
outer[k].update(inner)
return outer
return outer

View File

@ -2,6 +2,7 @@ import math
from typing import Dict, List
from collections import defaultdict, Counter
class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha
@ -9,9 +10,9 @@ class OnlineNaiveBayes:
self.gamma = gamma
self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str):
if cid in self._logZ:
@ -57,4 +58,4 @@ class OnlineNaiveBayes:
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)
self._invalidate(cid)

View File

@ -3,17 +3,20 @@ from typing import List, Optional, Set
try:
import jieba
_HAS_JIEBA = True
except Exception:
_HAS_JIEBA = False
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
# 匹配纯符号的正则表达式
_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$')
_SYMBOL_RE = re.compile(r"^[^\w\u4e00-\u9fff]+$")
def simple_en_tokenize(text: str) -> List[str]:
return _WORD_RE.findall(text.lower())
class Tokenizer:
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
self.stopwords = stopwords or set()
@ -28,4 +31,4 @@ class Tokenizer:
else:
toks = simple_en_tokenize(text)
# 过滤掉纯符号和停用词
return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)]
return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)]

View File

@ -22,42 +22,42 @@ class StyleLearner:
学习从up_content到style的映射关系
支持动态管理风格集合无数量上限
"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True
"use_jieba": True,
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0 # 下一个可用的style_id
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": defaultdict(int),
"last_update": None,
"style_usage_frequency": defaultdict(int) # 风格使用频率
"style_usage_frequency": defaultdict(int), # 风格使用频率
}
def add_style(self, style: str, situation: str = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 对应的situation文本可选
Returns:
bool: 添加是否成功
"""
@ -66,35 +66,37 @@ class StyleLearner:
if style in self.style_to_id:
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
return True
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
(f", situation: '{situation}'" if situation else ""))
logger.info(
f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})"
+ (f", situation: '{situation}'" if situation else "")
)
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
return False
def remove_style(self, style: str) -> bool:
"""
删除一个风格
Args:
style: 要删除的风格文本
Returns:
bool: 删除是否成功
"""
@ -102,33 +104,33 @@ class StyleLearner:
if style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
return False
style_id = self.style_to_id[style]
# 从映射中删除
del self.style_to_id[style]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从expressor模型中删除通过重新构建
self._rebuild_expressor()
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
return False
def update_style(self, old_style: str, new_style: str) -> bool:
"""
更新一个风格
Args:
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
@ -136,37 +138,37 @@ class StyleLearner:
if old_style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
return False
if new_style in self.style_to_id and new_style != old_style:
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
return False
style_id = self.style_to_id[old_style]
# 更新映射
del self.style_to_id[old_style]
self.style_to_id[new_style] = style_id
self.id_to_style[style_id] = new_style
# 更新expressor模型保留原有的situation
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, new_style, situation)
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
return False
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
"""
批量添加风格
Args:
styles: 风格文本列表
situations: 对应的situation文本列表可选
Returns:
int: 成功添加的数量
"""
@ -175,55 +177,55 @@ class StyleLearner:
situation = situations[i] if situations and i < len(situations) else None
if self.add_style(style, situation):
success_count += 1
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
return success_count
def get_all_styles(self) -> List[str]:
"""获取所有已注册的风格"""
return list(self.style_to_id.keys())
def get_style_count(self) -> int:
"""获取当前风格数量"""
return len(self.style_to_id)
def get_situation(self, style: str) -> Optional[str]:
"""
获取风格对应的situation
Args:
style: 风格文本
Returns:
Optional[str]: 对应的situation如果不存在则返回None
"""
if style not in self.style_to_id:
return None
style_id = self.style_to_id[style]
return self.id_to_situation.get(style_id)
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取风格的完整信息
Args:
style: 风格文本
Returns:
Tuple[Optional[str], Optional[str]]: (style_id, situation)
"""
if style not in self.style_to_id:
return None, None
style_id = self.style_to_id[style]
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""
获取所有风格的完整信息
Returns:
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
"""
@ -232,32 +234,32 @@ class StyleLearner:
situation = self.id_to_situation.get(style_id)
result[style] = (style_id, situation)
return result
def _rebuild_expressor(self):
"""重新构建expressor模型删除风格后使用"""
try:
# 重新创建expressor
self.expressor = ExpressorModel(**self.model_config)
# 重新添加所有风格和situation
for style_id, style_text in self.id_to_style.items():
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, style_text, situation)
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
except Exception as e:
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
如果style不存在会自动添加
Args:
up_content: 输入内容
style: 对应的style文本
Returns:
bool: 学习是否成功
"""
@ -267,71 +269,71 @@ class StyleLearner:
if not self.add_style(style):
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["style_usage_frequency"][style] += 1
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
traceback.print_exc()
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style文本, 所有候选的分数]
"""
try:
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
return best_style, style_scores
except Exception as e:
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
traceback.print_exc()
return None, {}
def decay_learning(self, factor: Optional[float] = None) -> None:
"""
对学习到的知识进行衰减遗忘
Args:
factor: 衰减因子None则使用配置中的gamma
"""
self.expressor.decay(factor)
logger.debug(f"[{self.chat_id}] 执行知识衰减")
def get_stats(self) -> Dict:
"""获取学习统计信息"""
return {
@ -341,20 +343,20 @@ class StyleLearner:
"style_counts": dict(self.learning_stats["style_counts"]),
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
"last_update": self.learning_stats["last_update"],
"all_styles": list(self.style_to_id.keys())
"all_styles": list(self.style_to_id.keys()),
}
def save(self, base_path: str) -> bool:
"""
保存模型到文件
Args:
base_path: 基础路径实际文件为 {base_path}/{chat_id}_style_model.pkl
"""
try:
os.makedirs(base_path, exist_ok=True)
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
# 保存模型和统计信息
save_data = {
"model_config": self.model_config,
@ -362,43 +364,43 @@ class StyleLearner:
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"learning_stats": self.learning_stats
"learning_stats": self.learning_stats,
}
# 先保存expressor模型
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
self.expressor.save(expressor_path)
# 保存其他数据
with open(file_path, "wb") as f:
pickle.dump(save_data, f)
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载模型
Args:
base_path: 基础路径
"""
try:
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
return False
# 加载其他数据
with open(file_path, "rb") as f:
save_data = pickle.load(f)
# 恢复配置和状态
self.model_config = save_data["model_config"]
self.style_to_id = save_data["style_to_id"]
@ -406,14 +408,14 @@ class StyleLearner:
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
self.next_style_id = save_data["next_style_id"]
self.learning_stats = save_data["learning_stats"]
# 重新创建expressor并加载
self.expressor = ExpressorModel(**self.model_config)
self.expressor.load(expressor_path)
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
return False
@ -425,156 +427,156 @@ class StyleLearnerManager:
为每个chat_id维护独立的StyleLearner实例
每个chat_id可以动态管理自己的风格集合无数量上限
"""
def __init__(self, model_save_path: str = "data/style_models"):
self.model_save_path = model_save_path
self.learners: Dict[str, StyleLearner] = {}
# 自动保存配置
self.auto_save_interval = 300 # 5分钟
self._auto_save_task: Optional[asyncio.Task] = None
logger.info("StyleLearnerManager 已初始化")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置None则使用默认配置
Returns:
StyleLearner实例
"""
if chat_id not in self.learners:
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
return self.learners[chat_id]
def add_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id添加风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 添加是否成功
"""
learner = self.get_learner(chat_id)
return learner.add_style(style)
def remove_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id删除风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 删除是否成功
"""
learner = self.get_learner(chat_id)
return learner.remove_style(style)
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
"""
为指定chat_id更新风格
Args:
chat_id: 聊天室ID
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
learner = self.get_learner(chat_id)
return learner.update_style(old_style, new_style)
def get_chat_styles(self, chat_id: str) -> List[str]:
"""
获取指定chat_id的所有风格
Args:
chat_id: 聊天室ID
Returns:
List[str]: 风格列表
"""
learner = self.get_learner(chat_id)
return learner.get_all_styles()
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 输入内容
style: 对应的style
Returns:
bool: 学习是否成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
预测最合适的style
Args:
chat_id: 聊天室ID
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style, 所有候选分数]
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def decay_all_learners(self, factor: Optional[float] = None) -> None:
"""
对所有学习器执行衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.decay_learning(factor)
logger.info("已对所有学习器执行衰减")
def get_all_stats(self) -> Dict[str, Dict]:
"""获取所有学习器的统计信息"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
def save_all_models(self) -> bool:
"""保存所有模型"""
success_count = 0
for learner in self.learners.values():
if learner.save(self.model_save_path):
success_count += 1
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
return success_count == len(self.learners)
def load_all_models(self) -> int:
"""加载所有已保存的模型"""
if not os.path.exists(self.model_save_path):
return 0
loaded_count = 0
for filename in os.listdir(self.model_save_path):
if filename.endswith("_style_model.pkl"):
@ -583,16 +585,16 @@ class StyleLearnerManager:
if learner.load(self.model_save_path):
self.learners[chat_id] = learner
loaded_count += 1
logger.info(f"已加载 {loaded_count} 个模型")
return loaded_count
async def start_auto_save(self) -> None:
"""启动自动保存任务"""
if self._auto_save_task is None or self._auto_save_task.done():
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
logger.info("已启动自动保存任务")
async def stop_auto_save(self) -> None:
"""停止自动保存任务"""
if self._auto_save_task and not self._auto_save_task.done():
@ -602,7 +604,7 @@ class StyleLearnerManager:
except asyncio.CancelledError:
pass
logger.info("已停止自动保存任务")
async def _auto_save_loop(self) -> None:
"""自动保存循环"""
while True:

View File

@ -3,5 +3,3 @@ from .jargon_miner import extract_and_store_jargon
__all__ = [
"extract_and_store_jargon",
]

View File

@ -250,7 +250,7 @@ def _build_stream_api_resp(
if fr:
reason = str(fr)
break
if str(reason).endswith("MAX_TOKENS"):
has_visible_output = bool(resp.content and resp.content.strip())
if has_visible_output:
@ -281,8 +281,8 @@ async def _default_stream_response_handler(
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
resp = APIResponse()
resp = APIResponse()
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
@ -298,7 +298,7 @@ async def _default_stream_response_handler(
chunk,
_fc_delta_buffer,
_tool_calls_buffer,
resp=resp,
resp=resp,
)
if chunk.usage_metadata:
@ -314,7 +314,7 @@ async def _default_stream_response_handler(
_fc_delta_buffer,
_tool_calls_buffer,
last_resp=last_resp,
resp=resp,
resp=resp,
), _usage_record
except Exception:
# 确保缓冲区被关闭

View File

@ -256,7 +256,7 @@ def _build_stream_api_resp(
# 检查 max_tokens 截断(流式的告警改由处理函数统一输出,这里不再输出)
# 保留 finish_reason 仅用于上层判断
if not resp.content and not resp.tool_calls:
raise EmptyResponseException()
@ -310,7 +310,7 @@ async def _default_stream_response_handler(
if hasattr(event.choices[0], "finish_reason") and event.choices[0].finish_reason:
finish_reason = event.choices[0].finish_reason
if hasattr(event, "model") and event.model and not _model_name:
_model_name = event.model # 记录模型名
@ -358,10 +358,7 @@ async def _default_stream_response_handler(
model_dbg = None
# 统一日志格式
logger.info(
"模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整"
% (model_dbg or "")
)
logger.info("模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整" % (model_dbg or ""))
return resp, _usage_record
except Exception:
@ -404,9 +401,7 @@ def _default_normal_response_parser(
raw_snippet = str(resp)[:300]
except Exception:
raw_snippet = "<unserializable>"
logger.debug(
f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}"
)
logger.debug(f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}")
except Exception:
# 日志采集失败不应影响控制流
pass
@ -464,14 +459,11 @@ def _default_normal_response_parser(
# print(resp)
_model_name = resp.model
# 统一日志格式
logger.info(
"模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整"
% (_model_name or "")
)
logger.info("模型%s因为超过最大max_token限制可能仅输出部分内容可视情况调整" % (_model_name or ""))
return api_response, _usage_record
except Exception as e:
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
if not api_response.content and not api_response.tool_calls:
raise EmptyResponseException()

View File

@ -328,9 +328,7 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}"
)
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}")
await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e:
@ -340,9 +338,7 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}"
)
logger.warning(f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}")
await asyncio.sleep(api_provider.retry_interval)
except RespNotOkException as e:

View File

@ -5,6 +5,7 @@ from maim_message import MessageServer
from src.common.remote import TelemetryHeartBeatTask
from src.manager.async_task_manager import async_task_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
# from src.chat.utils.token_statistics import TokenStatisticsTask
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
@ -70,9 +71,10 @@ class MainSystem:
# 添加遥测心跳任务
await async_task_manager.add_task(TelemetryHeartBeatTask())
# 添加记忆遗忘任务
from src.chat.utils.memory_forget_task import MemoryForgetTask
await async_task_manager.add_task(MemoryForgetTask())
# 启动API服务器
@ -106,7 +108,6 @@ class MainSystem:
self.app.register_message_handler(chat_bot.message_process)
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
# 触发 ON_START 事件
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType

View File

@ -9,16 +9,13 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_jargon(
keyword: str,
chat_id: str
) -> str:
async def query_jargon(keyword: str, chat_id: str) -> str:
"""根据关键词在jargon库中查询
Args:
keyword: 关键词黑话/俚语/缩写
chat_id: 聊天ID
Returns:
str: 查询结果
"""
@ -26,29 +23,17 @@ async def query_jargon(
content = str(keyword).strip()
if not content:
return "关键词为空"
# 先尝试精确匹配
results = search_jargon(
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=False
)
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索
if not results:
results = search_jargon(
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=True
)
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
is_fuzzy_match = True
if results:
# 如果是模糊匹配显示找到的实际jargon内容
if is_fuzzy_match:
@ -71,11 +56,11 @@ async def query_jargon(
output = "".join(output_parts) if len(output_parts) > 1 else output_parts[0]
logger.info(f"在jargon库中找到匹配当前会话或全局精确匹配: {content},找到{len(results)}条结果")
return output
# 未命中
logger.info(f"在jargon库中未找到匹配当前会话或全局精确匹配和模糊搜索都未找到: {content}")
return f"未在jargon库中找到'{content}'的解释"
except Exception as e:
logger.error(f"查询jargon失败: {e}")
return f"查询失败: {str(e)}"
@ -86,14 +71,6 @@ def register_tool():
register_memory_retrieval_tool(
name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索默认会先尝试精确匹配如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
parameters=[
{
"name": "keyword",
"type": "string",
"description": "关键词(黑话/俚语/缩写)",
"required": True
}
],
execute_func=query_jargon
parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
execute_func=query_jargon,
)

View File

@ -12,17 +12,13 @@ logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self,
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
):
"""
初始化工具
Args:
name: 工具名称
description: 工具描述
@ -33,7 +29,7 @@ class MemoryRetrievalTool:
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt"""
param_descriptions = []
@ -44,10 +40,10 @@ class MemoryRetrievalTool:
required = param.get("required", True)
required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
@ -97,10 +93,10 @@ class MemoryRetrievalTool:
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self):
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
if tool.name in self.tools:
@ -108,22 +104,22 @@ class MemoryRetrievalToolRegistry:
return
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具"""
return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具"""
return self.tools.copy()
def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt"""
descriptions = []
for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt已废弃保留用于兼容"""
action_types = [tool.name for tool in self.tools.values()]
@ -145,13 +141,10 @@ _tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
) -> None:
"""注册记忆检索工具的便捷函数
Args:
name: 工具名称
description: 工具描述
@ -165,4 +158,3 @@ def register_memory_retrieval_tool(
def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例"""
return _tool_registry

View File

@ -1,10 +1,7 @@
import math
import random
import time
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive

View File

@ -6,7 +6,9 @@ logger = get_logger("frequency_api")
def get_current_talk_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
return frequency_control_manager.get_or_create_frequency_control(
chat_id
).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:

View File

@ -109,7 +109,7 @@ def get_messages_by_time_in_chat(
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
filter_command=filter_command
filter_command=filter_command,
)

View File

@ -12,11 +12,11 @@ logger = get_logger("tool_api")
def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]:
"""获取公开工具实例
Args:
tool_name: 工具名称
chat_stream: 聊天流对象用于传递聊天上下文信息
Returns:
Optional[BaseTool]: 工具实例如果未找到则返回None
"""

View File

@ -77,7 +77,7 @@ class BaseAction(ABC):
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
"""NORMAL模式下的激活类型"""
self.activation_type = getattr(self.__class__, "activation_type")
self.activation_type = self.__class__.activation_type
"""激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率"""
@ -108,21 +108,16 @@ class BaseAction(ABC):
self.is_group = False
self.target_id = None
self.group_id = (
str(self.action_message.chat_info.group_info.group_id)
if self.action_message.chat_info.group_info
else None
str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None
)
self.group_name = (
self.action_message.chat_info.group_info.group_name
if self.action_message.chat_info.group_info
else None
self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None
)
self.user_id = str(self.action_message.user_info.user_id)
self.user_nickname = self.action_message.user_info.user_nickname
if self.group_id:
self.is_group = True
self.target_id = self.group_id
@ -132,7 +127,6 @@ class BaseAction(ABC):
self.target_id = self.user_id
self.log_prefix = f"[{self.user_nickname} 的 私聊]"
logger.debug(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
)
@ -448,7 +442,6 @@ class BaseAction(ABC):
wait_start_time = asyncio.get_event_loop().time()
while True:
# 检查新消息
current_time = time.time()
new_message_count = message_api.count_new_messages(
@ -497,7 +490,7 @@ class BaseAction(ABC):
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
# 获取focus_activation_type和normal_activation_type
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
_normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
# 处理activation_type如果插件中声明了就用插件的值否则默认使用focus_activation_type
activation_type = getattr(cls, "activation_type", focus_activation_type)

View File

@ -34,17 +34,17 @@ class BaseTool(ABC):
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None):
"""初始化工具基类
Args:
plugin_config: 插件配置字典
chat_stream: 聊天流对象用于获取聊天上下文信息
"""
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息与BaseAction保持一致
# =============================================================================
# 获取聊天流对象
self.chat_stream = chat_stream
self.chat_id = self.chat_stream.stream_id if self.chat_stream else None

View File

@ -346,9 +346,7 @@ class EventsManager:
if not isinstance(result, tuple) or len(result) != 5:
if isinstance(result, tuple):
annotated = ", ".join(
f"{name}={val!r}" for name, val in zip(expected_fields, result)
)
annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False))
actual_desc = f"{len(result)} 个元素 ({annotated})"
else:
actual_desc = f"非 tuple 类型: {type(result)}"
@ -380,7 +378,6 @@ class EventsManager:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
return True, None # 发生异常时默认不中断其他处理
def _task_done_callback(
self,
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],

View File

@ -189,9 +189,8 @@ class ToolExecutor:
tool_info["content"] = str(content)
# 空内容直接跳过(空字符串、全空白字符串、空列表/空元组)
content_check = tool_info["content"]
if (
(isinstance(content_check, str) and not content_check.strip())
or (isinstance(content_check, (list, tuple)) and len(content_check) == 0)
if (isinstance(content_check, str) and not content_check.strip()) or (
isinstance(content_check, (list, tuple)) and len(content_check) == 0
):
logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示")
continue

View File

@ -8,29 +8,30 @@ import sys
import os
from pprint import pprint
def view_pkl_file(file_path):
"""查看 pkl 文件内容"""
if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}")
return
try:
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
data = pickle.load(f)
print(f"📁 文件: {file_path}")
print(f"📊 数据类型: {type(data)}")
print("=" * 50)
if isinstance(data, dict):
print("🔑 字典键:")
for key in data.keys():
print(f" - {key}: {type(data[key])}")
print()
print("📋 详细内容:")
pprint(data, width=120, depth=10)
elif isinstance(data, list):
print(f"📝 列表长度: {len(data)}")
if data:
@ -38,16 +39,16 @@ def view_pkl_file(file_path):
print("📋 前几个元素:")
for i, item in enumerate(data[:3]):
print(f" [{i}]: {item}")
else:
print("📋 内容:")
pprint(data, width=120, depth=10)
# 如果是 expressor 模型,特别显示 token_counts 的详细信息
if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']:
print("\n" + "="*50)
if isinstance(data, dict) and "nb" in data and "token_counts" in data["nb"]:
print("\n" + "=" * 50)
print("🔍 详细词汇统计 (token_counts):")
token_counts = data['nb']['token_counts']
token_counts = data["nb"]["token_counts"]
for style_id, tokens in token_counts.items():
print(f"\n📝 {style_id}:")
if tokens:
@ -59,18 +60,20 @@ def view_pkl_file(file_path):
print(f" ... 还有 {len(sorted_tokens) - 10} 个词")
else:
print(" (无词汇数据)")
except Exception as e:
print(f"❌ 读取文件失败: {e}")
def main():
if len(sys.argv) != 2:
print("用法: python view_pkl.py <pkl文件路径>")
print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl")
return
file_path = sys.argv[1]
view_pkl_file(file_path)
if __name__ == "__main__":
main()

View File

@ -7,57 +7,60 @@ import pickle
import sys
import os
def view_token_counts(file_path):
"""查看 expressor.pkl 文件中的词汇统计"""
if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}")
return
try:
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
data = pickle.load(f)
print(f"📁 文件: {file_path}")
print("=" * 60)
if 'nb' not in data or 'token_counts' not in data['nb']:
if "nb" not in data or "token_counts" not in data["nb"]:
print("❌ 这不是一个 expressor 模型文件")
return
token_counts = data['nb']['token_counts']
candidates = data.get('candidates', {})
token_counts = data["nb"]["token_counts"]
candidates = data.get("candidates", {})
print(f"🎯 找到 {len(token_counts)} 个风格")
print("=" * 60)
for style_id, tokens in token_counts.items():
style_text = candidates.get(style_id, "未知风格")
print(f"\n📝 {style_id}: {style_text}")
print(f"📊 词汇数量: {len(tokens)}")
if tokens:
# 按词频排序
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
print("🔤 词汇统计 (按频率排序):")
for i, (word, count) in enumerate(sorted_tokens):
print(f" {i+1:2d}. '{word}': {count}")
print(f" {i + 1:2d}. '{word}': {count}")
else:
print(" (无词汇数据)")
print("-" * 40)
except Exception as e:
print(f"❌ 读取文件失败: {e}")
def main():
if len(sys.argv) != 2:
print("用法: python view_tokens.py <expressor.pkl文件路径>")
print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl")
return
file_path = sys.argv[1]
view_token_counts(file_path)
if __name__ == "__main__":
main()