mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'dev' of https://github.com/Mai-with-u/MaiBot into dev
commit
d306e40db0
2
bot.py
2
bot.py
|
|
@ -30,7 +30,7 @@ else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
|
||||||
|
|
||||||
initialize_logging()
|
initialize_logging()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,11 @@
|
||||||
from typing import List, Tuple, Type, Optional
|
from typing import List, Tuple, Type, Optional
|
||||||
from src.plugin_system import (
|
from src.plugin_system import BasePlugin, register_plugin, BaseCommand, ComponentInfo, ConfigField
|
||||||
BasePlugin,
|
|
||||||
register_plugin,
|
|
||||||
BaseCommand,
|
|
||||||
ComponentInfo,
|
|
||||||
ConfigField
|
|
||||||
)
|
|
||||||
from src.plugin_system.apis import send_api, frequency_api
|
from src.plugin_system.apis import send_api, frequency_api
|
||||||
|
|
||||||
|
|
||||||
class SetTalkFrequencyCommand(BaseCommand):
|
class SetTalkFrequencyCommand(BaseCommand):
|
||||||
"""设置当前聊天的talk_frequency值"""
|
"""设置当前聊天的talk_frequency值"""
|
||||||
|
|
||||||
command_name = "set_talk_frequency"
|
command_name = "set_talk_frequency"
|
||||||
command_description = "设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>"
|
command_description = "设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>"
|
||||||
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
|
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
|
||||||
|
|
@ -43,7 +39,7 @@ class SetTalkFrequencyCommand(BaseCommand):
|
||||||
await send_api.text_to_stream(
|
await send_api.text_to_stream(
|
||||||
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
|
f"已设置当前聊天的talk_frequency调整值为: {value}\n当前talk_value: {final_value:.2f}\n发言频率调整: {adjust_value:.2f}\n基础值: {base_value:.2f}",
|
||||||
chat_id,
|
chat_id,
|
||||||
storage_message=False
|
storage_message=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return True, None, False
|
return True, None, False
|
||||||
|
|
@ -60,6 +56,7 @@ class SetTalkFrequencyCommand(BaseCommand):
|
||||||
|
|
||||||
class ShowFrequencyCommand(BaseCommand):
|
class ShowFrequencyCommand(BaseCommand):
|
||||||
"""显示当前聊天的频率控制状态"""
|
"""显示当前聊天的频率控制状态"""
|
||||||
|
|
||||||
command_name = "show_frequency"
|
command_name = "show_frequency"
|
||||||
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
|
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
|
||||||
command_pattern = r"^/chat\s+(?:show|s)$"
|
command_pattern = r"^/chat\s+(?:show|s)$"
|
||||||
|
|
@ -116,11 +113,7 @@ class BetterFrequencyPlugin(BasePlugin):
|
||||||
config_file_name: str = "config.toml"
|
config_file_name: str = "config.toml"
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
config_section_descriptions = {
|
config_section_descriptions = {"plugin": "插件基本信息", "frequency": "频率控制配置", "features": "功能开关配置"}
|
||||||
"plugin": "插件基本信息",
|
|
||||||
"frequency": "频率控制配置",
|
|
||||||
"features": "功能开关配置"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 配置Schema定义
|
# 配置Schema定义
|
||||||
config_schema: dict = {
|
config_schema: dict = {
|
||||||
|
|
@ -141,10 +134,11 @@ class BetterFrequencyPlugin(BasePlugin):
|
||||||
|
|
||||||
# 根据配置决定是否注册命令组件
|
# 根据配置决定是否注册命令组件
|
||||||
if self.config.get("features", {}).get("enable_commands", True):
|
if self.config.get("features", {}).get("enable_commands", True):
|
||||||
components.extend([
|
components.extend(
|
||||||
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
|
[
|
||||||
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
|
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
|
||||||
])
|
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,16 @@ import sys
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.message_repository import find_messages
|
||||||
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
|
|
||||||
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
||||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
if PROJECT_ROOT not in sys.path:
|
if PROJECT_ROOT not in sys.path:
|
||||||
sys.path.insert(0, PROJECT_ROOT)
|
sys.path.insert(0, PROJECT_ROOT)
|
||||||
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
from src.common.message_repository import find_messages
|
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
|
||||||
|
|
||||||
|
|
||||||
SECONDS_5_MINUTES = 5 * 60
|
SECONDS_5_MINUTES = 5 * 60
|
||||||
|
|
@ -30,13 +31,13 @@ def clean_output_text(text: str) -> str:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# 移除表情包内容:[表情包:...]
|
# 移除表情包内容:[表情包:...]
|
||||||
text = re.sub(r'\[表情包:[^\]]*\]', '', text)
|
text = re.sub(r"\[表情包:[^\]]*\]", "", text)
|
||||||
|
|
||||||
# 移除回复内容:[回复...],说:... 的完整模式
|
# 移除回复内容:[回复...],说:... 的完整模式
|
||||||
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text)
|
text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
|
||||||
|
|
||||||
# 清理多余的空格和换行
|
# 清理多余的空格和换行
|
||||||
text = re.sub(r'\s+', ' ', text).strip()
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
@ -89,7 +90,7 @@ def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[Databa
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
groups.setdefault(msg.chat_id, []).append(msg)
|
groups.setdefault(msg.chat_id, []).append(msg)
|
||||||
# 保证每个分组内按时间升序
|
# 保证每个分组内按时间升序
|
||||||
for chat_id, msgs in groups.items():
|
for _chat_id, msgs in groups.items():
|
||||||
msgs.sort(key=lambda m: m.time or 0)
|
msgs.sort(key=lambda m: m.time or 0)
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
|
|
@ -170,8 +171,8 @@ def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseM
|
||||||
continue
|
continue
|
||||||
|
|
||||||
last = bucket[-1]
|
last = bucket[-1]
|
||||||
same_user = (msg.user_info.user_id == last.user_info.user_id)
|
same_user = msg.user_info.user_id == last.user_info.user_id
|
||||||
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
|
close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
|
||||||
|
|
||||||
if same_user and close_enough:
|
if same_user and close_enough:
|
||||||
bucket.append(msg)
|
bucket.append(msg)
|
||||||
|
|
@ -209,13 +210,11 @@ def build_pairs_for_chat(
|
||||||
|
|
||||||
for merged_idx, merged_msg in enumerate(merged_messages):
|
for merged_idx, merged_msg in enumerate(merged_messages):
|
||||||
# 找到这个合并消息对应的第一个原始消息
|
# 找到这个合并消息对应的第一个原始消息
|
||||||
while (original_idx < n_original and
|
while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
|
||||||
original_messages[original_idx].time < merged_msg.time):
|
|
||||||
original_idx += 1
|
original_idx += 1
|
||||||
|
|
||||||
# 如果找到了时间匹配的原始消息,建立映射
|
# 如果找到了时间匹配的原始消息,建立映射
|
||||||
if (original_idx < n_original and
|
if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
|
||||||
original_messages[original_idx].time == merged_msg.time):
|
|
||||||
merged_to_original_map[merged_idx] = original_idx
|
merged_to_original_map[merged_idx] = original_idx
|
||||||
|
|
||||||
for merged_idx in range(n_merged):
|
for merged_idx in range(n_merged):
|
||||||
|
|
@ -266,7 +265,7 @@ def build_pairs(
|
||||||
groups = group_by_chat(messages)
|
groups = group_by_chat(messages)
|
||||||
|
|
||||||
all_pairs: List[Tuple[str, str, str]] = []
|
all_pairs: List[Tuple[str, str, str]] = []
|
||||||
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||||
# 对消息进行合并,用于output
|
# 对消息进行合并,用于output
|
||||||
merged = merge_adjacent_same_user(msgs)
|
merged = merge_adjacent_same_user(msgs)
|
||||||
# 传递原始消息和合并后消息,input使用原始消息,output使用合并后消息
|
# 传递原始消息和合并后消息,input使用原始消息,output使用合并后消息
|
||||||
|
|
@ -385,5 +384,3 @@ def run_interactive() -> int:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import time
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
@ -6,16 +5,17 @@ import matplotlib.dates as mdates
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from src.common.database.database_model import Expression, ChatStreams
|
||||||
|
|
||||||
# Add project root to Python path
|
# Add project root to Python path
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
from src.common.database.database_model import Expression, ChatStreams
|
|
||||||
|
|
||||||
# 设置中文字体
|
# 设置中文字体
|
||||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
|
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||||
plt.rcParams['axes.unicode_minus'] = False
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
|
|
||||||
|
|
||||||
def get_chat_name(chat_id: str) -> str:
|
def get_chat_name(chat_id: str) -> str:
|
||||||
|
|
@ -45,12 +45,7 @@ def get_expression_data() -> List[Tuple[float, float, str, str]]:
|
||||||
if expr.create_date is None:
|
if expr.create_date is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data.append((
|
data.append((expr.create_date, expr.count, expr.chat_id, expr.type))
|
||||||
expr.create_date,
|
|
||||||
expr.count,
|
|
||||||
expr.chat_id,
|
|
||||||
expr.type
|
|
||||||
))
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
@ -64,8 +59,8 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
|
||||||
# 分离数据
|
# 分离数据
|
||||||
create_dates = [item[0] for item in data]
|
create_dates = [item[0] for item in data]
|
||||||
counts = [item[1] for item in data]
|
counts = [item[1] for item in data]
|
||||||
chat_ids = [item[2] for item in data]
|
_chat_ids = [item[2] for item in data]
|
||||||
expression_types = [item[3] for item in data]
|
_expression_types = [item[3] for item in data]
|
||||||
|
|
||||||
# 转换时间戳为datetime对象
|
# 转换时间戳为datetime对象
|
||||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||||
|
|
@ -73,15 +68,15 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
|
||||||
# 计算时间跨度,自动调整显示格式
|
# 计算时间跨度,自动调整显示格式
|
||||||
time_span = max(dates) - min(dates)
|
time_span = max(dates) - min(dates)
|
||||||
if time_span.days > 30: # 超过30天,按月显示
|
if time_span.days > 30: # 超过30天,按月显示
|
||||||
date_format = '%Y-%m-%d'
|
date_format = "%Y-%m-%d"
|
||||||
major_locator = mdates.MonthLocator()
|
major_locator = mdates.MonthLocator()
|
||||||
minor_locator = mdates.DayLocator(interval=7)
|
minor_locator = mdates.DayLocator(interval=7)
|
||||||
elif time_span.days > 7: # 超过7天,按天显示
|
elif time_span.days > 7: # 超过7天,按天显示
|
||||||
date_format = '%Y-%m-%d'
|
date_format = "%Y-%m-%d"
|
||||||
major_locator = mdates.DayLocator(interval=1)
|
major_locator = mdates.DayLocator(interval=1)
|
||||||
minor_locator = mdates.HourLocator(interval=12)
|
minor_locator = mdates.HourLocator(interval=12)
|
||||||
else: # 7天内,按小时显示
|
else: # 7天内,按小时显示
|
||||||
date_format = '%Y-%m-%d %H:%M'
|
date_format = "%Y-%m-%d %H:%M"
|
||||||
major_locator = mdates.HourLocator(interval=6)
|
major_locator = mdates.HourLocator(interval=6)
|
||||||
minor_locator = mdates.HourLocator(interval=1)
|
minor_locator = mdates.HourLocator(interval=1)
|
||||||
|
|
||||||
|
|
@ -89,12 +84,12 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
|
||||||
fig, ax = plt.subplots(figsize=(12, 8))
|
fig, ax = plt.subplots(figsize=(12, 8))
|
||||||
|
|
||||||
# 创建散点图
|
# 创建散点图
|
||||||
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis')
|
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap="viridis")
|
||||||
|
|
||||||
# 设置标签和标题
|
# 设置标签和标题
|
||||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||||
ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold')
|
ax.set_title("表达式使用次数随时间分布散点图", fontsize=14, fontweight="bold")
|
||||||
|
|
||||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||||
|
|
@ -107,13 +102,13 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
|
||||||
|
|
||||||
# 添加颜色条
|
# 添加颜色条
|
||||||
cbar = plt.colorbar(scatter)
|
cbar = plt.colorbar(scatter)
|
||||||
cbar.set_label('数据点顺序', fontsize=10)
|
cbar.set_label("数据点顺序", fontsize=10)
|
||||||
|
|
||||||
# 调整布局
|
# 调整布局
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
# 显示统计信息
|
# 显示统计信息
|
||||||
print(f"\n=== 数据统计 ===")
|
print("\n=== 数据统计 ===")
|
||||||
print(f"总数据点数量: {len(data)}")
|
print(f"总数据点数量: {len(data)}")
|
||||||
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')} 到 {max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
|
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')} 到 {max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
print(f"使用次数范围: {min(counts):.1f} 到 {max(counts):.1f}")
|
print(f"使用次数范围: {min(counts):.1f} 到 {max(counts):.1f}")
|
||||||
|
|
@ -122,7 +117,7 @@ def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: st
|
||||||
|
|
||||||
# 保存图片
|
# 保存图片
|
||||||
if save_path:
|
if save_path:
|
||||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||||
print(f"\n散点图已保存到: {save_path}")
|
print(f"\n散点图已保存到: {save_path}")
|
||||||
|
|
||||||
# 显示图片
|
# 显示图片
|
||||||
|
|
@ -147,15 +142,15 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
|
||||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||||
time_span = max(all_dates) - min(all_dates)
|
time_span = max(all_dates) - min(all_dates)
|
||||||
if time_span.days > 30: # 超过30天,按月显示
|
if time_span.days > 30: # 超过30天,按月显示
|
||||||
date_format = '%Y-%m-%d'
|
date_format = "%Y-%m-%d"
|
||||||
major_locator = mdates.MonthLocator()
|
major_locator = mdates.MonthLocator()
|
||||||
minor_locator = mdates.DayLocator(interval=7)
|
minor_locator = mdates.DayLocator(interval=7)
|
||||||
elif time_span.days > 7: # 超过7天,按天显示
|
elif time_span.days > 7: # 超过7天,按天显示
|
||||||
date_format = '%Y-%m-%d'
|
date_format = "%Y-%m-%d"
|
||||||
major_locator = mdates.DayLocator(interval=1)
|
major_locator = mdates.DayLocator(interval=1)
|
||||||
minor_locator = mdates.HourLocator(interval=12)
|
minor_locator = mdates.HourLocator(interval=12)
|
||||||
else: # 7天内,按小时显示
|
else: # 7天内,按小时显示
|
||||||
date_format = '%Y-%m-%d %H:%M'
|
date_format = "%Y-%m-%d %H:%M"
|
||||||
major_locator = mdates.HourLocator(interval=6)
|
major_locator = mdates.HourLocator(interval=6)
|
||||||
minor_locator = mdates.HourLocator(interval=1)
|
minor_locator = mdates.HourLocator(interval=1)
|
||||||
|
|
||||||
|
|
@ -174,14 +169,21 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
|
||||||
# 截断过长的聊天名称
|
# 截断过长的聊天名称
|
||||||
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
|
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
|
||||||
|
|
||||||
ax.scatter(dates, counts, alpha=0.7, s=40,
|
ax.scatter(
|
||||||
c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)",
|
dates,
|
||||||
edgecolors='black', linewidth=0.5)
|
counts,
|
||||||
|
alpha=0.7,
|
||||||
|
s=40,
|
||||||
|
c=[colors[i]],
|
||||||
|
label=f"{display_name} ({len(chat_data)}个)",
|
||||||
|
edgecolors="black",
|
||||||
|
linewidth=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
# 设置标签和标题
|
# 设置标签和标题
|
||||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||||
ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold')
|
ax.set_title("按聊天分组的表达式使用次数散点图", fontsize=14, fontweight="bold")
|
||||||
|
|
||||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||||
|
|
@ -190,7 +192,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
|
||||||
plt.xticks(rotation=45)
|
plt.xticks(rotation=45)
|
||||||
|
|
||||||
# 添加图例
|
# 添加图例
|
||||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
|
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
|
||||||
|
|
||||||
# 添加网格
|
# 添加网格
|
||||||
ax.grid(True, alpha=0.3)
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
@ -199,7 +201,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
# 显示统计信息
|
# 显示统计信息
|
||||||
print(f"\n=== 分组统计 ===")
|
print("\n=== 分组统计 ===")
|
||||||
print(f"总聊天数量: {len(chat_groups)}")
|
print(f"总聊天数量: {len(chat_groups)}")
|
||||||
for chat_id, chat_data in chat_groups.items():
|
for chat_id, chat_data in chat_groups.items():
|
||||||
chat_name = get_chat_name(chat_id)
|
chat_name = get_chat_name(chat_id)
|
||||||
|
|
@ -208,7 +210,7 @@ def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_
|
||||||
|
|
||||||
# 保存图片
|
# 保存图片
|
||||||
if save_path:
|
if save_path:
|
||||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||||
print(f"\n分组散点图已保存到: {save_path}")
|
print(f"\n分组散点图已保存到: {save_path}")
|
||||||
|
|
||||||
# 显示图片
|
# 显示图片
|
||||||
|
|
@ -233,15 +235,15 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
|
||||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||||
time_span = max(all_dates) - min(all_dates)
|
time_span = max(all_dates) - min(all_dates)
|
||||||
if time_span.days > 30: # 超过30天,按月显示
|
if time_span.days > 30: # 超过30天,按月显示
|
||||||
date_format = '%Y-%m-%d'
|
date_format = "%Y-%m-%d"
|
||||||
major_locator = mdates.MonthLocator()
|
major_locator = mdates.MonthLocator()
|
||||||
minor_locator = mdates.DayLocator(interval=7)
|
minor_locator = mdates.DayLocator(interval=7)
|
||||||
elif time_span.days > 7: # 超过7天,按天显示
|
elif time_span.days > 7: # 超过7天,按天显示
|
||||||
date_format = '%Y-%m-%d'
|
date_format = "%Y-%m-%d"
|
||||||
major_locator = mdates.DayLocator(interval=1)
|
major_locator = mdates.DayLocator(interval=1)
|
||||||
minor_locator = mdates.HourLocator(interval=12)
|
minor_locator = mdates.HourLocator(interval=12)
|
||||||
else: # 7天内,按小时显示
|
else: # 7天内,按小时显示
|
||||||
date_format = '%Y-%m-%d %H:%M'
|
date_format = "%Y-%m-%d %H:%M"
|
||||||
major_locator = mdates.HourLocator(interval=6)
|
major_locator = mdates.HourLocator(interval=6)
|
||||||
minor_locator = mdates.HourLocator(interval=1)
|
minor_locator = mdates.HourLocator(interval=1)
|
||||||
|
|
||||||
|
|
@ -256,14 +258,21 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
|
||||||
counts = [item[1] for item in type_data]
|
counts = [item[1] for item in type_data]
|
||||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||||
|
|
||||||
ax.scatter(dates, counts, alpha=0.7, s=40,
|
ax.scatter(
|
||||||
c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)",
|
dates,
|
||||||
edgecolors='black', linewidth=0.5)
|
counts,
|
||||||
|
alpha=0.7,
|
||||||
|
s=40,
|
||||||
|
c=[colors[i]],
|
||||||
|
label=f"{expr_type} ({len(type_data)}个)",
|
||||||
|
edgecolors="black",
|
||||||
|
linewidth=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
# 设置标签和标题
|
# 设置标签和标题
|
||||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||||
ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold')
|
ax.set_title("按表达式类型分组的散点图", fontsize=14, fontweight="bold")
|
||||||
|
|
||||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||||
|
|
@ -272,7 +281,7 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
|
||||||
plt.xticks(rotation=45)
|
plt.xticks(rotation=45)
|
||||||
|
|
||||||
# 添加图例
|
# 添加图例
|
||||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
||||||
|
|
||||||
# 添加网格
|
# 添加网格
|
||||||
ax.grid(True, alpha=0.3)
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
@ -281,14 +290,14 @@ def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_pat
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
# 显示统计信息
|
# 显示统计信息
|
||||||
print(f"\n=== 类型统计 ===")
|
print("\n=== 类型统计 ===")
|
||||||
for expr_type, type_data in type_groups.items():
|
for expr_type, type_data in type_groups.items():
|
||||||
counts = [item[1] for item in type_data]
|
counts = [item[1] for item in type_data]
|
||||||
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||||
|
|
||||||
# 保存图片
|
# 保存图片
|
||||||
if save_path:
|
if save_path:
|
||||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||||
print(f"\n类型散点图已保存到: {save_path}")
|
print(f"\n类型散点图已保存到: {save_path}")
|
||||||
|
|
||||||
# 显示图片
|
# 显示图片
|
||||||
|
|
|
||||||
|
|
@ -945,9 +945,7 @@ class EmojiManager:
|
||||||
prompt, image_base64, "jpg", temperature=0.5
|
prompt, image_base64, "jpg", temperature=0.5
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = (
|
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析,精简回答"
|
||||||
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗,meme的角度去分析,精简回答"
|
|
||||||
)
|
|
||||||
description, _ = await self.vlm.generate_response_for_image(
|
description, _ = await self.vlm.generate_response_for_image(
|
||||||
prompt, image_base64, image_format, temperature=0.5
|
prompt, image_base64, image_format, temperature=0.5
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.apis import frequency_api
|
from src.plugin_system.apis import frequency_api
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{name_block}
|
"""{name_block}
|
||||||
|
|
@ -54,7 +55,6 @@ class FrequencyControl:
|
||||||
"""设置发言频率调整值"""
|
"""设置发言频率调整值"""
|
||||||
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
self.talk_frequency_adjust = max(0.1, min(5.0, value))
|
||||||
|
|
||||||
|
|
||||||
async def trigger_frequency_adjust(self) -> None:
|
async def trigger_frequency_adjust(self) -> None:
|
||||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
|
|
@ -62,7 +62,6 @@ class FrequencyControl:
|
||||||
timestamp_end=time.time(),
|
timestamp_end=time.time(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20:
|
if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
|
@ -118,7 +117,8 @@ class FrequencyControl:
|
||||||
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
|
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
|
||||||
self.last_frequency_adjust_time = time.time()
|
self.last_frequency_adjust_time = time.time()
|
||||||
else:
|
else:
|
||||||
logger.info(f"频率调整:response不符合要求,取消本次调整")
|
logger.info("频率调整:response不符合要求,取消本次调整")
|
||||||
|
|
||||||
|
|
||||||
class FrequencyControlManager:
|
class FrequencyControlManager:
|
||||||
"""频率控制管理器,管理多个聊天流的频率控制实例"""
|
"""频率控制管理器,管理多个聊天流的频率控制实例"""
|
||||||
|
|
@ -143,6 +143,7 @@ class FrequencyControlManager:
|
||||||
"""获取所有有频率控制的聊天ID"""
|
"""获取所有有频率控制的聊天ID"""
|
||||||
return list(self.frequency_control_dict.keys())
|
return list(self.frequency_control_dict.keys())
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
# 创建全局实例
|
# 创建全局实例
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from multiprocessing import context
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import random
|
import random
|
||||||
|
|
@ -102,7 +101,7 @@ class HeartFChatting:
|
||||||
|
|
||||||
self.is_mute = False
|
self.is_mute = False
|
||||||
|
|
||||||
self.last_active_time = time.time() # 记录上一次非noreply时间
|
self.last_active_time = time.time() # 记录上一次非noreply时间
|
||||||
|
|
||||||
self.question_probability_multiplier = 1
|
self.question_probability_multiplier = 1
|
||||||
self.questioned = False
|
self.questioned = False
|
||||||
|
|
@ -191,9 +190,6 @@ class HeartFChatting:
|
||||||
filter_command=True,
|
filter_command=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 根据连续 no_reply 次数动态调整阈值
|
# 根据连续 no_reply 次数动态调整阈值
|
||||||
# 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2)
|
# 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2)
|
||||||
# 5次 no_reply 时,提高到 2(大于等于两条消息的阈值)
|
# 5次 no_reply 时,提高到 2(大于等于两条消息的阈值)
|
||||||
|
|
@ -207,7 +203,7 @@ class HeartFChatting:
|
||||||
|
|
||||||
if len(recent_messages_list) >= threshold:
|
if len(recent_messages_list) >= threshold:
|
||||||
# for message in recent_messages_list:
|
# for message in recent_messages_list:
|
||||||
# print(message.processed_plain_text)
|
# print(message.processed_plain_text)
|
||||||
# !处理no_reply_until_call逻辑
|
# !处理no_reply_until_call逻辑
|
||||||
if self.no_reply_until_call:
|
if self.no_reply_until_call:
|
||||||
for message in recent_messages_list:
|
for message in recent_messages_list:
|
||||||
|
|
@ -395,14 +391,15 @@ class HeartFChatting:
|
||||||
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||||
if recent_messages_list is None:
|
if recent_messages_list is None:
|
||||||
recent_messages_list = []
|
recent_messages_list = []
|
||||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||||
asyncio.create_task(frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust())
|
asyncio.create_task(
|
||||||
|
frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()
|
||||||
|
)
|
||||||
|
|
||||||
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||||
# asyncio.create_task(check_and_make_question(self.stream_id))
|
# asyncio.create_task(check_and_make_question(self.stream_id))
|
||||||
|
|
@ -412,7 +409,6 @@ class HeartFChatting:
|
||||||
# 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
|
# 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
|
||||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||||
|
|
||||||
|
|
||||||
cycle_timers, thinking_id = self.start_cycle()
|
cycle_timers, thinking_id = self.start_cycle()
|
||||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||||
|
|
||||||
|
|
@ -457,7 +453,12 @@ class HeartFChatting:
|
||||||
# 处理回复结果
|
# 处理回复结果
|
||||||
if isinstance(reply_result, BaseException):
|
if isinstance(reply_result, BaseException):
|
||||||
logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}")
|
logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}")
|
||||||
reply_result = {"action_type": "reply", "success": False, "result": "回复生成异常", "loop_info": None}
|
reply_result = {
|
||||||
|
"action_type": "reply",
|
||||||
|
"success": False,
|
||||||
|
"result": "回复生成异常",
|
||||||
|
"loop_info": None,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# 正常流程:只执行planner
|
# 正常流程:只执行planner
|
||||||
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
|
||||||
|
|
@ -558,7 +559,7 @@ class HeartFChatting:
|
||||||
"taken_time": time.time(),
|
"taken_time": time.time(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
reply_text = reply_text_from_reply
|
_reply_text = reply_text_from_reply
|
||||||
else:
|
else:
|
||||||
# 没有回复信息,构建纯动作的loop_info
|
# 没有回复信息,构建纯动作的loop_info
|
||||||
loop_info = {
|
loop_info = {
|
||||||
|
|
@ -571,7 +572,7 @@ class HeartFChatting:
|
||||||
"taken_time": time.time(),
|
"taken_time": time.time(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
reply_text = action_reply_text
|
_reply_text = action_reply_text
|
||||||
|
|
||||||
self.end_cycle(loop_info, cycle_timers)
|
self.end_cycle(loop_info, cycle_timers)
|
||||||
self.print_cycle_info(cycle_timers)
|
self.print_cycle_info(cycle_timers)
|
||||||
|
|
@ -647,7 +648,6 @@ class HeartFChatting:
|
||||||
result = await action_handler.execute()
|
result = await action_handler.execute()
|
||||||
success, action_text = result
|
success, action_text = result
|
||||||
|
|
||||||
|
|
||||||
return success, action_text
|
return success, action_text
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -655,8 +655,6 @@ class HeartFChatting:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, ""
|
return False, ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _send_response(
|
async def _send_response(
|
||||||
self,
|
self,
|
||||||
reply_set: "ReplySetModel",
|
reply_set: "ReplySetModel",
|
||||||
|
|
@ -732,7 +730,6 @@ class HeartFChatting:
|
||||||
action_reasoning=reason,
|
action_reasoning=reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
|
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
|
||||||
|
|
||||||
elif action_planner_info.action_type == "no_reply_until_call":
|
elif action_planner_info.action_type == "no_reply_until_call":
|
||||||
|
|
@ -753,7 +750,12 @@ class HeartFChatting:
|
||||||
action_name="no_reply_until_call",
|
action_name="no_reply_until_call",
|
||||||
action_reasoning=reason,
|
action_reasoning=reason,
|
||||||
)
|
)
|
||||||
return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""}
|
return {
|
||||||
|
"action_type": "no_reply_until_call",
|
||||||
|
"success": True,
|
||||||
|
"result": "保持沉默,直到有人直接叫的名字",
|
||||||
|
"command": "",
|
||||||
|
}
|
||||||
|
|
||||||
elif action_planner_info.action_type == "reply":
|
elif action_planner_info.action_type == "reply":
|
||||||
# 直接当场执行reply逻辑
|
# 直接当场执行reply逻辑
|
||||||
|
|
@ -783,19 +785,16 @@ class HeartFChatting:
|
||||||
enable_tool=global_config.tool.enable_tool,
|
enable_tool=global_config.tool.enable_tool,
|
||||||
request_type="replyer",
|
request_type="replyer",
|
||||||
from_plugin=False,
|
from_plugin=False,
|
||||||
reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()),
|
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success or not llm_response or not llm_response.reply_set:
|
if not success or not llm_response or not llm_response.reply_set:
|
||||||
if action_planner_info.action_message:
|
if action_planner_info.action_message:
|
||||||
logger.info(
|
logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
|
||||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info("回复生成失败")
|
logger.info("回复生成失败")
|
||||||
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
|
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
|
||||||
|
|
||||||
|
|
||||||
response_set = llm_response.reply_set
|
response_set = llm_response.reply_set
|
||||||
selected_expressions = llm_response.selected_expressions
|
selected_expressions = llm_response.selected_expressions
|
||||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||||
|
|
@ -817,12 +816,12 @@ class HeartFChatting:
|
||||||
# 执行普通动作
|
# 执行普通动作
|
||||||
with Timer("动作执行", cycle_timers):
|
with Timer("动作执行", cycle_timers):
|
||||||
success, result = await self._handle_action(
|
success, result = await self._handle_action(
|
||||||
action = action_planner_info.action_type,
|
action=action_planner_info.action_type,
|
||||||
action_reasoning = action_planner_info.action_reasoning or "",
|
action_reasoning=action_planner_info.action_reasoning or "",
|
||||||
action_data = action_planner_info.action_data or {},
|
action_data=action_planner_info.action_data or {},
|
||||||
cycle_timers = cycle_timers,
|
cycle_timers=cycle_timers,
|
||||||
thinking_id = thinking_id,
|
thinking_id=thinking_id,
|
||||||
action_message= action_planner_info.action_message,
|
action_message=action_planner_info.action_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.last_active_time = time.time()
|
self.last_active_time = time.time()
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,11 @@ from src.person_info.person_info import Person
|
||||||
from src.common.database.database_model import Images
|
from src.common.database.database_model import Images
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
pass
|
||||||
|
|
||||||
logger = get_logger("chat")
|
logger = get_logger("chat")
|
||||||
|
|
||||||
|
|
||||||
class HeartFCMessageReceiver:
|
class HeartFCMessageReceiver:
|
||||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||||
from src.plugin_system.base import BaseCommand, EventType
|
from src.plugin_system.base import BaseCommand, EventType
|
||||||
from src.person_info.person_info import Person
|
|
||||||
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
|
|
||||||
|
|
@ -171,7 +170,11 @@ class ChatBot:
|
||||||
|
|
||||||
# 撤回事件打印;无法获取被撤回者则省略
|
# 撤回事件打印;无法获取被撤回者则省略
|
||||||
if sub_type == "recall":
|
if sub_type == "recall":
|
||||||
op_name = getattr(op, "user_cardname", None) or getattr(op, "user_nickname", None) or str(getattr(op, "user_id", None))
|
op_name = (
|
||||||
|
getattr(op, "user_cardname", None)
|
||||||
|
or getattr(op, "user_nickname", None)
|
||||||
|
or str(getattr(op, "user_id", None))
|
||||||
|
)
|
||||||
recalled_name = None
|
recalled_name = None
|
||||||
try:
|
try:
|
||||||
if isinstance(recalled, dict):
|
if isinstance(recalled, dict):
|
||||||
|
|
@ -189,7 +192,7 @@ class ChatBot:
|
||||||
logger.info(f"{op_name} 撤回了消息")
|
logger.info(f"{op_name} 撤回了消息")
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op,'user_nickname',None)}({getattr(op,'user_id',None)}) "
|
f"[notice] sub_type={sub_type} scene={scene} op={getattr(op, 'user_nickname', None)}({getattr(op, 'user_id', None)}) "
|
||||||
f"gid={gid} msg_id={msg_id} recalled={recalled_id}"
|
f"gid={gid} msg_id={msg_id} recalled={recalled_id}"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -234,7 +237,6 @@ class ChatBot:
|
||||||
# 确保所有任务已启动
|
# 确保所有任务已启动
|
||||||
await self._ensure_started()
|
await self._ensure_started()
|
||||||
|
|
||||||
|
|
||||||
if message_data["message_info"].get("group_info") is not None:
|
if message_data["message_info"].get("group_info") is not None:
|
||||||
message_data["message_info"]["group_info"]["group_id"] = str(
|
message_data["message_info"]["group_info"]["group_id"] = str(
|
||||||
message_data["message_info"]["group_info"]["group_id"]
|
message_data["message_info"]["group_info"]["group_id"]
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,6 @@ class ActionPlanner:
|
||||||
|
|
||||||
self.last_obs_time_mark = 0.0
|
self.last_obs_time_mark = 0.0
|
||||||
|
|
||||||
|
|
||||||
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
|
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
|
||||||
|
|
||||||
def find_message_by_id(
|
def find_message_by_id(
|
||||||
|
|
@ -306,7 +305,9 @@ class ActionPlanner:
|
||||||
loop_start_time=loop_start_time,
|
loop_start_time=loop_start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
logger.info(
|
||||||
|
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||||
|
)
|
||||||
|
|
||||||
self.add_plan_log(reasoning, actions)
|
self.add_plan_log(reasoning, actions)
|
||||||
|
|
||||||
|
|
@ -402,8 +403,7 @@ class ActionPlanner:
|
||||||
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
|
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
|
||||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||||
try:
|
try:
|
||||||
|
actions_before_now_block = self.get_plan_log_str()
|
||||||
actions_before_now_block=self.get_plan_log_str()
|
|
||||||
|
|
||||||
# 构建聊天上下文描述
|
# 构建聊天上下文描述
|
||||||
chat_context_description = "你现在正在一个群聊中"
|
chat_context_description = "你现在正在一个群聊中"
|
||||||
|
|
@ -564,7 +564,7 @@ class ActionPlanner:
|
||||||
filtered_actions: Dict[str, ActionInfo],
|
filtered_actions: Dict[str, ActionInfo],
|
||||||
available_actions: Dict[str, ActionInfo],
|
available_actions: Dict[str, ActionInfo],
|
||||||
loop_start_time: float,
|
loop_start_time: float,
|
||||||
) -> Tuple[str,List[ActionPlannerInfo]]:
|
) -> Tuple[str, List[ActionPlannerInfo]]:
|
||||||
"""执行主规划器"""
|
"""执行主规划器"""
|
||||||
llm_content = None
|
llm_content = None
|
||||||
actions: List[ActionPlannerInfo] = []
|
actions: List[ActionPlannerInfo] = []
|
||||||
|
|
@ -589,7 +589,7 @@ class ActionPlanner:
|
||||||
|
|
||||||
except Exception as req_e:
|
except Exception as req_e:
|
||||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||||
return f"LLM 请求失败,模型出现问题: {req_e}",[
|
return f"LLM 请求失败,模型出现问题: {req_e}", [
|
||||||
ActionPlannerInfo(
|
ActionPlannerInfo(
|
||||||
action_type="no_reply",
|
action_type="no_reply",
|
||||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||||
|
|
@ -608,7 +608,11 @@ class ActionPlanner:
|
||||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||||
filtered_actions_list = list(filtered_actions.items())
|
filtered_actions_list = list(filtered_actions.items())
|
||||||
for json_obj in json_objects:
|
for json_obj in json_objects:
|
||||||
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list, extracted_reasoning))
|
actions.extend(
|
||||||
|
self._parse_single_action(
|
||||||
|
json_obj, message_id_list, filtered_actions_list, extracted_reasoning
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 尝试解析为直接的JSON
|
# 尝试解析为直接的JSON
|
||||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||||
|
|
@ -631,7 +635,7 @@ class ActionPlanner:
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||||
|
|
||||||
return extracted_reasoning,actions
|
return extracted_reasoning, actions
|
||||||
|
|
||||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||||
"""创建no_reply"""
|
"""创建no_reply"""
|
||||||
|
|
@ -674,7 +678,7 @@ class ActionPlanner:
|
||||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||||
if json_str := json_str.strip():
|
if json_str := json_str.strip():
|
||||||
# 尝试按行分割,每行可能是一个JSON对象
|
# 尝试按行分割,每行可能是一个JSON对象
|
||||||
lines = [line.strip() for line in json_str.split('\n') if line.strip()]
|
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
try:
|
try:
|
||||||
# 尝试解析每一行作为独立的JSON对象
|
# 尝试解析每一行作为独立的JSON对象
|
||||||
|
|
|
||||||
|
|
@ -276,7 +276,6 @@ class DefaultReplyer:
|
||||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||||
return f"你现在的心情是:{mood_state}"
|
return f"你现在的心情是:{mood_state}"
|
||||||
|
|
||||||
|
|
||||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||||
"""构建工具信息块
|
"""构建工具信息块
|
||||||
|
|
||||||
|
|
@ -303,7 +302,7 @@ class DefaultReplyer:
|
||||||
for tool_result in tool_results:
|
for tool_result in tool_results:
|
||||||
tool_name = tool_result.get("tool_name", "unknown")
|
tool_name = tool_result.get("tool_name", "unknown")
|
||||||
content = tool_result.get("content", "")
|
content = tool_result.get("content", "")
|
||||||
result_type = tool_result.get("type", "tool_result")
|
_result_type = tool_result.get("type", "tool_result")
|
||||||
|
|
||||||
tool_info_str += f"- 【{tool_name}】: {content}\n"
|
tool_info_str += f"- 【{tool_name}】: {content}\n"
|
||||||
|
|
||||||
|
|
@ -605,9 +604,11 @@ class DefaultReplyer:
|
||||||
prompt_personality = global_config.personality.personality
|
prompt_personality = global_config.personality.personality
|
||||||
|
|
||||||
# 检查是否需要随机替换为状态
|
# 检查是否需要随机替换为状态
|
||||||
if (global_config.personality.states and
|
if (
|
||||||
global_config.personality.state_probability > 0 and
|
global_config.personality.states
|
||||||
random.random() < global_config.personality.state_probability):
|
and global_config.personality.state_probability > 0
|
||||||
|
and random.random() < global_config.personality.state_probability
|
||||||
|
):
|
||||||
# 随机选择一个状态替换personality
|
# 随机选择一个状态替换personality
|
||||||
selected_state = random.choice(global_config.personality.states)
|
selected_state = random.choice(global_config.personality.states)
|
||||||
prompt_personality = selected_state
|
prompt_personality = selected_state
|
||||||
|
|
@ -720,7 +721,7 @@ class DefaultReplyer:
|
||||||
available_actions = {}
|
available_actions = {}
|
||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
is_group_chat = bool(chat_stream.group_info)
|
_is_group_chat = bool(chat_stream.group_info)
|
||||||
platform = chat_stream.platform
|
platform = chat_stream.platform
|
||||||
|
|
||||||
user_id = "用户ID"
|
user_id = "用户ID"
|
||||||
|
|
@ -956,9 +957,7 @@ class DefaultReplyer:
|
||||||
)
|
)
|
||||||
elif has_text and pic_part:
|
elif has_text and pic_part:
|
||||||
# 既有图片又有文字
|
# 既有图片又有文字
|
||||||
reply_target_block = (
|
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 只包含文字
|
# 只包含文字
|
||||||
reply_target_block = (
|
reply_target_block = (
|
||||||
|
|
@ -975,7 +974,9 @@ class DefaultReplyer:
|
||||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||||
elif has_text and pic_part:
|
elif has_text and pic_part:
|
||||||
# 既有图片又有文字
|
# 既有图片又有文字
|
||||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
reply_target_block = (
|
||||||
|
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 只包含文字
|
# 只包含文字
|
||||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||||
|
|
@ -1132,6 +1133,7 @@ class DefaultReplyer:
|
||||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||||
"""
|
"""
|
||||||
加权且不放回地随机抽取k个元素。
|
加权且不放回地随机抽取k个元素。
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ init_memory_retrieval_prompt()
|
||||||
|
|
||||||
logger = get_logger("replyer")
|
logger = get_logger("replyer")
|
||||||
|
|
||||||
|
|
||||||
class PrivateReplyer:
|
class PrivateReplyer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -277,9 +278,7 @@ class PrivateReplyer:
|
||||||
expression_habits_block = ""
|
expression_habits_block = ""
|
||||||
expression_habits_title = ""
|
expression_habits_title = ""
|
||||||
if style_habits_str.strip():
|
if style_habits_str.strip():
|
||||||
expression_habits_title = (
|
expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
|
||||||
"在回复时,你可以参考以下的语言习惯,不要生硬使用:"
|
|
||||||
)
|
|
||||||
expression_habits_block += f"{style_habits_str}\n"
|
expression_habits_block += f"{style_habits_str}\n"
|
||||||
|
|
||||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||||
|
|
@ -291,7 +290,6 @@ class PrivateReplyer:
|
||||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||||
return f"你现在的心情是:{mood_state}"
|
return f"你现在的心情是:{mood_state}"
|
||||||
|
|
||||||
|
|
||||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||||
"""构建工具信息块
|
"""构建工具信息块
|
||||||
|
|
||||||
|
|
@ -519,9 +517,11 @@ class PrivateReplyer:
|
||||||
prompt_personality = global_config.personality.personality
|
prompt_personality = global_config.personality.personality
|
||||||
|
|
||||||
# 检查是否需要随机替换为状态
|
# 检查是否需要随机替换为状态
|
||||||
if (global_config.personality.states and
|
if (
|
||||||
global_config.personality.state_probability > 0 and
|
global_config.personality.states
|
||||||
random.random() < global_config.personality.state_probability):
|
and global_config.personality.state_probability > 0
|
||||||
|
and random.random() < global_config.personality.state_probability
|
||||||
|
):
|
||||||
# 随机选择一个状态替换personality
|
# 随机选择一个状态替换personality
|
||||||
selected_state = random.choice(global_config.personality.states)
|
selected_state = random.choice(global_config.personality.states)
|
||||||
prompt_personality = selected_state
|
prompt_personality = selected_state
|
||||||
|
|
@ -647,8 +647,6 @@ class PrivateReplyer:
|
||||||
sender = person_name
|
sender = person_name
|
||||||
target = reply_message.processed_plain_text
|
target = reply_message.processed_plain_text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||||
|
|
||||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||||
|
|
@ -710,9 +708,7 @@ class PrivateReplyer:
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||||
),
|
),
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"),
|
||||||
self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"
|
|
||||||
),
|
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||||
),
|
),
|
||||||
|
|
@ -859,8 +855,6 @@ class PrivateReplyer:
|
||||||
# 将[picid:xxx]替换为具体的图片描述
|
# 将[picid:xxx]替换为具体的图片描述
|
||||||
target = self._replace_picids_with_descriptions(target)
|
target = self._replace_picids_with_descriptions(target)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
|
|
@ -900,9 +894,7 @@ class PrivateReplyer:
|
||||||
)
|
)
|
||||||
elif has_text and pic_part:
|
elif has_text and pic_part:
|
||||||
# 既有图片又有文字
|
# 既有图片又有文字
|
||||||
reply_target_block = (
|
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 只包含文字
|
# 只包含文字
|
||||||
reply_target_block = (
|
reply_target_block = (
|
||||||
|
|
@ -919,7 +911,9 @@ class PrivateReplyer:
|
||||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||||
elif has_text and pic_part:
|
elif has_text and pic_part:
|
||||||
# 既有图片又有文字
|
# 既有图片又有文字
|
||||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
reply_target_block = (
|
||||||
|
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 只包含文字
|
# 只包含文字
|
||||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||||
|
|
@ -1106,6 +1100,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||||
pool.pop(idx)
|
pool.pop(idx)
|
||||||
break
|
break
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,13 @@
|
||||||
|
|
||||||
from src.chat.utils.prompt_builder import Prompt
|
from src.chat.utils.prompt_builder import Prompt
|
||||||
# from src.chat.memory_system.memory_activator import MemoryActivator
|
# from src.chat.memory_system.memory_activator import MemoryActivator
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_replyer_prompt():
|
def init_replyer_prompt():
|
||||||
Prompt("正在群里聊天", "chat_target_group2")
|
Prompt("正在群里聊天", "chat_target_group2")
|
||||||
Prompt("和{sender_name}聊天", "chat_target_private2")
|
Prompt("和{sender_name}聊天", "chat_target_private2")
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||||
{expression_habits_block}{memory_retrieval}
|
{expression_habits_block}{memory_retrieval}
|
||||||
|
|
||||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片:
|
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片:
|
||||||
|
|
@ -28,9 +25,8 @@ def init_replyer_prompt():
|
||||||
"replyer_prompt",
|
"replyer_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||||
{expression_habits_block}{memory_retrieval}
|
{expression_habits_block}{memory_retrieval}
|
||||||
|
|
||||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||||
|
|
@ -47,9 +43,8 @@ def init_replyer_prompt():
|
||||||
"private_replyer_prompt",
|
"private_replyer_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||||
{expression_habits_block}{memory_retrieval}
|
{expression_habits_block}{memory_retrieval}
|
||||||
|
|
||||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
聊天内容概括器
|
聊天内容概括器
|
||||||
用于累积、打包和压缩聊天记录
|
用于累积、打包和压缩聊天记录
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
@ -23,6 +24,7 @@ logger = get_logger("chat_history_summarizer")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageBatch:
|
class MessageBatch:
|
||||||
"""消息批次"""
|
"""消息批次"""
|
||||||
|
|
||||||
messages: List[DatabaseMessages]
|
messages: List[DatabaseMessages]
|
||||||
start_time: float
|
start_time: float
|
||||||
end_time: float
|
end_time: float
|
||||||
|
|
@ -52,8 +54,7 @@ class ChatHistorySummarizer:
|
||||||
|
|
||||||
# LLM请求器,用于压缩聊天内容
|
# LLM请求器,用于压缩聊天内容
|
||||||
self.summarizer_llm = LLMRequest(
|
self.summarizer_llm = LLMRequest(
|
||||||
model_set=model_config.model_task_config.utils,
|
model_set=model_config.model_task_config.utils, request_type="chat_history_summarizer"
|
||||||
request_type="chat_history_summarizer"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 后台循环相关
|
# 后台循环相关
|
||||||
|
|
@ -117,9 +118,7 @@ class ChatHistorySummarizer:
|
||||||
before_count = len(self.current_batch.messages)
|
before_count = len(self.current_batch.messages)
|
||||||
self.current_batch.messages.extend(new_messages)
|
self.current_batch.messages.extend(new_messages)
|
||||||
self.current_batch.end_time = current_time
|
self.current_batch.end_time = current_time
|
||||||
logger.info(
|
logger.info(f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||||
f"{self.log_prefix} 批次更新: {before_count} -> {len(self.current_batch.messages)} 条消息"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 创建新批次
|
# 创建新批次
|
||||||
self.current_batch = MessageBatch(
|
self.current_batch = MessageBatch(
|
||||||
|
|
@ -127,9 +126,7 @@ class ChatHistorySummarizer:
|
||||||
start_time=new_messages[0].time if new_messages else current_time,
|
start_time=new_messages[0].time if new_messages else current_time,
|
||||||
end_time=current_time,
|
end_time=current_time,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息")
|
||||||
f"{self.log_prefix} 新建批次: {len(new_messages)} 条消息"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查是否需要打包
|
# 检查是否需要打包
|
||||||
await self._check_and_package(current_time)
|
await self._check_and_package(current_time)
|
||||||
|
|
@ -137,6 +134,7 @@ class ChatHistorySummarizer:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}")
|
logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _check_and_package(self, current_time: float):
|
async def _check_and_package(self, current_time: float):
|
||||||
|
|
@ -153,9 +151,9 @@ class ChatHistorySummarizer:
|
||||||
if time_since_last_message < 60:
|
if time_since_last_message < 60:
|
||||||
time_str = f"{time_since_last_message:.1f}秒"
|
time_str = f"{time_since_last_message:.1f}秒"
|
||||||
elif time_since_last_message < 3600:
|
elif time_since_last_message < 3600:
|
||||||
time_str = f"{time_since_last_message/60:.1f}分钟"
|
time_str = f"{time_since_last_message / 60:.1f}分钟"
|
||||||
else:
|
else:
|
||||||
time_str = f"{time_since_last_message/3600:.1f}小时"
|
time_str = f"{time_since_last_message / 3600:.1f}小时"
|
||||||
|
|
||||||
preparing_status = "是" if self.current_batch.is_preparing else "否"
|
preparing_status = "是" if self.current_batch.is_preparing else "否"
|
||||||
|
|
||||||
|
|
@ -250,26 +248,23 @@ class ChatHistorySummarizer:
|
||||||
participants_set: Set[str] = set()
|
participants_set: Set[str] = set()
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
# 使用 msg.user_platform(扁平化字段)或 msg.user_info.platform
|
# 使用 msg.user_platform(扁平化字段)或 msg.user_info.platform
|
||||||
platform = getattr(msg, 'user_platform', None) or (msg.user_info.platform if msg.user_info else None) or msg.chat_info.platform
|
platform = (
|
||||||
person = Person(
|
getattr(msg, "user_platform", None)
|
||||||
platform=platform,
|
or (msg.user_info.platform if msg.user_info else None)
|
||||||
user_id=msg.user_info.user_id
|
or msg.chat_info.platform
|
||||||
)
|
)
|
||||||
|
person = Person(platform=platform, user_id=msg.user_info.user_id)
|
||||||
person_name = person.person_name
|
person_name = person.person_name
|
||||||
if person_name:
|
if person_name:
|
||||||
participants_set.add(person_name)
|
participants_set.add(person_name)
|
||||||
participants = list(participants_set)
|
participants = list(participants_set)
|
||||||
logger.info(
|
logger.info(f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}")
|
||||||
f"{self.log_prefix} 批次参与者: {', '.join(participants) if participants else '未知'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用LLM压缩聊天内容
|
# 使用LLM压缩聊天内容
|
||||||
success, theme, keywords, summary = await self._compress_with_llm(original_text)
|
success, theme, keywords, summary = await self._compress_with_llm(original_text)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
logger.warning(
|
logger.warning(f"{self.log_prefix} LLM压缩失败,不存储到数据库 | 消息数: {len(messages)}")
|
||||||
f"{self.log_prefix} LLM压缩失败,不存储到数据库 | 消息数: {len(messages)}"
|
|
||||||
)
|
|
||||||
# 清空当前批次,避免重复处理
|
# 清空当前批次,避免重复处理
|
||||||
self.current_batch = None
|
self.current_batch = None
|
||||||
return
|
return
|
||||||
|
|
@ -297,6 +292,7 @@ class ChatHistorySummarizer:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}")
|
logger.error(f"{self.log_prefix} 打包和存储聊天记录时出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# 出错时也清空批次,避免重复处理
|
# 出错时也清空批次,避免重复处理
|
||||||
self.current_batch = None
|
self.current_batch = None
|
||||||
|
|
@ -338,23 +334,23 @@ class ChatHistorySummarizer:
|
||||||
|
|
||||||
# 移除可能的markdown代码块标记
|
# 移除可能的markdown代码块标记
|
||||||
json_str = response.strip()
|
json_str = response.strip()
|
||||||
json_str = re.sub(r'^```json\s*', '', json_str, flags=re.MULTILINE)
|
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
|
||||||
json_str = re.sub(r'^```\s*', '', json_str, flags=re.MULTILINE)
|
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
|
||||||
json_str = json_str.strip()
|
json_str = json_str.strip()
|
||||||
|
|
||||||
# 尝试找到JSON对象的开始和结束位置
|
# 尝试找到JSON对象的开始和结束位置
|
||||||
# 查找第一个 { 和最后一个匹配的 }
|
# 查找第一个 { 和最后一个匹配的 }
|
||||||
start_idx = json_str.find('{')
|
start_idx = json_str.find("{")
|
||||||
if start_idx == -1:
|
if start_idx == -1:
|
||||||
raise ValueError("未找到JSON对象开始标记")
|
raise ValueError("未找到JSON对象开始标记")
|
||||||
|
|
||||||
# 从后往前查找最后一个 }
|
# 从后往前查找最后一个 }
|
||||||
end_idx = json_str.rfind('}')
|
end_idx = json_str.rfind("}")
|
||||||
if end_idx == -1 or end_idx <= start_idx:
|
if end_idx == -1 or end_idx <= start_idx:
|
||||||
raise ValueError("未找到JSON对象结束标记")
|
raise ValueError("未找到JSON对象结束标记")
|
||||||
|
|
||||||
# 提取JSON字符串
|
# 提取JSON字符串
|
||||||
json_str = json_str[start_idx:end_idx + 1]
|
json_str = json_str[start_idx : end_idx + 1]
|
||||||
|
|
||||||
# 尝试解析JSON
|
# 尝试解析JSON
|
||||||
try:
|
try:
|
||||||
|
|
@ -372,7 +368,7 @@ class ChatHistorySummarizer:
|
||||||
if escape_next:
|
if escape_next:
|
||||||
fixed_chars.append(char)
|
fixed_chars.append(char)
|
||||||
escape_next = False
|
escape_next = False
|
||||||
elif char == '\\':
|
elif char == "\\":
|
||||||
fixed_chars.append(char)
|
fixed_chars.append(char)
|
||||||
escape_next = True
|
escape_next = True
|
||||||
elif char == '"' and not escape_next:
|
elif char == '"' and not escape_next:
|
||||||
|
|
@ -385,7 +381,7 @@ class ChatHistorySummarizer:
|
||||||
fixed_chars.append(char)
|
fixed_chars.append(char)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
json_str = ''.join(fixed_chars)
|
json_str = "".join(fixed_chars)
|
||||||
# 再次尝试解析
|
# 再次尝试解析
|
||||||
result = json.loads(json_str)
|
result = json.loads(json_str)
|
||||||
|
|
||||||
|
|
@ -450,6 +446,7 @@ class ChatHistorySummarizer:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
|
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
@ -490,6 +487,6 @@ class ChatHistorySummarizer:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 后台检查循环出错: {e}")
|
logger.error(f"{self.log_prefix} 后台检查循环出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import time
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
|
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
@ -568,7 +568,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
|
||||||
output_lines = []
|
output_lines = []
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
for action in actions:
|
for action in actions:
|
||||||
action_time = action.time or current_time
|
action_time = action.time or current_time
|
||||||
action_name = action.action_name or "未知动作"
|
action_name = action.action_name or "未知动作"
|
||||||
|
|
@ -596,7 +595,6 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
|
||||||
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}”"
|
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}”"
|
||||||
output_lines.append(line)
|
output_lines.append(line)
|
||||||
|
|
||||||
|
|
||||||
return "\n".join(output_lines)
|
return "\n".join(output_lines)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -936,7 +934,6 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
||||||
return formatted_string
|
return formatted_string
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
记忆遗忘任务
|
记忆遗忘任务
|
||||||
每5分钟进行一次遗忘检查,根据不同的遗忘阶段删除记忆
|
每5分钟进行一次遗忘检查,根据不同的遗忘阶段删除记忆
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
@ -48,11 +49,7 @@ class MemoryForgetTask(AsyncTask):
|
||||||
|
|
||||||
# 查询符合条件的记忆:forget_times=0 且 end_time < time_threshold
|
# 查询符合条件的记忆:forget_times=0 且 end_time < time_threshold
|
||||||
candidates = list(
|
candidates = list(
|
||||||
ChatHistory.select()
|
ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
|
||||||
.where(
|
|
||||||
(ChatHistory.forget_times == 0) &
|
|
||||||
(ChatHistory.end_time < time_threshold)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not candidates:
|
if not candidates:
|
||||||
|
|
@ -101,11 +98,11 @@ class MemoryForgetTask(AsyncTask):
|
||||||
if remaining:
|
if remaining:
|
||||||
# 批量更新
|
# 批量更新
|
||||||
ids_to_update = [r.id for r in remaining]
|
ids_to_update = [r.id for r in remaining]
|
||||||
ChatHistory.update(forget_times=1).where(
|
ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||||
ChatHistory.id.in_(ids_to_update)
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
logger.info(f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1")
|
logger.info(
|
||||||
|
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
|
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
|
||||||
|
|
@ -122,11 +119,7 @@ class MemoryForgetTask(AsyncTask):
|
||||||
|
|
||||||
# 查询符合条件的记忆:forget_times=1 且 end_time < time_threshold
|
# 查询符合条件的记忆:forget_times=1 且 end_time < time_threshold
|
||||||
candidates = list(
|
candidates = list(
|
||||||
ChatHistory.select()
|
ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
|
||||||
.where(
|
|
||||||
(ChatHistory.forget_times == 1) &
|
|
||||||
(ChatHistory.end_time < time_threshold)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not candidates:
|
if not candidates:
|
||||||
|
|
@ -168,11 +161,11 @@ class MemoryForgetTask(AsyncTask):
|
||||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||||
if remaining:
|
if remaining:
|
||||||
ids_to_update = [r.id for r in remaining]
|
ids_to_update = [r.id for r in remaining]
|
||||||
ChatHistory.update(forget_times=2).where(
|
ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||||
ChatHistory.id.in_(ids_to_update)
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
logger.info(f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2")
|
logger.info(
|
||||||
|
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
|
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
|
||||||
|
|
@ -189,11 +182,7 @@ class MemoryForgetTask(AsyncTask):
|
||||||
|
|
||||||
# 查询符合条件的记忆:forget_times=2 且 end_time < time_threshold
|
# 查询符合条件的记忆:forget_times=2 且 end_time < time_threshold
|
||||||
candidates = list(
|
candidates = list(
|
||||||
ChatHistory.select()
|
ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
|
||||||
.where(
|
|
||||||
(ChatHistory.forget_times == 2) &
|
|
||||||
(ChatHistory.end_time < time_threshold)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not candidates:
|
if not candidates:
|
||||||
|
|
@ -235,11 +224,11 @@ class MemoryForgetTask(AsyncTask):
|
||||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||||
if remaining:
|
if remaining:
|
||||||
ids_to_update = [r.id for r in remaining]
|
ids_to_update = [r.id for r in remaining]
|
||||||
ChatHistory.update(forget_times=3).where(
|
ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||||
ChatHistory.id.in_(ids_to_update)
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
logger.info(f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3")
|
logger.info(
|
||||||
|
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
|
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
|
||||||
|
|
@ -256,11 +245,7 @@ class MemoryForgetTask(AsyncTask):
|
||||||
|
|
||||||
# 查询符合条件的记忆:forget_times=3 且 end_time < time_threshold
|
# 查询符合条件的记忆:forget_times=3 且 end_time < time_threshold
|
||||||
candidates = list(
|
candidates = list(
|
||||||
ChatHistory.select()
|
ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
|
||||||
.where(
|
|
||||||
(ChatHistory.forget_times == 3) &
|
|
||||||
(ChatHistory.end_time < time_threshold)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not candidates:
|
if not candidates:
|
||||||
|
|
@ -302,16 +287,18 @@ class MemoryForgetTask(AsyncTask):
|
||||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||||
if remaining:
|
if remaining:
|
||||||
ids_to_update = [r.id for r in remaining]
|
ids_to_update = [r.id for r in remaining]
|
||||||
ChatHistory.update(forget_times=4).where(
|
ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||||
ChatHistory.id.in_(ids_to_update)
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
logger.info(f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4")
|
logger.info(
|
||||||
|
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
|
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
|
||||||
|
|
||||||
def _handle_same_count_random(self, candidates: List[ChatHistory], delete_count: int, mode: str) -> List[ChatHistory]:
|
def _handle_same_count_random(
|
||||||
|
self, candidates: List[ChatHistory], delete_count: int, mode: str
|
||||||
|
) -> List[ChatHistory]:
|
||||||
"""
|
"""
|
||||||
处理count相同的情况,随机选择要删除的记录
|
处理count相同的情况,随机选择要删除的记录
|
||||||
|
|
||||||
|
|
@ -373,4 +360,3 @@ class MemoryForgetTask(AsyncTask):
|
||||||
start_idx = idx
|
start_idx = idx
|
||||||
|
|
||||||
return to_delete
|
return to_delete
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -504,7 +504,11 @@ class StatisticOutputTask(AsyncTask):
|
||||||
}
|
}
|
||||||
|
|
||||||
# 获取bot的QQ账号
|
# 获取bot的QQ账号
|
||||||
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else ""
|
bot_qq_account = (
|
||||||
|
str(global_config.bot.qq_account)
|
||||||
|
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||||
|
|
@ -588,7 +592,9 @@ class StatisticOutputTask(AsyncTask):
|
||||||
continue
|
continue
|
||||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||||
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
|
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
|
||||||
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
|
self.stat_period = [
|
||||||
|
item for item in self.stat_period if item[0] != "all_time"
|
||||||
|
] # 删除"所有时间"的统计时段
|
||||||
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
|
self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}")
|
logger.warning(f"加载上次完整统计数据失败,进行全量统计,错误信息:{e}")
|
||||||
|
|
@ -699,7 +705,11 @@ class StatisticOutputTask(AsyncTask):
|
||||||
|
|
||||||
# 计算花费/消息数量(排除自己回复)指标(每100条)
|
# 计算花费/消息数量(排除自己回复)指标(每100条)
|
||||||
total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies
|
total_messages_excluding_replies = stats[TOTAL_MSG_CNT] - total_replies
|
||||||
cost_per_100_messages_excluding_replies = (stats[TOTAL_COST] / total_messages_excluding_replies * 100) if total_messages_excluding_replies > 0 else 0.0
|
cost_per_100_messages_excluding_replies = (
|
||||||
|
(stats[TOTAL_COST] / total_messages_excluding_replies * 100)
|
||||||
|
if total_messages_excluding_replies > 0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
output = [
|
output = [
|
||||||
f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
|
f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
|
||||||
|
|
@ -709,7 +719,9 @@ class StatisticOutputTask(AsyncTask):
|
||||||
f"总Token数: {_format_large_number(total_tokens)}",
|
f"总Token数: {_format_large_number(total_tokens)}",
|
||||||
f"总花费: {stats[TOTAL_COST]:.2f}¥",
|
f"总花费: {stats[TOTAL_COST]:.2f}¥",
|
||||||
f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A",
|
f"花费/消息数量: {cost_per_100_messages:.4f}¥/100条" if stats[TOTAL_MSG_CNT] > 0 else "花费/消息数量: N/A",
|
||||||
f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条" if total_messages_excluding_replies > 0 else "花费/消息数量(排除回复): N/A",
|
f"花费/接受消息数量: {cost_per_100_messages_excluding_replies:.4f}¥/100条"
|
||||||
|
if total_messages_excluding_replies > 0
|
||||||
|
else "花费/消息数量(排除回复): N/A",
|
||||||
f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A",
|
f"花费/回复消息数量: {cost_per_100_replies:.4f}¥/100条" if total_replies > 0 else "花费/回复数量: N/A",
|
||||||
f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A",
|
f"花费/时间: {cost_per_hour:.2f}¥/小时" if online_hours > 0 else "花费/时间: N/A",
|
||||||
f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A",
|
f"Token/时间: {_format_large_number(tokens_per_hour)}/小时" if online_hours > 0 else "Token/时间: N/A",
|
||||||
|
|
@ -745,7 +757,16 @@ class StatisticOutputTask(AsyncTask):
|
||||||
formatted_out_tokens = _format_large_number(out_tokens)
|
formatted_out_tokens = _format_large_number(out_tokens)
|
||||||
formatted_tokens = _format_large_number(tokens)
|
formatted_tokens = _format_large_number(tokens)
|
||||||
output.append(
|
output.append(
|
||||||
data_fmt.format(name, formatted_count, formatted_in_tokens, formatted_out_tokens, formatted_tokens, cost, avg_time_cost, std_time_cost)
|
data_fmt.format(
|
||||||
|
name,
|
||||||
|
formatted_count,
|
||||||
|
formatted_in_tokens,
|
||||||
|
formatted_out_tokens,
|
||||||
|
formatted_tokens,
|
||||||
|
cost,
|
||||||
|
avg_time_cost,
|
||||||
|
std_time_cost,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
output.append("")
|
output.append("")
|
||||||
|
|
@ -892,7 +913,11 @@ class StatisticOutputTask(AsyncTask):
|
||||||
logger.warning(f"生成HTML聊天统计时发生错误,chat_id: {chat_id}, 错误: {e}")
|
logger.warning(f"生成HTML聊天统计时发生错误,chat_id: {chat_id}, 错误: {e}")
|
||||||
chat_rows.append(f"<tr><td>未知聊天</td><td>{_format_large_number(count, html=True)}</td></tr>")
|
chat_rows.append(f"<tr><td>未知聊天</td><td>{_format_large_number(count, html=True)}</td></tr>")
|
||||||
|
|
||||||
chat_rows_html = "\n".join(chat_rows) if chat_rows else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
|
chat_rows_html = (
|
||||||
|
"\n".join(chat_rows)
|
||||||
|
if chat_rows
|
||||||
|
else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
|
||||||
|
)
|
||||||
# 生成HTML
|
# 生成HTML
|
||||||
return f"""
|
return f"""
|
||||||
<div id=\"{div_id}\" class=\"tab-content\">
|
<div id=\"{div_id}\" class=\"tab-content\">
|
||||||
|
|
@ -1777,10 +1802,10 @@ class StatisticOutputTask(AsyncTask):
|
||||||
metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1)
|
metrics_data["24h"] = self._collect_metrics_interval_data(now, hours=24, interval_hours=1)
|
||||||
|
|
||||||
# 7天尺度:1天为单位
|
# 7天尺度:1天为单位
|
||||||
metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24*7, interval_hours=24)
|
metrics_data["7d"] = self._collect_metrics_interval_data(now, hours=24 * 7, interval_hours=24)
|
||||||
|
|
||||||
# 30天尺度:1天为单位
|
# 30天尺度:1天为单位
|
||||||
metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24*30, interval_hours=24)
|
metrics_data["30d"] = self._collect_metrics_interval_data(now, hours=24 * 30, interval_hours=24)
|
||||||
|
|
||||||
return metrics_data
|
return metrics_data
|
||||||
|
|
||||||
|
|
@ -1809,7 +1834,11 @@ class StatisticOutputTask(AsyncTask):
|
||||||
total_online_hours = [0.0] * len(time_points)
|
total_online_hours = [0.0] * len(time_points)
|
||||||
|
|
||||||
# 获取bot的QQ账号
|
# 获取bot的QQ账号
|
||||||
bot_qq_account = str(global_config.bot.qq_account) if hasattr(global_config, 'bot') and hasattr(global_config.bot, 'qq_account') else ""
|
bot_qq_account = (
|
||||||
|
str(global_config.bot.qq_account)
|
||||||
|
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
interval_seconds = interval_hours * 3600
|
interval_seconds = interval_hours * 3600
|
||||||
|
|
||||||
|
|
@ -1867,19 +1896,19 @@ class StatisticOutputTask(AsyncTask):
|
||||||
for idx in range(len(time_points)):
|
for idx in range(len(time_points)):
|
||||||
# 花费/消息数量(每100条)
|
# 花费/消息数量(每100条)
|
||||||
if total_messages[idx] > 0:
|
if total_messages[idx] > 0:
|
||||||
cost_per_100_messages[idx] = (total_costs[idx] / total_messages[idx] * 100)
|
cost_per_100_messages[idx] = total_costs[idx] / total_messages[idx] * 100
|
||||||
|
|
||||||
# 花费/时间(每小时)
|
# 花费/时间(每小时)
|
||||||
if total_online_hours[idx] > 0:
|
if total_online_hours[idx] > 0:
|
||||||
cost_per_hour[idx] = (total_costs[idx] / total_online_hours[idx])
|
cost_per_hour[idx] = total_costs[idx] / total_online_hours[idx]
|
||||||
|
|
||||||
# Token/时间(每小时)
|
# Token/时间(每小时)
|
||||||
if total_online_hours[idx] > 0:
|
if total_online_hours[idx] > 0:
|
||||||
tokens_per_hour[idx] = (total_tokens[idx] / total_online_hours[idx])
|
tokens_per_hour[idx] = total_tokens[idx] / total_online_hours[idx]
|
||||||
|
|
||||||
# 花费/回复数量(每100条)
|
# 花费/回复数量(每100条)
|
||||||
if total_replies[idx] > 0:
|
if total_replies[idx] > 0:
|
||||||
cost_per_100_replies[idx] = (total_costs[idx] / total_replies[idx] * 100)
|
cost_per_100_replies[idx] = total_costs[idx] / total_replies[idx] * 100
|
||||||
|
|
||||||
# 生成时间标签
|
# 生成时间标签
|
||||||
if interval_hours == 1:
|
if interval_hours == 1:
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,11 @@ import time
|
||||||
import jieba
|
import jieba
|
||||||
import json
|
import json
|
||||||
import ast
|
import ast
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from collections import Counter
|
|
||||||
from typing import Optional, Tuple, List, TYPE_CHECKING
|
from typing import Optional, Tuple, List, TYPE_CHECKING
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.message_repository import find_messages, count_messages
|
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
@ -146,7 +143,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
|
||||||
elif current_account:
|
elif current_account:
|
||||||
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\):(.+?)\],说:", text):
|
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\):(.+?)\],说:", text):
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text):
|
elif re.search(
|
||||||
|
rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text
|
||||||
|
):
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
|
|
||||||
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
|
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
|
||||||
|
|
@ -185,7 +184,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||||
"""将文本分割成句子,并根据概率合并
|
"""将文本分割成句子,并根据概率合并
|
||||||
1. 识别分割点(, , 。 ; 空格),但如果分割点左右都是英文字母则不分割。
|
1. 识别分割点(, , 。 ; 空格),但如果分割点左右都是英文字母则不分割。
|
||||||
|
|
@ -227,7 +225,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||||
prev_char = text[i - 1]
|
prev_char = text[i - 1]
|
||||||
next_char = text[i + 1]
|
next_char = text[i + 1]
|
||||||
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
|
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
|
||||||
if char == ' ':
|
if char == " ":
|
||||||
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
|
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
|
||||||
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
|
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
|
||||||
if prev_is_alnum and next_is_alnum:
|
if prev_is_alnum and next_is_alnum:
|
||||||
|
|
@ -340,7 +338,7 @@ def _get_random_default_reply() -> str:
|
||||||
"不知道",
|
"不知道",
|
||||||
"不晓得",
|
"不晓得",
|
||||||
"懒得说",
|
"懒得说",
|
||||||
"()"
|
"()",
|
||||||
]
|
]
|
||||||
return random.choice(default_replies)
|
return random.choice(default_replies)
|
||||||
|
|
||||||
|
|
@ -469,7 +467,6 @@ def calculate_typing_time(
|
||||||
return total_time # 加上回车时间
|
return total_time # 加上回车时间
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def truncate_message(message: str, max_length=20) -> str:
|
def truncate_message(message: str, max_length=20) -> str:
|
||||||
"""截断消息,使其不超过指定长度"""
|
"""截断消息,使其不超过指定长度"""
|
||||||
return f"{message[:max_length]}..." if len(message) > max_length else message
|
return f"{message[:max_length]}..." if len(message) > max_length else message
|
||||||
|
|
@ -546,7 +543,6 @@ def get_western_ratio(paragraph):
|
||||||
return western_count / len(alnum_chars)
|
return western_count / len(alnum_chars)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
|
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
|
||||||
"""将时间戳转换为人类可读的时间格式
|
"""将时间戳转换为人类可读的时间格式
|
||||||
|
|
|
||||||
|
|
@ -103,14 +103,16 @@ class ImageManager:
|
||||||
invalid_values = ["", "None"]
|
invalid_values = ["", "None"]
|
||||||
|
|
||||||
# 清理 Images 表
|
# 清理 Images 表
|
||||||
deleted_images = Images.delete().where(
|
deleted_images = (
|
||||||
(Images.description >> None) | (Images.description << invalid_values)
|
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
|
||||||
).execute()
|
)
|
||||||
|
|
||||||
# 清理 ImageDescriptions 表
|
# 清理 ImageDescriptions 表
|
||||||
deleted_descriptions = ImageDescriptions.delete().where(
|
deleted_descriptions = (
|
||||||
(ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values)
|
ImageDescriptions.delete()
|
||||||
).execute()
|
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
if deleted_images or deleted_descriptions:
|
if deleted_images or deleted_descriptions:
|
||||||
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions} 条")
|
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions} 条")
|
||||||
|
|
|
||||||
|
|
@ -220,7 +220,7 @@ class DatabaseActionRecords(BaseDataModel):
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
chat_info_stream_id: str,
|
chat_info_stream_id: str,
|
||||||
chat_info_platform: str,
|
chat_info_platform: str,
|
||||||
action_reasoning:str
|
action_reasoning: str,
|
||||||
):
|
):
|
||||||
self.action_id = action_id
|
self.action_id = action_id
|
||||||
self.time = time
|
self.time = time
|
||||||
|
|
|
||||||
|
|
@ -317,10 +317,12 @@ class Expression(BaseModel):
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "expression"
|
table_name = "expression"
|
||||||
|
|
||||||
|
|
||||||
class Jargon(BaseModel):
|
class Jargon(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储俚语的模型
|
用于存储俚语的模型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content = TextField()
|
content = TextField()
|
||||||
raw_content = TextField(null=True)
|
raw_content = TextField(null=True)
|
||||||
type = TextField(null=True)
|
type = TextField(null=True)
|
||||||
|
|
@ -336,10 +338,12 @@ class Jargon(BaseModel):
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "jargon"
|
table_name = "jargon"
|
||||||
|
|
||||||
|
|
||||||
class ChatHistory(BaseModel):
|
class ChatHistory(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储聊天历史概括的模型
|
用于存储聊天历史概括的模型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
chat_id = TextField(index=True) # 聊天ID
|
chat_id = TextField(index=True) # 聊天ID
|
||||||
start_time = DoubleField() # 起始时间
|
start_time = DoubleField() # 起始时间
|
||||||
end_time = DoubleField() # 结束时间
|
end_time = DoubleField() # 结束时间
|
||||||
|
|
@ -359,6 +363,7 @@ class ThinkingBack(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储记忆检索思考过程的模型
|
用于存储记忆检索思考过程的模型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
chat_id = TextField(index=True) # 聊天ID
|
chat_id = TextField(index=True) # 聊天ID
|
||||||
question = TextField() # 提出的问题
|
question = TextField() # 提出的问题
|
||||||
context = TextField(null=True) # 上下文信息
|
context = TextField(null=True) # 上下文信息
|
||||||
|
|
@ -371,6 +376,7 @@ class ThinkingBack(BaseModel):
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "thinking_back"
|
table_name = "thinking_back"
|
||||||
|
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
ChatStreams,
|
ChatStreams,
|
||||||
LLMUsage,
|
LLMUsage,
|
||||||
|
|
@ -387,6 +393,7 @@ MODELS = [
|
||||||
ThinkingBack,
|
ThinkingBack,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_tables():
|
def create_tables():
|
||||||
"""
|
"""
|
||||||
创建所有在模型中定义的数据库表。
|
创建所有在模型中定义的数据库表。
|
||||||
|
|
|
||||||
|
|
@ -311,6 +311,7 @@ class MessageReceiveConfig(ConfigBase):
|
||||||
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||||
"""过滤正则表达式列表"""
|
"""过滤正则表达式列表"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryConfig(ConfigBase):
|
class MemoryConfig(ConfigBase):
|
||||||
"""记忆配置类"""
|
"""记忆配置类"""
|
||||||
|
|
@ -321,6 +322,7 @@ class MemoryConfig(ConfigBase):
|
||||||
memory_build_frequency: int = 1
|
memory_build_frequency: int = 1
|
||||||
"""记忆构建频率"""
|
"""记忆构建频率"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExpressionConfig(ConfigBase):
|
class ExpressionConfig(ConfigBase):
|
||||||
"""表达配置类"""
|
"""表达配置类"""
|
||||||
|
|
@ -501,6 +503,7 @@ class MoodConfig(ConfigBase):
|
||||||
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
|
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
|
||||||
"""情感特征,影响情绪的变化情况"""
|
"""情感特征,影响情绪的变化情况"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VoiceConfig(ConfigBase):
|
class VoiceConfig(ConfigBase):
|
||||||
"""语音识别配置类"""
|
"""语音识别配置类"""
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import difflib
|
||||||
import random
|
import random
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
|
|
||||||
def filter_message_content(content: Optional[str]) -> str:
|
def filter_message_content(content: Optional[str]) -> str:
|
||||||
|
|
@ -20,13 +19,13 @@ def filter_message_content(content: Optional[str]) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||||
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
|
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||||
# 移除@<...>格式的内容
|
# 移除@<...>格式的内容
|
||||||
content = re.sub(r'@<[^>]*>', '', content)
|
content = re.sub(r"@<[^>]*>", "", content)
|
||||||
# 移除[picid:...]格式的图片ID
|
# 移除[picid:...]格式的图片ID
|
||||||
content = re.sub(r'\[picid:[^\]]*\]', '', content)
|
content = re.sub(r"\[picid:[^\]]*\]", "", content)
|
||||||
# 移除[表情包:...]格式的内容
|
# 移除[表情包:...]格式的内容
|
||||||
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
|
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||||
|
|
||||||
return content.strip()
|
return content.strip()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
import traceback
|
import traceback
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
@ -158,8 +157,6 @@ class ExpressionLearner:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
学习并存储表达方式
|
学习并存储表达方式
|
||||||
|
|
@ -195,9 +192,7 @@ class ExpressionLearner:
|
||||||
) in learnt_expressions:
|
) in learnt_expressions:
|
||||||
# 查找是否已存在相似表达方式
|
# 查找是否已存在相似表达方式
|
||||||
query = Expression.select().where(
|
query = Expression.select().where(
|
||||||
(Expression.chat_id == self.chat_id)
|
(Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style)
|
||||||
& (Expression.situation == situation)
|
|
||||||
& (Expression.style == style)
|
|
||||||
)
|
)
|
||||||
if query.exists():
|
if query.exists():
|
||||||
# 表达方式完全相同,只更新时间戳
|
# 表达方式完全相同,只更新时间戳
|
||||||
|
|
@ -222,19 +217,17 @@ class ExpressionLearner:
|
||||||
learner.add_style(style, situation)
|
learner.add_style(style, situation)
|
||||||
|
|
||||||
# 学习映射关系
|
# 学习映射关系
|
||||||
success = style_learner_manager.learn_mapping(
|
success = style_learner_manager.learn_mapping(self.chat_id, up_content, style)
|
||||||
self.chat_id,
|
|
||||||
up_content,
|
|
||||||
style
|
|
||||||
)
|
|
||||||
if success:
|
if success:
|
||||||
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
|
logger.debug(
|
||||||
|
f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}"
|
||||||
|
+ (f" (situation: {situation})" if situation else "")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
|
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
|
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
|
||||||
|
|
||||||
|
|
||||||
# 保存当前聊天室的 style_learner 模型
|
# 保存当前聊天室的 style_learner 模型
|
||||||
if has_new_expressions:
|
if has_new_expressions:
|
||||||
try:
|
try:
|
||||||
|
|
@ -367,9 +360,7 @@ class ExpressionLearner:
|
||||||
|
|
||||||
return matched_expressions
|
return matched_expressions
|
||||||
|
|
||||||
async def learn_expression(
|
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||||
self, num: int = 10
|
|
||||||
) -> Optional[List[Tuple[str, str, str, str]]]:
|
|
||||||
"""从指定聊天流学习表达方式
|
"""从指定聊天流学习表达方式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -409,7 +400,6 @@ class ExpressionLearner:
|
||||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||||
# logger.debug(f"学习{type_str}的response: {response}")
|
# logger.debug(f"学习{type_str}的response: {response}")
|
||||||
|
|
||||||
|
|
||||||
# 对表达方式溯源
|
# 对表达方式溯源
|
||||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||||
expressions, random_msg_match_str
|
expressions, random_msg_match_str
|
||||||
|
|
@ -449,7 +439,6 @@ class ExpressionLearner:
|
||||||
|
|
||||||
return filtered_with_up
|
return filtered_with_up
|
||||||
|
|
||||||
|
|
||||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
@ -115,7 +113,9 @@ class ExpressionSelector:
|
||||||
return group_chat_ids
|
return group_chat_ids
|
||||||
return [chat_id]
|
return [chat_id]
|
||||||
|
|
||||||
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
|
def get_model_predicted_expressions(
|
||||||
|
self, chat_id: str, target_message: str, total_num: int = 10
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
使用 style_learner 模型预测最合适的表达方式
|
使用 style_learner 模型预测最合适的表达方式
|
||||||
|
|
||||||
|
|
@ -136,7 +136,6 @@ class ExpressionSelector:
|
||||||
# 支持多chat_id合并预测
|
# 支持多chat_id合并预测
|
||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
|
|
||||||
|
|
||||||
predicted_expressions = []
|
predicted_expressions = []
|
||||||
|
|
||||||
# 为每个相关的chat_id进行预测
|
# 为每个相关的chat_id进行预测
|
||||||
|
|
@ -155,25 +154,31 @@ class ExpressionSelector:
|
||||||
if style_id and situation:
|
if style_id and situation:
|
||||||
# 从数据库查找对应的表达记录
|
# 从数据库查找对应的表达记录
|
||||||
expr_query = Expression.select().where(
|
expr_query = Expression.select().where(
|
||||||
(Expression.chat_id == related_chat_id) &
|
(Expression.chat_id == related_chat_id)
|
||||||
(Expression.situation == situation) &
|
& (Expression.situation == situation)
|
||||||
(Expression.style == best_style)
|
& (Expression.style == best_style)
|
||||||
)
|
)
|
||||||
|
|
||||||
if expr_query.exists():
|
if expr_query.exists():
|
||||||
expr = expr_query.get()
|
expr = expr_query.get()
|
||||||
predicted_expressions.append({
|
predicted_expressions.append(
|
||||||
"id": expr.id,
|
{
|
||||||
"situation": expr.situation,
|
"id": expr.id,
|
||||||
"style": expr.style,
|
"situation": expr.situation,
|
||||||
"last_active_time": expr.last_active_time,
|
"style": expr.style,
|
||||||
"source_id": expr.chat_id,
|
"last_active_time": expr.last_active_time,
|
||||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
"source_id": expr.chat_id,
|
||||||
"prediction_score": scores.get(best_style, 0.0),
|
"create_date": expr.create_date
|
||||||
"prediction_input": filtered_target_message
|
if expr.create_date is not None
|
||||||
})
|
else expr.last_active_time,
|
||||||
|
"prediction_score": scores.get(best_style, 0.0),
|
||||||
|
"prediction_input": filtered_target_message,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
|
logger.warning(
|
||||||
|
f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
|
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
|
||||||
|
|
@ -207,9 +212,7 @@ class ExpressionSelector:
|
||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
|
|
||||||
# 优化:一次性查询所有相关chat_id的表达方式
|
# 优化:一次性查询所有相关chat_id的表达方式
|
||||||
style_query = Expression.select().where(
|
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)))
|
||||||
(Expression.chat_id.in_(related_chat_ids))
|
|
||||||
)
|
|
||||||
|
|
||||||
style_exprs = [
|
style_exprs = [
|
||||||
{
|
{
|
||||||
|
|
@ -236,7 +239,6 @@ class ExpressionSelector:
|
||||||
logger.error(f"随机选择表达方式失败: {e}")
|
logger.error(f"随机选择表达方式失败: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
async def select_suitable_expressions(
|
async def select_suitable_expressions(
|
||||||
self,
|
self,
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
|
|
@ -425,17 +427,13 @@ class ExpressionSelector:
|
||||||
updates_by_key[key] = expr
|
updates_by_key[key] = expr
|
||||||
for chat_id, situation, style in updates_by_key:
|
for chat_id, situation, style in updates_by_key:
|
||||||
query = Expression.select().where(
|
query = Expression.select().where(
|
||||||
(Expression.chat_id == chat_id)
|
(Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
|
||||||
& (Expression.situation == situation)
|
|
||||||
& (Expression.style == style)
|
|
||||||
)
|
)
|
||||||
if query.exists():
|
if query.exists():
|
||||||
expr_obj = query.get()
|
expr_obj = query.get()
|
||||||
expr_obj.last_active_time = time.time()
|
expr_obj.last_active_time = time.time()
|
||||||
expr_obj.save()
|
expr_obj.save()
|
||||||
logger.debug(
|
logger.debug("表达方式激活: 更新last_active_time in db")
|
||||||
"表达方式激活: 更新last_active_time in db"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,21 @@ import os
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .online_nb import OnlineNaiveBayes
|
from .online_nb import OnlineNaiveBayes
|
||||||
|
|
||||||
|
|
||||||
class ExpressorModel:
|
class ExpressorModel:
|
||||||
"""
|
"""
|
||||||
直接使用朴素贝叶斯精排(可在线学习)
|
直接使用朴素贝叶斯精排(可在线学习)
|
||||||
支持存储situation字段,不参与计算,仅与style对应
|
支持存储situation字段,不参与计算,仅与style对应
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
alpha: float = 0.5,
|
self,
|
||||||
beta: float = 0.5,
|
alpha: float = 0.5,
|
||||||
gamma: float = 1.0,
|
beta: float = 0.5,
|
||||||
vocab_size: int = 200000,
|
gamma: float = 1.0,
|
||||||
use_jieba: bool = True):
|
vocab_size: int = 200000,
|
||||||
|
use_jieba: bool = True,
|
||||||
|
):
|
||||||
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
|
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
|
||||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||||||
|
|
@ -96,25 +99,27 @@ class ExpressorModel:
|
||||||
|
|
||||||
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||||
"""获取所有候选的style和situation信息"""
|
"""获取所有候选的style和situation信息"""
|
||||||
return {cid: (style, self._situations.get(cid))
|
return {cid: (style, self._situations.get(cid)) for cid, style in self._candidates.items()}
|
||||||
for cid, style in self._candidates.items()}
|
|
||||||
|
|
||||||
def save(self, path: str):
|
def save(self, path: str):
|
||||||
"""保存模型"""
|
"""保存模型"""
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
pickle.dump({
|
pickle.dump(
|
||||||
"candidates": self._candidates,
|
{
|
||||||
"situations": self._situations,
|
"candidates": self._candidates,
|
||||||
"nb": {
|
"situations": self._situations,
|
||||||
"cls_counts": dict(self.nb.cls_counts),
|
"nb": {
|
||||||
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
|
"cls_counts": dict(self.nb.cls_counts),
|
||||||
"alpha": self.nb.alpha,
|
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
|
||||||
"beta": self.nb.beta,
|
"alpha": self.nb.alpha,
|
||||||
"gamma": self.nb.gamma,
|
"beta": self.nb.beta,
|
||||||
"V": self.nb.V,
|
"gamma": self.nb.gamma,
|
||||||
}
|
"V": self.nb.V,
|
||||||
}, f)
|
},
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
def load(self, path: str):
|
def load(self, path: str):
|
||||||
"""加载模型"""
|
"""加载模型"""
|
||||||
|
|
@ -133,8 +138,10 @@ class ExpressorModel:
|
||||||
self.nb.V = obj["nb"]["V"]
|
self.nb.V = obj["nb"]["V"]
|
||||||
self.nb._logZ.clear()
|
self.nb._logZ.clear()
|
||||||
|
|
||||||
|
|
||||||
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
|
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
outer = defaultdict(lambda: defaultdict(float))
|
outer = defaultdict(lambda: defaultdict(float))
|
||||||
for k, inner in d.items():
|
for k, inner in d.items():
|
||||||
outer[k].update(inner)
|
outer[k].update(inner)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import math
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
|
|
||||||
|
|
||||||
class OnlineNaiveBayes:
|
class OnlineNaiveBayes:
|
||||||
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
|
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
@ -9,9 +10,9 @@ class OnlineNaiveBayes:
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.V = vocab_size
|
self.V = vocab_size
|
||||||
|
|
||||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
|
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
|
||||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||||
|
|
||||||
def _invalidate(self, cid: str):
|
def _invalidate(self, cid: str):
|
||||||
if cid in self._logZ:
|
if cid in self._logZ:
|
||||||
|
|
|
||||||
|
|
@ -3,17 +3,20 @@ from typing import List, Optional, Set
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jieba
|
import jieba
|
||||||
|
|
||||||
_HAS_JIEBA = True
|
_HAS_JIEBA = True
|
||||||
except Exception:
|
except Exception:
|
||||||
_HAS_JIEBA = False
|
_HAS_JIEBA = False
|
||||||
|
|
||||||
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||||
# 匹配纯符号的正则表达式
|
# 匹配纯符号的正则表达式
|
||||||
_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$')
|
_SYMBOL_RE = re.compile(r"^[^\w\u4e00-\u9fff]+$")
|
||||||
|
|
||||||
|
|
||||||
def simple_en_tokenize(text: str) -> List[str]:
|
def simple_en_tokenize(text: str) -> List[str]:
|
||||||
return _WORD_RE.findall(text.lower())
|
return _WORD_RE.findall(text.lower())
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
|
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
|
||||||
self.stopwords = stopwords or set()
|
self.stopwords = stopwords or set()
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class StyleLearner:
|
||||||
"beta": 0.5,
|
"beta": 0.5,
|
||||||
"gamma": 0.99, # 衰减因子,支持遗忘
|
"gamma": 0.99, # 衰减因子,支持遗忘
|
||||||
"vocab_size": 200000,
|
"vocab_size": 200000,
|
||||||
"use_jieba": True
|
"use_jieba": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 初始化表达模型
|
# 初始化表达模型
|
||||||
|
|
@ -47,7 +47,7 @@ class StyleLearner:
|
||||||
"total_samples": 0,
|
"total_samples": 0,
|
||||||
"style_counts": defaultdict(int),
|
"style_counts": defaultdict(int),
|
||||||
"last_update": None,
|
"last_update": None,
|
||||||
"style_usage_frequency": defaultdict(int) # 风格使用频率
|
"style_usage_frequency": defaultdict(int), # 风格使用频率
|
||||||
}
|
}
|
||||||
|
|
||||||
def add_style(self, style: str, situation: str = None) -> bool:
|
def add_style(self, style: str, situation: str = None) -> bool:
|
||||||
|
|
@ -80,8 +80,10 @@ class StyleLearner:
|
||||||
# 添加到expressor模型
|
# 添加到expressor模型
|
||||||
self.expressor.add_candidate(style_id, style, situation)
|
self.expressor.add_candidate(style_id, style, situation)
|
||||||
|
|
||||||
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
|
logger.info(
|
||||||
(f", situation: '{situation}'" if situation else ""))
|
f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})"
|
||||||
|
+ (f", situation: '{situation}'" if situation else "")
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -341,7 +343,7 @@ class StyleLearner:
|
||||||
"style_counts": dict(self.learning_stats["style_counts"]),
|
"style_counts": dict(self.learning_stats["style_counts"]),
|
||||||
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
|
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
|
||||||
"last_update": self.learning_stats["last_update"],
|
"last_update": self.learning_stats["last_update"],
|
||||||
"all_styles": list(self.style_to_id.keys())
|
"all_styles": list(self.style_to_id.keys()),
|
||||||
}
|
}
|
||||||
|
|
||||||
def save(self, base_path: str) -> bool:
|
def save(self, base_path: str) -> bool:
|
||||||
|
|
@ -362,7 +364,7 @@ class StyleLearner:
|
||||||
"id_to_style": self.id_to_style,
|
"id_to_style": self.id_to_style,
|
||||||
"id_to_situation": self.id_to_situation,
|
"id_to_situation": self.id_to_situation,
|
||||||
"next_style_id": self.next_style_id,
|
"next_style_id": self.next_style_id,
|
||||||
"learning_stats": self.learning_stats
|
"learning_stats": self.learning_stats,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 先保存expressor模型
|
# 先保存expressor模型
|
||||||
|
|
|
||||||
|
|
@ -3,5 +3,3 @@ from .jargon_miner import extract_and_store_jargon
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_and_store_jargon",
|
"extract_and_store_jargon",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -358,10 +358,7 @@ async def _default_stream_response_handler(
|
||||||
model_dbg = None
|
model_dbg = None
|
||||||
|
|
||||||
# 统一日志格式
|
# 统一日志格式
|
||||||
logger.info(
|
logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (model_dbg or ""))
|
||||||
"模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整"
|
|
||||||
% (model_dbg or "")
|
|
||||||
)
|
|
||||||
|
|
||||||
return resp, _usage_record
|
return resp, _usage_record
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -404,9 +401,7 @@ def _default_normal_response_parser(
|
||||||
raw_snippet = str(resp)[:300]
|
raw_snippet = str(resp)[:300]
|
||||||
except Exception:
|
except Exception:
|
||||||
raw_snippet = "<unserializable>"
|
raw_snippet = "<unserializable>"
|
||||||
logger.debug(
|
logger.debug(f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}")
|
||||||
f"empty choices: model={model_dbg} id={id_dbg} usage={usage_dbg} raw≈{raw_snippet}"
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# 日志采集失败不应影响控制流
|
# 日志采集失败不应影响控制流
|
||||||
pass
|
pass
|
||||||
|
|
@ -464,10 +459,7 @@ def _default_normal_response_parser(
|
||||||
# print(resp)
|
# print(resp)
|
||||||
_model_name = resp.model
|
_model_name = resp.model
|
||||||
# 统一日志格式
|
# 统一日志格式
|
||||||
logger.info(
|
logger.info("模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整" % (_model_name or ""))
|
||||||
"模型%s因为超过最大max_token限制,可能仅输出部分内容,可视情况调整"
|
|
||||||
% (_model_name or "")
|
|
||||||
)
|
|
||||||
return api_response, _usage_record
|
return api_response, _usage_record
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
|
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
|
||||||
|
|
|
||||||
|
|
@ -328,9 +328,7 @@ class LLMRequest:
|
||||||
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。")
|
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。")
|
||||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}")
|
||||||
f"模型 '{model_info.name}' 返回空回复(可重试)。剩余重试次数: {retry_remain}"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(api_provider.retry_interval)
|
await asyncio.sleep(api_provider.retry_interval)
|
||||||
|
|
||||||
except NetworkConnectionError as e:
|
except NetworkConnectionError as e:
|
||||||
|
|
@ -340,9 +338,7 @@ class LLMRequest:
|
||||||
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。")
|
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。")
|
||||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}")
|
||||||
f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}。剩余重试次数: {retry_remain}"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(api_provider.retry_interval)
|
await asyncio.sleep(api_provider.retry_interval)
|
||||||
|
|
||||||
except RespNotOkException as e:
|
except RespNotOkException as e:
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from maim_message import MessageServer
|
||||||
from src.common.remote import TelemetryHeartBeatTask
|
from src.common.remote import TelemetryHeartBeatTask
|
||||||
from src.manager.async_task_manager import async_task_manager
|
from src.manager.async_task_manager import async_task_manager
|
||||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
|
|
||||||
# from src.chat.utils.token_statistics import TokenStatisticsTask
|
# from src.chat.utils.token_statistics import TokenStatisticsTask
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
@ -73,6 +74,7 @@ class MainSystem:
|
||||||
|
|
||||||
# 添加记忆遗忘任务
|
# 添加记忆遗忘任务
|
||||||
from src.chat.utils.memory_forget_task import MemoryForgetTask
|
from src.chat.utils.memory_forget_task import MemoryForgetTask
|
||||||
|
|
||||||
await async_task_manager.add_task(MemoryForgetTask())
|
await async_task_manager.add_task(MemoryForgetTask())
|
||||||
|
|
||||||
# 启动API服务器
|
# 启动API服务器
|
||||||
|
|
@ -106,7 +108,6 @@ class MainSystem:
|
||||||
self.app.register_message_handler(chat_bot.message_process)
|
self.app.register_message_handler(chat_bot.message_process)
|
||||||
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
|
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
|
||||||
|
|
||||||
|
|
||||||
# 触发 ON_START 事件
|
# 触发 ON_START 事件
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.plugin_system.core.events_manager import events_manager
|
||||||
from src.plugin_system.base.component_types import EventType
|
from src.plugin_system.base.component_types import EventType
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,7 @@ from .tool_registry import register_memory_retrieval_tool
|
||||||
logger = get_logger("memory_retrieval_tools")
|
logger = get_logger("memory_retrieval_tools")
|
||||||
|
|
||||||
|
|
||||||
async def query_jargon(
|
async def query_jargon(keyword: str, chat_id: str) -> str:
|
||||||
keyword: str,
|
|
||||||
chat_id: str
|
|
||||||
) -> str:
|
|
||||||
"""根据关键词在jargon库中查询
|
"""根据关键词在jargon库中查询
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -28,25 +25,13 @@ async def query_jargon(
|
||||||
return "关键词为空"
|
return "关键词为空"
|
||||||
|
|
||||||
# 先尝试精确匹配
|
# 先尝试精确匹配
|
||||||
results = search_jargon(
|
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||||
keyword=content,
|
|
||||||
chat_id=chat_id,
|
|
||||||
limit=10,
|
|
||||||
case_sensitive=False,
|
|
||||||
fuzzy=False
|
|
||||||
)
|
|
||||||
|
|
||||||
is_fuzzy_match = False
|
is_fuzzy_match = False
|
||||||
|
|
||||||
# 如果精确匹配未找到,尝试模糊搜索
|
# 如果精确匹配未找到,尝试模糊搜索
|
||||||
if not results:
|
if not results:
|
||||||
results = search_jargon(
|
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||||
keyword=content,
|
|
||||||
chat_id=chat_id,
|
|
||||||
limit=10,
|
|
||||||
case_sensitive=False,
|
|
||||||
fuzzy=True
|
|
||||||
)
|
|
||||||
is_fuzzy_match = True
|
is_fuzzy_match = True
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
|
|
@ -86,14 +71,6 @@ def register_tool():
|
||||||
register_memory_retrieval_tool(
|
register_memory_retrieval_tool(
|
||||||
name="query_jargon",
|
name="query_jargon",
|
||||||
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索,默认会先尝试精确匹配,如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
|
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索,默认会先尝试精确匹配,如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
|
||||||
parameters=[
|
parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
|
||||||
{
|
execute_func=query_jargon,
|
||||||
"name": "keyword",
|
|
||||||
"type": "string",
|
|
||||||
"description": "关键词(黑话/俚语/缩写)",
|
|
||||||
"required": True
|
|
||||||
}
|
|
||||||
],
|
|
||||||
execute_func=query_jargon
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,7 @@ class MemoryRetrievalTool:
|
||||||
"""记忆检索工具基类"""
|
"""记忆检索工具基类"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
parameters: List[Dict[str, Any]],
|
|
||||||
execute_func: Callable[..., Awaitable[str]]
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化工具
|
初始化工具
|
||||||
|
|
@ -145,10 +141,7 @@ _tool_registry = MemoryRetrievalToolRegistry()
|
||||||
|
|
||||||
|
|
||||||
def register_memory_retrieval_tool(
|
def register_memory_retrieval_tool(
|
||||||
name: str,
|
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||||||
description: str,
|
|
||||||
parameters: List[Dict[str, Any]],
|
|
||||||
execute_func: Callable[..., Awaitable[str]]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""注册记忆检索工具的便捷函数
|
"""注册记忆检索工具的便捷函数
|
||||||
|
|
||||||
|
|
@ -165,4 +158,3 @@ def register_memory_retrieval_tool(
|
||||||
def get_tool_registry() -> MemoryRetrievalToolRegistry:
|
def get_tool_registry() -> MemoryRetrievalToolRegistry:
|
||||||
"""获取工具注册器实例"""
|
"""获取工具注册器实例"""
|
||||||
return _tool_registry
|
return _tool_registry
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,7 @@
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.message_receive.message import MessageRecv
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@ logger = get_logger("frequency_api")
|
||||||
|
|
||||||
|
|
||||||
def get_current_talk_value(chat_id: str) -> float:
|
def get_current_talk_value(chat_id: str) -> float:
|
||||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
|
return frequency_control_manager.get_or_create_frequency_control(
|
||||||
|
chat_id
|
||||||
|
).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
|
||||||
|
|
||||||
|
|
||||||
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
|
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,7 @@ def get_messages_by_time_in_chat(
|
||||||
limit=limit,
|
limit=limit,
|
||||||
limit_mode=limit_mode,
|
limit_mode=limit_mode,
|
||||||
filter_bot=filter_mai,
|
filter_bot=filter_mai,
|
||||||
filter_command=filter_command
|
filter_command=filter_command,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ class BaseAction(ABC):
|
||||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||||
|
|
||||||
"""NORMAL模式下的激活类型"""
|
"""NORMAL模式下的激活类型"""
|
||||||
self.activation_type = getattr(self.__class__, "activation_type")
|
self.activation_type = self.__class__.activation_type
|
||||||
"""激活类型"""
|
"""激活类型"""
|
||||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||||
"""当激活类型为RANDOM时的概率"""
|
"""当激活类型为RANDOM时的概率"""
|
||||||
|
|
@ -108,16 +108,11 @@ class BaseAction(ABC):
|
||||||
self.is_group = False
|
self.is_group = False
|
||||||
self.target_id = None
|
self.target_id = None
|
||||||
|
|
||||||
|
|
||||||
self.group_id = (
|
self.group_id = (
|
||||||
str(self.action_message.chat_info.group_info.group_id)
|
str(self.action_message.chat_info.group_info.group_id) if self.action_message.chat_info.group_info else None
|
||||||
if self.action_message.chat_info.group_info
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
self.group_name = (
|
self.group_name = (
|
||||||
self.action_message.chat_info.group_info.group_name
|
self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None
|
||||||
if self.action_message.chat_info.group_info
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.user_id = str(self.action_message.user_info.user_id)
|
self.user_id = str(self.action_message.user_info.user_id)
|
||||||
|
|
@ -132,7 +127,6 @@ class BaseAction(ABC):
|
||||||
self.target_id = self.user_id
|
self.target_id = self.user_id
|
||||||
self.log_prefix = f"[{self.user_nickname} 的 私聊]"
|
self.log_prefix = f"[{self.user_nickname} 的 私聊]"
|
||||||
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||||
)
|
)
|
||||||
|
|
@ -448,7 +442,6 @@ class BaseAction(ABC):
|
||||||
|
|
||||||
wait_start_time = asyncio.get_event_loop().time()
|
wait_start_time = asyncio.get_event_loop().time()
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
# 检查新消息
|
# 检查新消息
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
new_message_count = message_api.count_new_messages(
|
new_message_count = message_api.count_new_messages(
|
||||||
|
|
@ -497,7 +490,7 @@ class BaseAction(ABC):
|
||||||
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||||
# 获取focus_activation_type和normal_activation_type
|
# 获取focus_activation_type和normal_activation_type
|
||||||
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
|
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||||
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
|
_normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||||
|
|
||||||
# 处理activation_type:如果插件中声明了就用插件的值,否则默认使用focus_activation_type
|
# 处理activation_type:如果插件中声明了就用插件的值,否则默认使用focus_activation_type
|
||||||
activation_type = getattr(cls, "activation_type", focus_activation_type)
|
activation_type = getattr(cls, "activation_type", focus_activation_type)
|
||||||
|
|
|
||||||
|
|
@ -346,9 +346,7 @@ class EventsManager:
|
||||||
|
|
||||||
if not isinstance(result, tuple) or len(result) != 5:
|
if not isinstance(result, tuple) or len(result) != 5:
|
||||||
if isinstance(result, tuple):
|
if isinstance(result, tuple):
|
||||||
annotated = ", ".join(
|
annotated = ", ".join(f"{name}={val!r}" for name, val in zip(expected_fields, result, strict=False))
|
||||||
f"{name}={val!r}" for name, val in zip(expected_fields, result)
|
|
||||||
)
|
|
||||||
actual_desc = f"{len(result)} 个元素 ({annotated})"
|
actual_desc = f"{len(result)} 个元素 ({annotated})"
|
||||||
else:
|
else:
|
||||||
actual_desc = f"非 tuple 类型: {type(result)}"
|
actual_desc = f"非 tuple 类型: {type(result)}"
|
||||||
|
|
@ -380,7 +378,6 @@ class EventsManager:
|
||||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
||||||
return True, None # 发生异常时默认不中断其他处理
|
return True, None # 发生异常时默认不中断其他处理
|
||||||
|
|
||||||
|
|
||||||
def _task_done_callback(
|
def _task_done_callback(
|
||||||
self,
|
self,
|
||||||
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
|
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
|
||||||
|
|
|
||||||
|
|
@ -189,9 +189,8 @@ class ToolExecutor:
|
||||||
tool_info["content"] = str(content)
|
tool_info["content"] = str(content)
|
||||||
# 空内容直接跳过(空字符串、全空白字符串、空列表/空元组)
|
# 空内容直接跳过(空字符串、全空白字符串、空列表/空元组)
|
||||||
content_check = tool_info["content"]
|
content_check = tool_info["content"]
|
||||||
if (
|
if (isinstance(content_check, str) and not content_check.strip()) or (
|
||||||
(isinstance(content_check, str) and not content_check.strip())
|
isinstance(content_check, (list, tuple)) and len(content_check) == 0
|
||||||
or (isinstance(content_check, (list, tuple)) and len(content_check) == 0)
|
|
||||||
):
|
):
|
||||||
logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示")
|
logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示")
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
11
view_pkl.py
11
view_pkl.py
|
|
@ -8,6 +8,7 @@ import sys
|
||||||
import os
|
import os
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
def view_pkl_file(file_path):
|
def view_pkl_file(file_path):
|
||||||
"""查看 pkl 文件内容"""
|
"""查看 pkl 文件内容"""
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
|
|
@ -15,7 +16,7 @@ def view_pkl_file(file_path):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, "rb") as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
|
|
||||||
print(f"📁 文件: {file_path}")
|
print(f"📁 文件: {file_path}")
|
||||||
|
|
@ -44,10 +45,10 @@ def view_pkl_file(file_path):
|
||||||
pprint(data, width=120, depth=10)
|
pprint(data, width=120, depth=10)
|
||||||
|
|
||||||
# 如果是 expressor 模型,特别显示 token_counts 的详细信息
|
# 如果是 expressor 模型,特别显示 token_counts 的详细信息
|
||||||
if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']:
|
if isinstance(data, dict) and "nb" in data and "token_counts" in data["nb"]:
|
||||||
print("\n" + "="*50)
|
print("\n" + "=" * 50)
|
||||||
print("🔍 详细词汇统计 (token_counts):")
|
print("🔍 详细词汇统计 (token_counts):")
|
||||||
token_counts = data['nb']['token_counts']
|
token_counts = data["nb"]["token_counts"]
|
||||||
for style_id, tokens in token_counts.items():
|
for style_id, tokens in token_counts.items():
|
||||||
print(f"\n📝 {style_id}:")
|
print(f"\n📝 {style_id}:")
|
||||||
if tokens:
|
if tokens:
|
||||||
|
|
@ -63,6 +64,7 @@ def view_pkl_file(file_path):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ 读取文件失败: {e}")
|
print(f"❌ 读取文件失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if len(sys.argv) != 2:
|
if len(sys.argv) != 2:
|
||||||
print("用法: python view_pkl.py <pkl文件路径>")
|
print("用法: python view_pkl.py <pkl文件路径>")
|
||||||
|
|
@ -72,5 +74,6 @@ def main():
|
||||||
file_path = sys.argv[1]
|
file_path = sys.argv[1]
|
||||||
view_pkl_file(file_path)
|
view_pkl_file(file_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import pickle
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def view_token_counts(file_path):
|
def view_token_counts(file_path):
|
||||||
"""查看 expressor.pkl 文件中的词汇统计"""
|
"""查看 expressor.pkl 文件中的词汇统计"""
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
|
|
@ -14,18 +15,18 @@ def view_token_counts(file_path):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, "rb") as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
|
|
||||||
print(f"📁 文件: {file_path}")
|
print(f"📁 文件: {file_path}")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
if 'nb' not in data or 'token_counts' not in data['nb']:
|
if "nb" not in data or "token_counts" not in data["nb"]:
|
||||||
print("❌ 这不是一个 expressor 模型文件")
|
print("❌ 这不是一个 expressor 模型文件")
|
||||||
return
|
return
|
||||||
|
|
||||||
token_counts = data['nb']['token_counts']
|
token_counts = data["nb"]["token_counts"]
|
||||||
candidates = data.get('candidates', {})
|
candidates = data.get("candidates", {})
|
||||||
|
|
||||||
print(f"🎯 找到 {len(token_counts)} 个风格")
|
print(f"🎯 找到 {len(token_counts)} 个风格")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
@ -41,7 +42,7 @@ def view_token_counts(file_path):
|
||||||
|
|
||||||
print("🔤 词汇统计 (按频率排序):")
|
print("🔤 词汇统计 (按频率排序):")
|
||||||
for i, (word, count) in enumerate(sorted_tokens):
|
for i, (word, count) in enumerate(sorted_tokens):
|
||||||
print(f" {i+1:2d}. '{word}': {count}")
|
print(f" {i + 1:2d}. '{word}': {count}")
|
||||||
else:
|
else:
|
||||||
print(" (无词汇数据)")
|
print(" (无词汇数据)")
|
||||||
|
|
||||||
|
|
@ -50,6 +51,7 @@ def view_token_counts(file_path):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ 读取文件失败: {e}")
|
print(f"❌ 读取文件失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if len(sys.argv) != 2:
|
if len(sys.argv) != 2:
|
||||||
print("用法: python view_tokens.py <expressor.pkl文件路径>")
|
print("用法: python view_tokens.py <expressor.pkl文件路径>")
|
||||||
|
|
@ -59,5 +61,6 @@ def main():
|
||||||
file_path = sys.argv[1]
|
file_path = sys.argv[1]
|
||||||
view_token_counts(file_path)
|
view_token_counts(file_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue