Merge branch 'MaiM-with-u:dev' into dev

pull/1290/head
exynos 2025-10-12 19:08:21 +08:00 committed by GitHub
commit d06f4005a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
121 changed files with 8269 additions and 12267 deletions

1
.envrc
View File

@ -1 +0,0 @@
use flake

3
.gitattributes vendored
View File

@ -1,3 +1,2 @@
*.bat text eol=crlf
*.cmd text eol=crlf
MaiLauncher.bat text eol=crlf working-tree-encoding=GBK
*.cmd text eol=crlf

12
.gitignore vendored
View File

@ -20,6 +20,8 @@ MaiBot-Napcat-Adapter
nonebot-maibot-adapter/
MaiMBot-LPMM
*.zip
run_bot.bat
run_na.bat
run.bat
log_debug/
run_amds.bat
@ -41,16 +43,13 @@ config/bot_config.toml
config/bot_config.toml.bak
config/lpmm_config.toml
config/lpmm_config.toml.bak
src/mais4u/config/s4u_config.toml
src/mais4u/config/old
template/compare/bot_config_template.toml
template/compare/model_config_template.toml
(测试版)麦麦生成人格.bat
(临时版)麦麦开始学习.bat
src/plugins/utils/statistic.py
CLAUDE.md
s4u.s4u
s4u.s4u1
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@ -321,9 +320,14 @@ run_pet.bat
/plugins/*
!/plugins
!/plugins/hello_world_plugin
!/plugins/emoji_manage_plugin
!/plugins/take_picture_plugin
!/plugins/deep_think
!/plugins/ChatFrequency/
!/plugins/__init__.py
config.toml
interested_rates.txt
MaiBot.code-workspace
*.lock

21
bot.py
View File

@ -5,16 +5,29 @@ import sys
import time
import platform
import traceback
import shutil
from dotenv import load_dotenv
from pathlib import Path
from rich.traceback import install
if os.path.exists(".env"):
load_dotenv(".env", override=True)
env_path = Path(__file__).parent / ".env"
template_env_path = Path(__file__).parent / "template" / "template.env"
if env_path.exists():
load_dotenv(str(env_path), override=True)
print("成功加载环境变量配置")
else:
print("未找到.env文件请确保程序所需的环境变量被正确设置")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
try:
if template_env_path.exists():
shutil.copyfile(template_env_path, env_path)
print("未找到.env已从 template/template.env 自动创建")
load_dotenv(str(env_path), override=True)
else:
print("未找到.env文件也未找到模板 template/template.env")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
except Exception as e:
print(f"自动创建 .env 失败: {e}")
raise
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging

View File

@ -1,14 +1,27 @@
# Changelog
0.10.4饼 表达方式优化
无了
## [0.11.0] - 2025-9-22
### 🌟 主要功能更改
- 重构记忆系统,新的记忆系统更可靠,记忆能力更强大
- 麦麦好奇功能,麦麦会自主提出问题
- 添加deepthink插件默认关闭让麦麦可以深度思考一些问题
- 添加表情包管理插件
## [0.10.3] - 2025-9-1x
### 细节功能更改
- 修复配置文件转义问题
- 情绪系统现在可以由配置文件控制开关
- 修复平行动作控制失效的问题
- 添加planner防抖防止短时间快速消耗token
- 修复吞字问题
- 更新依赖表
- 修复负载均衡
- 优化了对gemini和不同模型的支持
## [0.10.3] - 2025-9-22
### 🌟 主要功能更改
- planner支持多动作移除Sub_planner
- 移除激活度系统现在回复完全由planner控制
- 现可自定义planner行为
- 更丰富的聊天行为
- 现可自定义planner行为更优化的聊天频率控制
- 支持发送转发和合并转发
- 关系现在支持多人的信息
- 更好的event系统正式建立
@ -20,6 +33,8 @@
- 优化识图token限制
- 为空回复添加重试机制
- 加入brainchat模式为私聊支持做准备
- 修复qq号格式
## [0.10.2] - 2025-8-31

View File

@ -1,51 +0,0 @@
# Changelog
## [1.0.3] - 2025-3-31
### Added
- 新增了心流相关配置项:
- `heartflow` 配置项,用于控制心流功能
### Removed
- 移除了 `response` 配置项中的 `model_r1_probability``model_v3_probability` 选项
- 移除了次级推理模型相关配置
## [1.0.1] - 2025-3-30
### Added
- 增加了流式输出控制项 `stream`
- 修复 `LLM_Request` 不会自动为 `payload` 增加流式输出标志的问题
## [1.0.0] - 2025-3-30
### Added
- 修复了错误的版本命名
- 杀掉了所有无关文件
## [0.0.11] - 2025-3-12
### Added
- 新增了 `schedule` 配置项,用于配置日程表生成功能
- 新增了 `response_splitter` 配置项,用于控制回复分割
- 新增了 `experimental` 配置项,用于实验性功能开关
- 新增了 `llm_observation``llm_sub_heartflow` 模型配置
- 新增了 `llm_heartflow` 模型配置
- 在 `personality` 配置项中新增了 `prompt_schedule_gen` 参数
### Changed
- 优化了模型配置的组织结构
- 调整了部分配置项的默认值
- 调整了配置项的顺序,将 `groups` 配置项移到了更靠前的位置
- 在 `message` 配置项中:
- 新增了 `model_max_output_length` 参数
- 在 `willing` 配置项中新增了 `emoji_response_penalty` 参数
- 将 `personality` 配置项中的 `prompt_schedule` 重命名为 `prompt_schedule_gen`
### Removed
- 移除了 `min_text_length` 配置项
- 移除了 `cq_code` 配置项
- 移除了 `others` 配置项(其功能已整合到 `experimental` 中)
## [0.0.5] - 2025-3-11
### Added
- 新增了 `alias_names` 配置项,用于指定麦麦的别名。
## [0.0.4] - 2025-3-9
### Added
- 新增了 `memory_ban_words` 配置项,用于指定不希望记忆的词汇。

View File

@ -28,7 +28,7 @@ version = "1.1.1"
```toml
[[api_providers]]
name = "DeepSeek" # 服务商名称(自定义)
base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
api_key = "your-api-key-here" # API密钥
client_type = "openai" # 客户端类型
max_retry = 2 # 最大重试次数
@ -43,19 +43,19 @@ retry_interval = 10 # 重试间隔(秒)
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
| `base_url` | ✅ | API服务的基础URL | - |
| `api_key` | ✅ | API密钥请替换为实际密钥 | - |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式,现在支持不良好 | `openai` |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式 | `openai` |
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
| `timeout` | ❌ | API请求超时时间 | 30 |
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。**
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
### 2.3 支持的服务商示例
#### DeepSeek
```toml
[[api_providers]]
name = "DeepSeek"
base_url = "https://api.deepseek.cn/v1"
base_url = "https://api.deepseek.com/v1"
api_key = "your-deepseek-api-key"
client_type = "openai"
```
@ -73,7 +73,7 @@ client_type = "openai"
```toml
[[api_providers]]
name = "Google"
base_url = "https://api.google.com/v1"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "your-google-api-key"
client_type = "gemini" # 注意Gemini需要使用特殊客户端
```
@ -131,9 +131,20 @@ enable_thinking = false # 禁用思考
[models.extra_params]
thinking = {type = "disabled"} # 禁用思考
```
而对于`gemini`需要单独进行配置
```toml
[[models]]
model_identifier = "gemini-2.5-flash"
name = "gemini-2.5-flash"
api_provider = "Google"
[models.extra_params]
thinking_budget = 0 # 禁用思考
# thinking_budget = -1 由模型自己决定
```
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构具体内容取决于API服务商的要求。
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
### 3.3 配置参数说明
| 参数 | 必填 | 说明 |

View File

@ -1,57 +0,0 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 0,
"narHash": "sha256-nJj8f78AYAxl/zqLiFGXn5Im1qjFKU8yBPKoWEeZN5M=",
"path": "/nix/store/f30jn7l0bf7a01qj029fq55i466vmnkh-source",
"type": "path"
},
"original": {
"id": "nixpkgs",
"type": "indirect"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs",
"utils": "utils"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

View File

@ -1,39 +0,0 @@
{
description = "MaiMBot Nix Dev Env";
inputs = {
utils.url = "github:numtide/flake-utils";
};
outputs = {
self,
nixpkgs,
utils,
...
}:
utils.lib.eachDefaultSystem (system: let
pkgs = import nixpkgs {inherit system;};
pythonPackages = pkgs.python3Packages;
in {
devShells.default = pkgs.mkShell {
name = "python-venv";
venvDir = "./.venv";
buildInputs = with pythonPackages; [
python
venvShellHook
scipy
numpy
];
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
pip install -r requirements.txt
'';
postShellHook = ''
# allow pip to install wheels
unset SOURCE_DATE_EPOCH
'';
};
});
}

View File

@ -0,0 +1,75 @@
# BetterFrequency 频率控制插件
这是一个用于控制MaiBot聊天频率的插件支持实时调整talk_frequency参数。
## 功能特性
- 💬 **Talk Frequency控制**: 调整机器人的发言频率
- 📊 **状态显示**: 实时查看当前频率控制状态
- ⚡ **实时生效**: 设置后立即生效,无需重启
- 💾 **不保存消息**: 命令执行反馈不会保存到数据库
- 🚀 **简化命令**: 支持完整命令和简化命令两种形式
## 命令列表
### 1. 设置Talk Frequency
```
/chat talk_frequency <数字> # 完整命令
/chat t <数字> # 简化命令
```
- 功能设置当前聊天的talk_frequency调整值
- 参数支持0到1之间的数值
- 示例:
- `/chat talk_frequency 1.0``/chat t 1.0` - 设置发言频率调整为1.0(最高频率)
- `/chat talk_frequency 0.5``/chat t 0.5` - 设置发言频率调整为0.5
- `/chat talk_frequency 0.0``/chat t 0.0` - 设置发言频率调整为0.0(最低频率)
### 2. 显示当前状态
```
/chat show # 完整命令
/chat s # 简化命令
```
- 功能:显示当前聊天的频率控制状态
- 显示内容:
- 当前talk_frequency值
- 可用命令提示(包含简化命令)
## 配置说明
插件配置文件 `config.toml` 包含以下选项:
```toml
[plugin]
name = "better_frequency_plugin"
version = "1.0.0"
enabled = true
[frequency]
default_talk_adjust = 1.0 # 默认talk_frequency调整值
max_adjust_value = 1.0 # 最大调整值
min_adjust_value = 0.0 # 最小调整值
```
## 使用场景
- **提高机器人活跃度**: 设置较高的talk_frequency值接近1.0
- **降低机器人活跃度**: 设置较低的talk_frequency值接近0.0
- **精细调节**: 使用小数进行微调
- **实时监控**: 通过show命令查看当前状态
- **快速操作**: 使用简化命令提高操作效率
## 注意事项
1. 调整值会立即生效,影响当前聊天的机器人行为
2. 命令执行反馈消息不会保存到数据库
3. 支持0到1之间的数值
4. 每个聊天都有独立的频率控制设置
5. 简化命令和完整命令功能完全相同,可根据个人习惯选择
## 技术实现
- 基于MaiCore插件系统开发
- 使用frequency_api进行频率控制操作
- 使用send_api发送反馈消息
- 支持异步操作和错误处理
- 正则表达式支持多种命令格式

View File

@ -0,0 +1,56 @@
{
"manifest_version": 1,
"name": "发言频率控制插件|BetterFrequency Plugin",
"version": "2.0.0",
"description": "控制聊天频率支持设置focus_value和talk_frequency调整值提供完整命令和简化命令",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.3"
},
"homepage_url": "https://github.com/SengokuCola/BetterFrequency",
"repository_url": "https://github.com/SengokuCola/BetterFrequency",
"keywords": ["frequency", "control", "talk_frequency", "plugin", "shortcut"],
"categories": ["Chat", "Frequency", "Control"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "frequency",
"components": [
{
"type": "command",
"name": "set_talk_frequency",
"description": "设置当前聊天的talk_frequency调整值",
"pattern": "/chat talk_frequency <数字> 或 /chat t <数字>"
},
{
"type": "action",
"name": "frequency_adjust",
"description": "调整当前聊天的发言频率",
"pattern": "/chat frequency_adjust <数字> 或 /chat f <数字>"
},
{
"type": "command",
"name": "show_frequency",
"description": "显示当前聊天的频率控制状态",
"pattern": "/chat show 或 /chat s"
}
],
"features": [
"设置talk_frequency调整值",
"调整当前聊天的发言频率",
"显示当前频率控制状态",
"实时频率控制调整",
"命令执行反馈(不保存消息)",
"支持完整命令和简化命令",
"快速操作支持"
]
}
}

View File

@ -0,0 +1,121 @@
from typing import Tuple
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType
# 导入依赖的系统组件
from src.common.logger import get_logger
# 导入API模块
from src.plugin_system.apis import frequency_api, send_api, config_api, generator_api
logger = get_logger("frequency_adjust")
class FrequencyAdjustAction(BaseAction):
"""频率调节动作 - 调整聊天发言频率"""
activation_type = ActionActivationType.LLM_JUDGE
parallel_action = False
# 动作基本信息
action_name = "frequency_adjust"
action_description = "调整当前聊天的发言频率"
# 动作参数定义
action_parameters = {
"direction": "调整方向:'increase'(增加)或'decrease'(降低)",
}
# 动作使用场景
bot_name = config_api.get_global_config("bot.nickname")
action_require = [
f"当用户提到 {bot_name} 太安静或太活跃时使用",
f"有人提到 {bot_name} 的发言太多或太少",
f"需要根据聊天氛围调整 {bot_name} 的活跃度",
]
# 关联类型
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
"""执行频率调节动作"""
try:
# 1. 获取动作参数
direction = self.action_data.get("direction")
# multiply = 1.2
# multiply = self.action_data.get("multiply")
if not direction:
error_msg = "缺少必要的参数direction或multiply"
logger.error(f"{self.log_prefix} {error_msg}")
return False, error_msg
# 2. 获取当前频率值
current_frequency = frequency_api.get_current_talk_frequency(self.chat_id)
# 3. 计算新的频率值(使用比率而不是绝对值)
# calculated_frequency = current_frequency * multiply
if direction == "increase":
calculated_frequency = current_frequency * 1.2
if calculated_frequency > 1.0:
new_frequency = 1.0
action_desc = f"增加到最大值"
# 记录超出限制的action
logger.warning(f"{self.log_prefix} 尝试调整频率超出最大值: current={current_frequency:.2f}, calculated={calculated_frequency:.2f}")
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"你尝试调整发言频率到{calculated_frequency:.2f}但最大值只能为1.0,已设置为最大值",
action_done=True,
)
return True, f"调整发言频率超出限制: {current_frequency:.2f}{new_frequency:.2f}"
else:
new_frequency = calculated_frequency
action_desc = f"增加"
elif direction == "decrease":
calculated_frequency = current_frequency * 0.8
new_frequency = max(0.0, calculated_frequency)
action_desc = f"降低"
else:
error_msg = f"无效的调整方向: {direction}"
logger.error(f"{self.log_prefix} {error_msg}")
return False, error_msg
# 4. 设置新的频率值
frequency_api.set_talk_frequency_adjust(self.chat_id, new_frequency)
# 5. 发送反馈消息
feedback_msg = f"{action_desc}发言频率:{current_frequency:.2f}{new_frequency:.2f}"
result_status, data = await generator_api.rewrite_reply(
chat_stream=self.chat_stream,
reply_data={
"raw_reply": feedback_msg,
"reason": "表达自己已经调整了发言频率,不一定要说具体数值,可以有趣一些",
},
)
if result_status:
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data)
logger.info(f"{self.log_prefix} {send_data}")
# 6. 存储动作信息(仅在未超出限制时)
if calculated_frequency <= 1.0:
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"{action_desc}了发言频率,从{current_frequency:.2f}调整到{new_frequency:.2f}",
action_done=True,
)
return True, f"成功调整发言频率: {current_frequency:.2f}{new_frequency:.2f}"
except Exception as e:
error_msg = f"频率调节失败: {str(e)}"
logger.error(f"{self.log_prefix} {error_msg}", exc_info=True)
await self.send_text("频率调节失败")
return False, error_msg

View File

@ -0,0 +1,137 @@
from typing import List, Tuple, Type, Any, Optional
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseCommand,
ComponentInfo,
ConfigField
)
from src.plugin_system.apis import send_api, frequency_api
from .frequency_adjust_action import FrequencyAdjustAction
class SetTalkFrequencyCommand(BaseCommand):
"""设置当前聊天的talk_frequency值"""
command_name = "set_talk_frequency"
command_description = "设置当前聊天的talk_frequency值/chat talk_frequency <数字> 或 /chat t <数字>"
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
async def execute(self) -> Tuple[bool, Optional[str], bool]:
try:
# 获取命令参数 - 使用命名捕获组
if not self.matched_groups or "value" not in self.matched_groups:
return False, "命令格式错误", False
value_str = self.matched_groups["value"]
if not value_str:
return False, "无法获取数值参数", False
value = float(value_str)
# 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id
# 设置talk_frequency
frequency_api.set_talk_frequency_adjust(chat_id, value)
# 发送反馈消息(不保存到数据库)
await send_api.text_to_stream(
f"已设置当前聊天的talk_frequency调整值为: {value}",
chat_id,
storage_message=False
)
return True, None, False
except ValueError:
error_msg = "数值格式错误,请输入有效的数字"
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
except Exception as e:
error_msg = f"设置talk_frequency失败: {str(e)}"
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
class ShowFrequencyCommand(BaseCommand):
"""显示当前聊天的频率控制状态"""
command_name = "show_frequency"
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
command_pattern = r"^/chat\s+(?:show|s)$"
async def execute(self) -> Tuple[bool, Optional[str], bool]:
try:
# 获取聊天流ID
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
return False, "无法获取聊天流信息", False
chat_id = self.message.chat_stream.stream_id
# 获取当前频率控制状态
current_talk_frequency = frequency_api.get_current_talk_frequency(chat_id)
talk_frequency_adjust = frequency_api.get_talk_frequency_adjust(chat_id)
# 构建显示消息
status_msg = f"""当前聊天频率控制状态
Talk Frequency (发言频率):
当前值: {current_talk_frequency:.2f}
使用命令:
/chat talk_frequency <数字> /chat t <数字> - 设置发言频率调整
/chat show /chat s - 显示当前状态"""
# 发送状态消息(不保存到数据库)
await send_api.text_to_stream(status_msg, chat_id, storage_message=False)
return True, None, False
except Exception as e:
error_msg = f"获取频率控制状态失败: {str(e)}"
# 使用内置的send_text方法发送错误消息
await self.send_text(error_msg, storage_message=False)
return False, error_msg, False
# ===== 插件注册 =====
@register_plugin
class BetterFrequencyPlugin(BasePlugin):
"""BetterFrequency插件 - 控制聊天频率的插件"""
# 插件基本信息
plugin_name: str = "better_frequency_plugin"
enable_plugin: bool = True
dependencies: List[str] = []
python_dependencies: List[str] = []
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息",
"frequency": "频率控制配置"
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
},
"frequency": {
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
}
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
(FrequencyAdjustAction.get_action_info(), FrequencyAdjustAction),
]

View File

@ -0,0 +1,34 @@
{
"manifest_version": 1,
"name": "Deep Think插件 (Deep Think Actions)",
"version": "1.0.0",
"description": "可以深度思考",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.11.0"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",
"keywords": ["deep", "think", "action", "built-in"],
"categories": ["Deep Think"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": true,
"plugin_type": "action_provider",
"components": [
{
"type": "action",
"name": "deep_think",
"description": "发送深度思考"
}
]
}
}

View File

@ -0,0 +1,102 @@
from typing import List, Tuple, Type, Any
# 导入新插件系统
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
from src.plugin_system.base.config_types import ConfigField
from src.person_info.person_info import Person
from src.plugin_system.base.base_tool import BaseTool, ToolParamType
# 导入依赖的系统组件
from src.common.logger import get_logger
from src.plugins.built_in.relation.relation import BuildRelationAction
from src.plugin_system.apis import llm_api
logger = get_logger("relation_actions")
class DeepThinkTool(BaseTool):
"""获取用户信息"""
name = "deep_think"
description = "深度思考,对某个知识,概念或逻辑问题进行全面且深入的思考,当面临复杂环境或重要问题时,使用此获得更好的解决方案。"
parameters = [
("question", ToolParamType.STRING, "需要思考的问题,越具体越好(从上下文中总结)", True, None),
]
available_for_llm = True
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
question: str = function_args.get("question") # type: ignore
print(f"question: {question}")
prompt = f"""
请你思考以下问题以简洁的一段话回答
{question}
"""
models = llm_api.get_available_models()
chat_model_config = models.get("replyer") # 使用字典访问方式
success, thinking_result, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="deep_think"
)
logger.info(f"{question}: {thinking_result}")
thinking_result =f"思考结果:{thinking_result}\n**注意** 因为你进行了深度思考,最后的回复内容可以回复的长一些,更加详细一些,不用太简洁。\n"
return {"content": thinking_result}
@register_plugin
class DeepThinkPlugin(BasePlugin):
"""关系动作插件
系统内置插件提供基础的聊天交互功能
- Reply: 回复动作
- NoReply: 不回复动作
- Emoji: 表情动作
注意插件基本信息优先从_manifest.json文件中读取
"""
# 插件基本信息
plugin_name: str = "deep_think" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件启用配置",
"components": "核心组件启用配置",
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),
}
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# --- 根据配置注册组件 ---
components = []
components.append((DeepThinkTool.get_tool_info(), DeepThinkTool))
return components

View File

@ -0,0 +1,53 @@
{
"manifest_version": 1,
"name": "BetterEmoji",
"version": "1.0.0",
"description": "更好的表情包管理插件",
"author": {
"name": "SengokuCola",
"url": "https://github.com/SengokuCola"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.4"
},
"homepage_url": "https://github.com/SengokuCola/BetterEmoji",
"repository_url": "https://github.com/SengokuCola/BetterEmoji",
"keywords": ["emoji", "manage", "plugin"],
"categories": ["Examples", "Tutorial"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "emoji_manage",
"components": [
{
"type": "action",
"name": "hello_greeting",
"description": "向用户发送问候消息"
},
{
"type": "action",
"name": "bye_greeting",
"description": "向用户发送告别消息",
"activation_modes": ["keyword"],
"keywords": ["再见", "bye", "88", "拜拜"]
},
{
"type": "command",
"name": "time",
"description": "查询当前时间",
"pattern": "/time"
}
],
"features": [
"问候和告别功能",
"时间查询命令",
"配置文件示例",
"新手教程代码"
]
}
}

View File

@ -0,0 +1,399 @@
from typing import List, Tuple, Type
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseCommand,
ComponentInfo,
ConfigField,
ReplyContentType,
emoji_api,
)
from maim_message import Seg
from src.common.logger import get_logger
logger = get_logger("emoji_manage_plugin")
class AddEmojiCommand(BaseCommand):
command_name = "add_emoji"
command_description = "添加表情包"
command_pattern = r".*/emoji add.*"
async def execute(self) -> Tuple[bool, str, bool]:
# 查找消息中的表情包
# logger.info(f"查找消息中的表情包: {self.message.message_segment}")
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
if not emoji_base64_list:
return False, "未在消息中找到表情包或图片", False
# 注册找到的表情包
success_count = 0
fail_count = 0
results = []
for i, emoji_base64 in enumerate(emoji_base64_list):
try:
# 使用emoji_api注册表情包让API自动生成唯一文件名
result = await emoji_api.register_emoji(emoji_base64)
if result["success"]:
success_count += 1
description = result.get("description", "未知描述")
emotions = result.get("emotions", [])
replaced = result.get("replaced", False)
result_msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
if description:
result_msg += f"\n描述: {description}"
if emotions:
result_msg += f"\n情感标签: {', '.join(emotions)}"
results.append(result_msg)
else:
fail_count += 1
error_msg = result.get("message", "注册失败")
results.append(f"表情包 {i + 1} 注册失败: {error_msg}")
except Exception as e:
fail_count += 1
results.append(f"表情包 {i + 1} 注册时发生错误: {str(e)}")
# 构建返回消息
total_count = success_count + fail_count
summary_msg = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count}"
# 如果有结果详情,添加到返回消息中
details_msg = ""
if results:
details_msg = "\n" + "\n".join(results)
final_msg = summary_msg + details_msg
else:
final_msg = summary_msg
# 使用表达器重写回复
try:
from src.plugin_system.apis import generator_api
# 构建重写数据
rewrite_data = {
"raw_reply": summary_msg,
"reason": f"注册了表情包:{details_msg}\n",
}
# 调用表达器重写
result_status, data = await generator_api.rewrite_reply(
chat_stream=self.message.chat_stream,
reply_data=rewrite_data,
)
if result_status:
# 发送重写后的回复
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data)
return success_count > 0, final_msg, success_count > 0
else:
# 如果重写失败,发送原始消息
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
except Exception as e:
# 如果表达器调用失败,发送原始消息
logger.error(f"[add_emoji] 表达器重写失败: {e}")
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
emoji_base64_list = []
# 处理单个Seg对象的情况
if isinstance(message_segments, Seg):
if message_segments.type == "emoji":
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
return emoji_base64_list
# 处理Seg列表的情况
for seg in message_segments:
if seg.type == "emoji":
emoji_base64_list.append(seg.data)
elif seg.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(seg.data)
elif seg.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
return emoji_base64_list
class ListEmojiCommand(BaseCommand):
"""列表表情包Command - 响应/emoji list命令"""
command_name = "emoji_list"
command_description = "列表表情包"
# === 命令设置(必须填写)===
command_pattern = r"^/emoji list(\s+\d+)?$" # 匹配 "/emoji list" 或 "/emoji list 数量"
async def execute(self) -> Tuple[bool, str, bool]:
"""执行列表表情包"""
from src.plugin_system.apis import emoji_api
import datetime
# 解析命令参数
import re
match = re.match(r"^/emoji list(?:\s+(\d+))?$", self.message.raw_message)
max_count = 10 # 默认显示10个
if match and match.group(1):
max_count = min(int(match.group(1)), 50) # 最多显示50个
# 获取当前时间
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
now = datetime.datetime.now()
time_str = now.strftime(time_format)
# 获取表情包信息
emoji_count = emoji_api.get_count()
emoji_info = emoji_api.get_info()
# 构建返回消息
message_lines = [
f"📊 表情包统计信息 ({time_str})",
f"• 总数: {emoji_count} / {emoji_info['max_count']}",
f"• 可用: {emoji_info['available_emojis']}",
]
if emoji_count == 0:
message_lines.append("\n❌ 暂无表情包")
final_message = "\n".join(message_lines)
await self.send_text(final_message)
return True, final_message, True
# 获取所有表情包
all_emojis = await emoji_api.get_all()
if not all_emojis:
message_lines.append("\n❌ 无法获取表情包列表")
final_message = "\n".join(message_lines)
await self.send_text(final_message)
return False, final_message, True
# 显示前N个表情包
display_emojis = all_emojis[:max_count]
message_lines.append(f"\n📋 显示前 {len(display_emojis)} 个表情包:")
for i, (_, description, emotion) in enumerate(display_emojis, 1):
# 截断过长的描述
short_desc = description[:50] + "..." if len(description) > 50 else description
message_lines.append(f"{i}. {short_desc} [{emotion}]")
# 如果还有更多表情包,显示总数
if len(all_emojis) > max_count:
message_lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
final_message = "\n".join(message_lines)
# 直接发送文本消息
await self.send_text(final_message)
return True, final_message, True
class DeleteEmojiCommand(BaseCommand):
command_name = "delete_emoji"
command_description = "删除表情包"
command_pattern = r".*/emoji delete.*"
async def execute(self) -> Tuple[bool, str, bool]:
# 查找消息中的表情包图片
logger.info(f"查找消息中的表情包用于删除: {self.message.message_segment}")
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
if not emoji_base64_list:
return False, "未在消息中找到表情包或图片", False
# 删除找到的表情包
success_count = 0
fail_count = 0
results = []
for i, emoji_base64 in enumerate(emoji_base64_list):
try:
# 计算图片的哈希值来查找对应的表情包
import base64
import hashlib
# 确保base64字符串只包含ASCII字符
if isinstance(emoji_base64, str):
emoji_base64_clean = emoji_base64.encode("ascii", errors="ignore").decode("ascii")
else:
emoji_base64_clean = str(emoji_base64)
# 计算哈希值
image_bytes = base64.b64decode(emoji_base64_clean)
emoji_hash = hashlib.md5(image_bytes).hexdigest()
# 使用emoji_api删除表情包
result = await emoji_api.delete_emoji(emoji_hash)
if result["success"]:
success_count += 1
description = result.get("description", "未知描述")
count_before = result.get("count_before", 0)
count_after = result.get("count_after", 0)
emotions = result.get("emotions", [])
result_msg = f"表情包 {i + 1} 删除成功"
if description:
result_msg += f"\n描述: {description}"
if emotions:
result_msg += f"\n情感标签: {', '.join(emotions)}"
result_msg += f"\n表情包数量: {count_before}{count_after}"
results.append(result_msg)
else:
fail_count += 1
error_msg = result.get("message", "删除失败")
results.append(f"表情包 {i + 1} 删除失败: {error_msg}")
except Exception as e:
fail_count += 1
results.append(f"表情包 {i + 1} 删除时发生错误: {str(e)}")
# 构建返回消息
total_count = success_count + fail_count
summary_msg = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count}"
# 如果有结果详情,添加到返回消息中
details_msg = ""
if results:
details_msg = "\n" + "\n".join(results)
final_msg = summary_msg + details_msg
else:
final_msg = summary_msg
# 使用表达器重写回复
try:
from src.plugin_system.apis import generator_api
# 构建重写数据
rewrite_data = {
"raw_reply": summary_msg,
"reason": f"删除了表情包:{details_msg}\n",
}
# 调用表达器重写
result_status, data = await generator_api.rewrite_reply(
chat_stream=self.message.chat_stream,
reply_data=rewrite_data,
)
if result_status:
# 发送重写后的回复
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data)
return success_count > 0, final_msg, success_count > 0
else:
# 如果重写失败,发送原始消息
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
except Exception as e:
# 如果表达器调用失败,发送原始消息
logger.error(f"[delete_emoji] 表达器重写失败: {e}")
await self.send_text(final_msg)
return success_count > 0, final_msg, success_count > 0
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
emoji_base64_list = []
# 处理单个Seg对象的情况
if isinstance(message_segments, Seg):
if message_segments.type == "emoji":
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(message_segments.data)
elif message_segments.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
return emoji_base64_list
# 处理Seg列表的情况
for seg in message_segments:
if seg.type == "emoji":
emoji_base64_list.append(seg.data)
elif seg.type == "image":
# 假设图片数据是base64编码的
emoji_base64_list.append(seg.data)
elif seg.type == "seglist":
# 递归处理嵌套的Seg列表
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
return emoji_base64_list
class RandomEmojis(BaseCommand):
command_name = "random_emojis"
command_description = "发送多张随机表情包"
command_pattern = r"^/random_emojis$"
async def execute(self):
emojis = await emoji_api.get_random(5)
if not emojis:
return False, "未找到表情包", False
emoji_base64_list = []
for emoji in emojis:
emoji_base64_list.append(emoji[0])
return await self.forward_images(emoji_base64_list)
async def forward_images(self, images: List[str]):
"""
把多张图片用合并转发的方式发给用户
"""
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
# ===== 插件注册 =====
@register_plugin
class EmojiManagePlugin(BasePlugin):
"""表情包管理插件 - 管理表情包"""
# 插件基本信息
plugin_name: str = "emoji_manage_plugin" # 内部标识符
enable_plugin: bool = False
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息", "emoji": "表情包功能配置"}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(RandomEmojis.get_command_info(), RandomEmojis),
(AddEmojiCommand.get_command_info(), AddEmojiCommand),
(ListEmojiCommand.get_command_info(), ListEmojiCommand),
(DeleteEmojiCommand.get_command_info(), DeleteEmojiCommand),
]

View File

@ -237,8 +237,7 @@ class HelloWorldPlugin(BasePlugin):
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
"config_version": ConfigField(type=str, default="1.0.0", description="配置文件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"greeting": {

View File

@ -1,56 +1,37 @@
[project]
name = "MaiBot"
version = "0.8.1"
version = "0.11.0"
description = "MaiCore 是一个基于大语言模型的可交互智能体"
requires-python = ">=3.10"
dependencies = [
"aiohttp>=3.12.14",
"apscheduler>=3.11.0",
"aiohttp-cors>=0.8.1",
"colorama>=0.4.6",
"cryptography>=45.0.5",
"customtkinter>=5.2.2",
"dotenv>=0.9.9",
"faiss-cpu>=1.11.0",
"fastapi>=0.116.0",
"google-genai>=1.39.1",
"jieba>=0.42.1",
"json-repair>=0.47.6",
"jsonlines>=4.0.0",
"maim-message>=0.3.8",
"maim-message",
"matplotlib>=3.10.3",
"networkx>=3.4.2",
"numpy>=2.2.6",
"openai>=1.95.0",
"packaging>=25.0",
"pandas>=2.3.1",
"peewee>=3.18.2",
"pillow>=11.3.0",
"psutil>=7.0.0",
"pyarrow>=20.0.0",
"pydantic>=2.11.7",
"pymongo>=4.13.2",
"pypinyin>=0.54.0",
"python-dateutil>=2.9.0.post0",
"python-dotenv>=1.1.1",
"python-igraph>=0.11.9",
"quick-algo>=0.1.3",
"reportportal-client>=5.6.5",
"requests>=2.32.4",
"rich>=14.0.0",
"ruff>=0.12.2",
"scikit-learn>=1.7.0",
"scipy>=1.15.3",
"seaborn>=0.13.2",
"setuptools>=80.9.0",
"strawberry-graphql[fastapi]>=0.275.5",
"structlog>=25.4.0",
"toml>=0.10.2",
"tomli>=2.2.1",
"tomli-w>=1.2.0",
"tomlkit>=0.13.3",
"tqdm>=4.67.1",
"urllib3>=2.5.0",
"uvicorn>=0.35.0",
"websockets>=15.0.1",
]

View File

@ -1,271 +0,0 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.txt -o requirements.lock
aenum==3.1.16
# via reportportal-client
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.12.14
# via
# -r requirements.txt
# maim-message
# reportportal-client
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
anyio==4.9.0
# via
# httpx
# openai
# starlette
apscheduler==3.11.0
# via -r requirements.txt
attrs==25.3.0
# via
# aiohttp
# jsonlines
certifi==2025.7.9
# via
# httpcore
# httpx
# reportportal-client
# requests
cffi==1.17.1
# via cryptography
charset-normalizer==3.4.2
# via requests
click==8.2.1
# via uvicorn
colorama==0.4.6
# via
# -r requirements.txt
# click
# tqdm
contourpy==1.3.2
# via matplotlib
cryptography==45.0.5
# via
# -r requirements.txt
# maim-message
customtkinter==5.2.2
# via -r requirements.txt
cycler==0.12.1
# via matplotlib
darkdetect==0.8.0
# via customtkinter
distro==1.9.0
# via openai
dnspython==2.7.0
# via pymongo
dotenv==0.9.9
# via -r requirements.txt
faiss-cpu==1.11.0
# via -r requirements.txt
fastapi==0.116.0
# via
# -r requirements.txt
# maim-message
# strawberry-graphql
fonttools==4.58.5
# via matplotlib
frozenlist==1.7.0
# via
# aiohttp
# aiosignal
graphql-core==3.2.6
# via strawberry-graphql
h11==0.16.0
# via
# httpcore
# uvicorn
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via openai
idna==3.10
# via
# anyio
# httpx
# requests
# yarl
igraph==0.11.9
# via python-igraph
jieba==0.42.1
# via -r requirements.txt
jiter==0.10.0
# via openai
joblib==1.5.1
# via scikit-learn
json-repair==0.47.6
# via -r requirements.txt
jsonlines==4.0.0
# via -r requirements.txt
kiwisolver==1.4.8
# via matplotlib
maim-message==0.3.8
# via -r requirements.txt
markdown-it-py==3.0.0
# via rich
matplotlib==3.10.3
# via
# -r requirements.txt
# seaborn
mdurl==0.1.2
# via markdown-it-py
multidict==6.6.3
# via
# aiohttp
# yarl
networkx==3.5
# via -r requirements.txt
numpy==2.3.1
# via
# -r requirements.txt
# contourpy
# faiss-cpu
# matplotlib
# pandas
# scikit-learn
# scipy
# seaborn
openai==1.95.0
# via -r requirements.txt
packaging==25.0
# via
# -r requirements.txt
# customtkinter
# faiss-cpu
# matplotlib
# strawberry-graphql
pandas==2.3.1
# via
# -r requirements.txt
# seaborn
peewee==3.18.2
# via -r requirements.txt
pillow==11.3.0
# via
# -r requirements.txt
# matplotlib
propcache==0.3.2
# via
# aiohttp
# yarl
psutil==7.0.0
# via -r requirements.txt
pyarrow==20.0.0
# via -r requirements.txt
pycparser==2.22
# via cffi
pydantic==2.11.7
# via
# -r requirements.txt
# fastapi
# maim-message
# openai
pydantic-core==2.33.2
# via pydantic
pygments==2.19.2
# via rich
pymongo==4.13.2
# via -r requirements.txt
pyparsing==3.2.3
# via matplotlib
pypinyin==0.54.0
# via -r requirements.txt
python-dateutil==2.9.0.post0
# via
# -r requirements.txt
# matplotlib
# pandas
# strawberry-graphql
python-dotenv==1.1.1
# via
# -r requirements.txt
# dotenv
python-igraph==0.11.9
# via -r requirements.txt
python-multipart==0.0.20
# via strawberry-graphql
pytz==2025.2
# via pandas
quick-algo==0.1.3
# via -r requirements.txt
reportportal-client==5.6.5
# via -r requirements.txt
requests==2.32.4
# via
# -r requirements.txt
# reportportal-client
rich==14.0.0
# via -r requirements.txt
ruff==0.12.2
# via -r requirements.txt
scikit-learn==1.7.0
# via -r requirements.txt
scipy==1.16.0
# via
# -r requirements.txt
# scikit-learn
seaborn==0.13.2
# via -r requirements.txt
setuptools==80.9.0
# via -r requirements.txt
six==1.17.0
# via python-dateutil
sniffio==1.3.1
# via
# anyio
# openai
starlette==0.46.2
# via fastapi
strawberry-graphql==0.275.5
# via -r requirements.txt
structlog==25.4.0
# via -r requirements.txt
texttable==1.7.0
# via igraph
threadpoolctl==3.6.0
# via scikit-learn
toml==0.10.2
# via -r requirements.txt
tomli==2.2.1
# via -r requirements.txt
tomli-w==1.2.0
# via -r requirements.txt
tomlkit==0.13.3
# via -r requirements.txt
tqdm==4.67.1
# via
# -r requirements.txt
# openai
typing-extensions==4.14.1
# via
# fastapi
# openai
# pydantic
# pydantic-core
# strawberry-graphql
# typing-inspection
typing-inspection==0.4.1
# via pydantic
tzdata==2025.2
# via
# pandas
# tzlocal
tzlocal==5.3.1
# via apscheduler
urllib3==2.5.0
# via
# -r requirements.txt
# requests
uvicorn==0.35.0
# via
# -r requirements.txt
# maim-message
websockets==15.0.1
# via
# -r requirements.txt
# maim-message
yarl==1.20.1
# via aiohttp

View File

@ -1,49 +1,28 @@
APScheduler
Pillow
aiohttp
aiohttp-cors
colorama
customtkinter
dotenv
faiss-cpu
fastapi
jieba
jsonlines
maim_message
quick_algo
matplotlib
networkx
numpy
openai
pandas
peewee
pyarrow
pydantic
pypinyin
python-dateutil
python-dotenv
python-igraph
pymongo
requests
ruff
scipy
setuptools
toml
tomli
tomli_w
tomlkit
tqdm
urllib3
uvicorn
websockets
strawberry-graphql[fastapi]
packaging
rich
psutil
cryptography
json-repair
reportportal-client
scikit-learn
seaborn
structlog
google.genai
aiohttp>=3.12.14
aiohttp-cors>=0.8.1
colorama>=0.4.6
faiss-cpu>=1.11.0
fastapi>=0.116.0
google-genai>=1.39.1
jieba>=0.42.1
json-repair>=0.47.6
maim-message
matplotlib>=3.10.3
numpy>=2.2.6
openai>=1.95.0
pandas>=2.3.1
peewee>=3.18.2
pillow>=11.3.0
pyarrow>=20.0.0
pydantic>=2.11.7
pypinyin>=0.54.0
python-dotenv>=1.1.1
quick-algo>=0.1.3
rich>=14.0.0
ruff>=0.12.2
setuptools>=80.9.0
structlog>=25.4.0
toml>=0.10.2
tomlkit>=0.13.3
urllib3>=2.5.0
uvicorn>=0.35.0

View File

@ -0,0 +1,389 @@
import argparse
import json
import random
import re
import sys
import os
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple
# 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
SECONDS_5_MINUTES = 5 * 60
def clean_output_text(text: str) -> str:
"""
清理输出文本移除表情包和回复内容
- 移除 [表情包...] 格式的内容
- 移除 [回复...] 格式的内容
"""
if not text:
return text
# 移除表情包内容:[表情包:...]
text = re.sub(r'\[表情包:[^\]]*\]', '', text)
# 移除回复内容:[回复...],说:... 的完整模式
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text)
# 清理多余的空格和换行
text = re.sub(r'\s+', ' ', text).strip()
return text
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳
支持示例
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err: Optional[Exception] = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e: # noqa: BLE001
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")
def fetch_messages_between(
start_ts: float,
end_ts: float,
platform: Optional[str] = None,
) -> List[DatabaseMessages]:
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
if platform:
filter_query["chat_info_platform"] = platform
# 当 limit==0 时sort 生效,这里按时间升序
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
groups: Dict[str, List[DatabaseMessages]] = {}
for msg in messages:
groups.setdefault(msg.chat_id, []).append(msg)
# 保证每个分组内按时间升序
for chat_id, msgs in groups.items():
msgs.sort(key=lambda m: m.time or 0)
return groups
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
"""
将相邻同一 user_id 5 分钟内的消息 bucket 合并为一条
processed_plain_text 合并以换行连接其余字段取最新一条时间最大
"""
if not bucket:
raise ValueError("bucket 为空,无法合并")
latest = bucket[-1]
merged_texts: List[str] = []
for m in bucket:
text = m.processed_plain_text or ""
if text:
merged_texts.append(text)
merged = DatabaseMessages(
# 其他信息采用最新消息
message_id=latest.message_id,
time=latest.time,
chat_id=latest.chat_id,
reply_to=latest.reply_to,
interest_value=latest.interest_value,
key_words=latest.key_words,
key_words_lite=latest.key_words_lite,
is_mentioned=latest.is_mentioned,
is_at=latest.is_at,
reply_probability_boost=latest.reply_probability_boost,
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
display_message=latest.display_message,
priority_mode=latest.priority_mode,
priority_info=latest.priority_info,
additional_config=latest.additional_config,
is_emoji=latest.is_emoji,
is_picid=latest.is_picid,
is_command=latest.is_command,
is_notify=latest.is_notify,
selected_expressions=latest.selected_expressions,
user_id=latest.user_info.user_id,
user_nickname=latest.user_info.user_nickname,
user_cardname=latest.user_info.user_cardname,
user_platform=latest.user_info.platform,
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
chat_info_user_id=latest.chat_info.user_info.user_id,
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
chat_info_user_platform=latest.chat_info.user_info.platform,
chat_info_stream_id=latest.chat_info.stream_id,
chat_info_platform=latest.chat_info.platform,
chat_info_create_time=latest.chat_info.create_time,
chat_info_last_active_time=latest.chat_info.last_active_time,
)
return merged
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
if not messages:
return []
merged: List[DatabaseMessages] = []
bucket: List[DatabaseMessages] = []
def flush_bucket() -> None:
nonlocal bucket
if bucket:
merged.append(_merge_bucket_to_message(bucket))
bucket = []
for msg in messages:
if not bucket:
bucket = [msg]
continue
last = bucket[-1]
same_user = (msg.user_info.user_id == last.user_info.user_id)
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
if same_user and close_enough:
bucket.append(msg)
else:
flush_bucket()
bucket = [msg]
flush_bucket()
return merged
def build_pairs_for_chat(
original_messages: List[DatabaseMessages],
merged_messages: List[DatabaseMessages],
min_ctx: int,
max_ctx: int,
target_user_id: Optional[str] = None,
) -> List[Tuple[str, str, str]]:
"""
对每条合并后的消息作为 output从其前面取 20-30 可配置的原始消息作为 input
input 使用原始未合并的消息构建上下文
output 使用合并后消息的 processed_plain_text
如果指定了 target_user_id则只处理该用户的消息作为 output
"""
pairs: List[Tuple[str, str, str]] = []
n_merged = len(merged_messages)
n_original = len(original_messages)
if n_merged == 0 or n_original == 0:
return pairs
# 为每个合并后的消息找到对应的原始消息位置
merged_to_original_map = {}
original_idx = 0
for merged_idx, merged_msg in enumerate(merged_messages):
# 找到这个合并消息对应的第一个原始消息
while (original_idx < n_original and
original_messages[original_idx].time < merged_msg.time):
original_idx += 1
# 如果找到了时间匹配的原始消息,建立映射
if (original_idx < n_original and
original_messages[original_idx].time == merged_msg.time):
merged_to_original_map[merged_idx] = original_idx
for merged_idx in range(n_merged):
merged_msg = merged_messages[merged_idx]
# 如果指定了 target_user_id只处理该用户的消息作为 output
if target_user_id and merged_msg.user_info.user_id != target_user_id:
continue
# 找到对应的原始消息位置
if merged_idx not in merged_to_original_map:
continue
original_idx = merged_to_original_map[merged_idx]
# 选择上下文窗口大小
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
start = max(0, original_idx - window)
context_msgs = original_messages[start:original_idx]
# 使用原始未合并消息构建 input
input_str = build_readable_messages(
messages=context_msgs,
timestamp_mode="normal_no_YMD",
show_actions=False,
show_pic=True,
)
# 输出取合并后消息的 processed_plain_text 并清理表情包和回复内容
output_text = merged_msg.processed_plain_text or ""
output_text = clean_output_text(output_text)
output_id = merged_msg.message_id or ""
pairs.append((input_str, output_text, output_id))
return pairs
def build_pairs(
start_ts: float,
end_ts: float,
platform: Optional[str],
user_id: Optional[str],
min_ctx: int,
max_ctx: int,
) -> List[Tuple[str, str, str]]:
# 获取所有消息不按user_id过滤这样input上下文可以包含所有用户的消息
messages = fetch_messages_between(start_ts, end_ts, platform)
groups = group_by_chat(messages)
all_pairs: List[Tuple[str, str, str]] = []
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
# 对消息进行合并用于output
merged = merge_adjacent_same_user(msgs)
# 传递原始消息和合并后消息input使用原始消息output使用合并后消息
pairs = build_pairs_for_chat(msgs, merged, min_ctx, max_ctx, user_id)
all_pairs.extend(pairs)
return all_pairs
def main(argv: Optional[List[str]] = None) -> int:
# 若未提供参数,则进入交互模式
if argv is None:
argv = sys.argv[1:]
if len(argv) == 0:
return run_interactive()
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表支持按用户ID筛选消息")
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
parser.add_argument("--user_id", default=None, help="仅选择指定 user_id 的消息")
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数默认20")
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数默认30")
parser.add_argument(
"--output",
default=None,
help="输出保存路径,支持 .jsonl每行 {input, output}若不指定则打印到stdout",
)
args = parser.parse_args(argv)
start_ts = parse_datetime_to_timestamp(args.start)
end_ts = parse_datetime_to_timestamp(args.end)
if end_ts <= start_ts:
raise ValueError("结束时间必须大于起始时间")
if args.max_ctx < args.min_ctx:
raise ValueError("max_ctx 不能小于 min_ctx")
pairs = build_pairs(start_ts, end_ts, args.platform, args.user_id, args.min_ctx, args.max_ctx)
if args.output:
# 保存为 JSONL每行一个 {input, output, message_id}
with open(args.output, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {args.output}")
else:
# 打印到 stdout
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
return 0
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
suffix = f"[{default}]" if default not in (None, "") else ""
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
if value == "" and default is not None:
return default
return value
def run_interactive() -> int:
print("进入交互模式直接回车采用默认值。时间格式例如2025-09-28 00:00:00 或 2025-09-28")
start_str = _prompt_with_default("请输入起始时间", None)
end_str = _prompt_with_default("请输入结束时间", None)
platform = _prompt_with_default("平台(可留空表示不限)", "")
user_id = _prompt_with_default("用户ID可留空表示不限", "")
try:
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
except Exception:
print("上下文条数输入有误,使用默认 20/30")
min_ctx, max_ctx = 20, 30
output_path = _prompt_with_default("输出路径(.jsonl可留空打印到控制台", "")
if not start_str or not end_str:
print("必须提供起始与结束时间。")
return 2
try:
start_ts = parse_datetime_to_timestamp(start_str)
end_ts = parse_datetime_to_timestamp(end_str)
except Exception as e: # noqa: BLE001
print(f"时间解析失败:{e}")
return 2
if end_ts <= start_ts:
print("结束时间必须大于起始时间。")
return 2
if max_ctx < min_ctx:
print("最多条数不能小于最少条数。")
return 2
platform_val = platform if platform != "" else None
user_id_val = user_id if user_id != "" else None
pairs = build_pairs(start_ts, end_ts, platform_val, user_id_val, min_ctx, max_ctx)
if output_path:
with open(output_path, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {output_path}")
else:
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
print(f"总计 {len(pairs)} 条。")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,334 @@
import time
import sys
import os
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime
from typing import List, Tuple
import numpy as np
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression, ChatStreams
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def get_expression_data() -> List[Tuple[float, float, str, str]]:
"""获取Expression表中的数据返回(create_date, count, chat_id, expression_type)的列表"""
expressions = Expression.select()
data = []
for expr in expressions:
# 如果create_date为空跳过该记录
if expr.create_date is None:
continue
data.append((
expr.create_date,
expr.count,
expr.chat_id,
expr.type
))
return data
def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
"""创建散点图"""
if not data:
print("没有找到有效的表达式数据")
return
# 分离数据
create_dates = [item[0] for item in data]
counts = [item[1] for item in data]
chat_ids = [item[2] for item in data]
expression_types = [item[3] for item in data]
# 转换时间戳为datetime对象
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
# 计算时间跨度,自动调整显示格式
time_span = max(dates) - min(dates)
if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d'
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d'
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M'
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(12, 8))
# 创建散点图
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis')
# 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12)
ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold')
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加网格
ax.grid(True, alpha=0.3)
# 添加颜色条
cbar = plt.colorbar(scatter)
cbar.set_label('数据点顺序', fontsize=10)
# 调整布局
plt.tight_layout()
# 显示统计信息
print(f"\n=== 数据统计 ===")
print(f"总数据点数量: {len(data)}")
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')}{max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
print(f"使用次数范围: {min(counts):.1f}{max(counts):.1f}")
print(f"平均使用次数: {np.mean(counts):.2f}")
print(f"中位数使用次数: {np.median(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"\n散点图已保存到: {save_path}")
# 显示图片
plt.show()
def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
"""创建按聊天分组的散点图"""
if not data:
print("没有找到有效的表达式数据")
return
# 按chat_id分组
chat_groups = {}
for item in data:
chat_id = item[2]
if chat_id not in chat_groups:
chat_groups[chat_id] = []
chat_groups[chat_id].append(item)
# 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d'
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d'
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M'
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(14, 10))
# 为每个聊天分配不同颜色
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
create_dates = [item[0] for item in chat_data]
counts = [item[1] for item in chat_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
chat_name = get_chat_name(chat_id)
# 截断过长的聊天名称
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
ax.scatter(dates, counts, alpha=0.7, s=40,
c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)",
edgecolors='black', linewidth=0.5)
# 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12)
ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold')
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
# 添加网格
ax.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 显示统计信息
print(f"\n=== 分组统计 ===")
print(f"总聊天数量: {len(chat_groups)}")
for chat_id, chat_data in chat_groups.items():
chat_name = get_chat_name(chat_id)
counts = [item[1] for item in chat_data]
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"\n分组散点图已保存到: {save_path}")
# 显示图片
plt.show()
def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
"""创建按表达式类型分组的散点图"""
if not data:
print("没有找到有效的表达式数据")
return
# 按type分组
type_groups = {}
for item in data:
expr_type = item[3]
if expr_type not in type_groups:
type_groups[expr_type] = []
type_groups[expr_type].append(item)
# 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示
date_format = '%Y-%m-%d'
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = '%Y-%m-%d'
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = '%Y-%m-%d %H:%M'
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(12, 8))
# 为每个类型分配不同颜色
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
for i, (expr_type, type_data) in enumerate(type_groups.items()):
create_dates = [item[0] for item in type_data]
counts = [item[1] for item in type_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
ax.scatter(dates, counts, alpha=0.7, s=40,
c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)",
edgecolors='black', linewidth=0.5)
# 设置标签和标题
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
ax.set_ylabel('使用次数 (Count)', fontsize=12)
ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold')
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# 添加网格
ax.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 显示统计信息
print(f"\n=== 类型统计 ===")
for expr_type, type_data in type_groups.items():
counts = [item[1] for item in type_data]
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"\n类型散点图已保存到: {save_path}")
# 显示图片
plt.show()
def main():
"""主函数"""
print("开始分析表达式数据...")
# 获取数据
data = get_expression_data()
if not data:
print("没有找到有效的表达式数据create_date不为空的数据")
return
print(f"找到 {len(data)} 条有效数据")
# 创建输出目录
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
# 生成时间戳用于文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 1. 创建基础散点图
print("\n1. 创建基础散点图...")
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
# 2. 创建按聊天分组的散点图
print("\n2. 创建按聊天分组的散点图...")
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
# 3. 创建按类型分组的散点图
print("\n3. 创建按类型分组的散点图...")
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
print("\n分析完成!")
if __name__ == "__main__":
main()

View File

@ -1,285 +0,0 @@
import time
import sys
import os
from typing import Dict, List, Tuple, Optional
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams # noqa
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def format_timestamp(timestamp: float) -> str:
"""Format timestamp to readable date string"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of interest_value"""
distribution = {
"0.000-0.010": 0,
"0.010-0.050": 0,
"0.050-0.100": 0,
"0.100-0.500": 0,
"0.500-1.000": 0,
"1.000-2.000": 0,
"2.000-5.000": 0,
"5.000-10.000": 0,
"10.000+": 0,
}
for msg in messages:
if msg.interest_value is None or msg.interest_value == 0.0:
continue
value = float(msg.interest_value)
if value < 0.010:
distribution["0.000-0.010"] += 1
elif value < 0.050:
distribution["0.010-0.050"] += 1
elif value < 0.100:
distribution["0.050-0.100"] += 1
elif value < 0.500:
distribution["0.100-0.500"] += 1
elif value < 1.000:
distribution["0.500-1.000"] += 1
elif value < 2.000:
distribution["1.000-2.000"] += 1
elif value < 5.000:
distribution["2.000-5.000"] += 1
elif value < 10.000:
distribution["5.000-10.000"] += 1
else:
distribution["10.000+"] += 1
return distribution
def get_interest_value_stats(messages) -> Dict[str, float]:
"""Calculate basic statistics for interest_value"""
values = [
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
]
if not values:
return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0}
values.sort()
count = len(values)
return {
"count": count,
"min": min(values),
"max": max(values),
"avg": sum(values) / count,
"median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2,
}
def get_available_chats() -> List[Tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = (
Messages.select()
.where(
(Messages.chat_id == chat_id)
& (Messages.interest_value.is_null(False))
& (Messages.interest_value != 0.0)
)
.count()
)
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
except Exception as e:
print(f"获取聊天列表失败: {e}")
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24 * 3600, now
elif choice == "2":
return now - 3 * 24 * 3600, now
elif choice == "3":
return now - 7 * 24 * 3600, now
elif choice == "4":
return now - 30 * 24 * 3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
return start_time, end_time
except ValueError:
print("时间格式错误,将不限制时间范围")
return None, None
else:
return None, None
def analyze_interest_values(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
) -> None:
"""Analyze interest values with optional filters"""
# 构建查询条件
query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0))
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_interest_value_distribution(messages)
stats = get_interest_value_stats(messages)
# 显示结果
print("\n=== Interest Value 分析结果 ===")
if chat_id:
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
print(f"时间范围: {format_timestamp(start_time)} 之后")
elif end_time:
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"有效消息数量: {stats['count']} (排除null和0值)")
print(f"最小值: {stats['min']:.3f}")
print(f"最大值: {stats['max']:.3f}")
print(f"平均值: {stats['avg']:.3f}")
print(f"中位数: {stats['median']:.3f}")
print("\nInterest Value 分布:")
total = stats["count"]
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
print(f"{range_name}: {count} ({percentage:.2f}%)")
def interactive_menu() -> None:
"""Interactive menu for interest value analysis"""
while True:
print("\n" + "=" * 50)
print("Interest Value 分析工具")
print("=" * 50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == "q":
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到有interest_value数据的聊天")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条有效消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
chat_id = chats[chat_choice - 1][0]
else:
print("无效选择")
continue
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_interest_values(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

View File

@ -1,397 +0,0 @@
import time
import sys
import os
import re
from typing import Dict, List, Tuple, Optional
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams # noqa
def contains_emoji_or_image_tags(text: str) -> bool:
"""Check if text contains [表情包xxxxx] or [图片xxxxx] tags"""
if not text:
return False
# 检查是否包含 [表情包] 或 [图片] 标记
emoji_pattern = r"\[表情包[^\]]*\]"
image_pattern = r"\[图片[^\]]*\]"
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
def clean_reply_text(text: str) -> str:
"""Remove reply references like [回复 xxxx...] from text"""
if not text:
return text
# 匹配 [回复 xxxx...] 格式的内容
# 使用非贪婪匹配,匹配到第一个 ] 就停止
cleaned_text = re.sub(r"\[回复[^\]]*\]", "", text)
# 去除多余的空白字符
cleaned_text = cleaned_text.strip()
return cleaned_text
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def format_timestamp(timestamp: float) -> str:
"""Format timestamp to readable date string"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def calculate_text_length_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of processed_plain_text length"""
distribution = {
"0": 0, # 空文本
"1-5": 0, # 极短文本
"6-10": 0, # 很短文本
"11-20": 0, # 短文本
"21-30": 0, # 较短文本
"31-50": 0, # 中短文本
"51-70": 0, # 中等文本
"71-100": 0, # 较长文本
"101-150": 0, # 长文本
"151-200": 0, # 很长文本
"201-300": 0, # 超长文本
"301-500": 0, # 极长文本
"501-1000": 0, # 巨长文本
"1000+": 0, # 超巨长文本
}
for msg in messages:
if msg.processed_plain_text is None:
continue
# 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text):
continue
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text)
if length == 0:
distribution["0"] += 1
elif length <= 5:
distribution["1-5"] += 1
elif length <= 10:
distribution["6-10"] += 1
elif length <= 20:
distribution["11-20"] += 1
elif length <= 30:
distribution["21-30"] += 1
elif length <= 50:
distribution["31-50"] += 1
elif length <= 70:
distribution["51-70"] += 1
elif length <= 100:
distribution["71-100"] += 1
elif length <= 150:
distribution["101-150"] += 1
elif length <= 200:
distribution["151-200"] += 1
elif length <= 300:
distribution["201-300"] += 1
elif length <= 500:
distribution["301-500"] += 1
elif length <= 1000:
distribution["501-1000"] += 1
else:
distribution["1000+"] += 1
return distribution
def get_text_length_stats(messages) -> Dict[str, float]:
"""Calculate basic statistics for processed_plain_text length"""
lengths = []
null_count = 0
excluded_count = 0 # 被排除的消息数量
for msg in messages:
if msg.processed_plain_text is None:
null_count += 1
elif contains_emoji_or_image_tags(msg.processed_plain_text):
# 排除包含表情包或图片标记的消息
excluded_count += 1
else:
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
lengths.append(len(cleaned_text))
if not lengths:
return {
"count": 0,
"null_count": null_count,
"excluded_count": excluded_count,
"min": 0,
"max": 0,
"avg": 0,
"median": 0,
}
lengths.sort()
count = len(lengths)
return {
"count": count,
"null_count": null_count,
"excluded_count": excluded_count,
"min": min(lengths),
"max": max(lengths),
"avg": sum(lengths) / count,
"median": lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2,
}
def get_available_chats() -> List[Tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id排除特殊类型消息
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = (
Messages.select()
.where(
(Messages.chat_id == chat_id)
& (Messages.is_emoji != 1)
& (Messages.is_picid != 1)
& (Messages.is_command != 1)
)
.count()
)
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
except Exception as e:
print(f"获取聊天列表失败: {e}")
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24 * 3600, now
elif choice == "2":
return now - 3 * 24 * 3600, now
elif choice == "3":
return now - 7 * 24 * 3600, now
elif choice == "4":
return now - 30 * 24 * 3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
return start_time, end_time
except ValueError:
print("时间格式错误,将不限制时间范围")
return None, None
else:
return None, None
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
"""Get top N longest messages"""
message_lengths = []
for msg in messages:
if msg.processed_plain_text is not None:
# 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text):
continue
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text)
chat_name = get_chat_name(msg.chat_id)
time_str = format_timestamp(msg.time)
# 截取前100个字符作为预览
preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
message_lengths.append((chat_name, length, time_str, preview))
# 按长度排序取前N个
message_lengths.sort(key=lambda x: x[1], reverse=True)
return message_lengths[:top_n]
def analyze_text_lengths(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
) -> None:
"""Analyze processed_plain_text lengths with optional filters"""
# 构建查询条件,排除特殊类型的消息
query = Messages.select().where((Messages.is_emoji != 1) & (Messages.is_picid != 1) & (Messages.is_command != 1))
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_text_length_distribution(messages)
stats = get_text_length_stats(messages)
top_longest = get_top_longest_messages(messages, 10)
# 显示结果
print("\n=== Processed Plain Text 长度分析结果 ===")
print("(已排除表情、图片ID、命令类型消息已排除[表情包]和[图片]标记消息,已清理回复引用)")
if chat_id:
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
print(f"时间范围: {format_timestamp(start_time)} 之后")
elif end_time:
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"总消息数量: {len(messages)}")
print(f"有文本消息数量: {stats['count']}")
print(f"空文本消息数量: {stats['null_count']}")
print(f"被排除的消息数量: {stats['excluded_count']}")
if stats["count"] > 0:
print(f"最短长度: {stats['min']} 字符")
print(f"最长长度: {stats['max']} 字符")
print(f"平均长度: {stats['avg']:.2f} 字符")
print(f"中位数长度: {stats['median']:.2f} 字符")
print("\n文本长度分布:")
total = stats["count"]
if total > 0:
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
print(f"{range_name} 字符: {count} ({percentage:.2f}%)")
# 显示最长的消息
if top_longest:
print(f"\n最长的 {len(top_longest)} 条消息:")
for i, (chat_name, length, time_str, preview) in enumerate(top_longest, 1):
print(f"{i}. [{chat_name}] {time_str}")
print(f" 长度: {length} 字符")
print(f" 预览: {preview}")
print()
def interactive_menu() -> None:
"""Interactive menu for text length analysis"""
while True:
print("\n" + "=" * 50)
print("Processed Plain Text 长度分析工具")
print("=" * 50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == "q":
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到聊天数据")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
chat_id = chats[chat_choice - 1][0]
else:
print("无效选择")
continue
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_text_lengths(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

View File

@ -16,8 +16,7 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
from src.chat.express.expression_learner import expression_learner_manager
from src.express.expression_learner import expression_learner_manager
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
@ -96,7 +95,6 @@ class BrainChatting:
self.last_read_time = time.time() - 2
self.more_plan = False
async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。"""
@ -171,10 +169,8 @@ class BrainChatting:
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
await self._observe(
recent_messages_list=recent_messages_list
)
await self._observe(recent_messages_list=recent_messages_list)
else:
# Normal模式消息数量不足等待
await asyncio.sleep(0.2)
@ -233,11 +229,11 @@ class BrainChatting:
async def _observe(
self, # interest_value: float = 0.0,
recent_messages_list: Optional[List["DatabaseMessages"]] = None
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
if recent_messages_list is None:
recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
_reply_text = "" # 初始化reply_text变量避免UnboundLocalError
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat()
@ -286,7 +282,7 @@ class BrainChatting:
prompt_info = (modified_message.llm_prompt, prompt_info[1])
with Timer("规划器", cycle_timers):
action_to_use_info, _ = await self.action_planner.plan(
action_to_use_info = await self.action_planner.plan(
loop_start_time=self.last_read_time,
available_actions=available_actions,
)
@ -334,7 +330,7 @@ class BrainChatting:
"taken_time": time.time(),
}
)
reply_text = reply_text_from_reply
_reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
@ -347,7 +343,7 @@ class BrainChatting:
"taken_time": time.time(),
},
}
reply_text = action_reply_text
_reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
@ -401,7 +397,7 @@ class BrainChatting:
action_handler = self.action_manager.create_action(
action_name=action,
action_data=action_data,
reasoning=reasoning,
action_reasoning=reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
chat_stream=self.chat_stream,
@ -417,8 +413,8 @@ class BrainChatting:
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
return False, "", ""
# 处理动作并获取结果
result = await action_handler.execute()
# 处理动作并获取结果(固定记录一次动作信息)
result = await action_handler.run()
success, action_text = result
command = ""
@ -484,13 +480,12 @@ class BrainChatting:
"""执行单个动作的通用函数"""
try:
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
if action_planner_info.action_type == "no_reply":
# 直接处理no_action逻辑,不再通过动作系统
# 直接处理no_reply逻辑,不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_action信息到数据库
# 存储no_reply信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
@ -498,9 +493,9 @@ class BrainChatting:
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="no_action",
action_name="no_reply",
)
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
elif action_planner_info.action_type == "reply":
try:
@ -517,7 +512,9 @@ class BrainChatting:
if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message:
logger.info(f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败")
logger.info(
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}

View File

@ -152,10 +152,10 @@ class BrainPlanner:
action_planner_infos = []
try:
action = action_json.get("action", "no_action")
action = action_json.get("action", "no_reply")
reasoning = action_json.get("reason", "未提供原因")
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
# 非no_action动作需要target_message_id
# 非no_reply动作需要target_message_id
target_message = None
if target_message_id := action_json.get("target_message_id"):
@ -215,12 +215,11 @@ class BrainPlanner:
self,
available_actions: Dict[str, ActionInfo],
loop_start_time: float = 0.0,
) -> Tuple[List[ActionPlannerInfo], Optional["DatabaseMessages"]]:
) -> List[ActionPlannerInfo]:
# sourcery skip: use-named-expression
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作
"""
target_message: Optional["DatabaseMessages"] = None
# 获取聊天上下文
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
@ -274,12 +273,7 @@ class BrainPlanner:
loop_start_time=loop_start_time,
)
# 获取target_message如果有非no_action的动作
non_no_actions = [a for a in actions if a.action_type != "no_reply"]
if non_no_actions:
target_message = non_no_actions[0].action_message
return actions, target_message
return actions
async def build_planner_prompt(
self,
@ -307,7 +301,9 @@ class BrainPlanner:
if chat_target_info:
# 构建聊天上下文描述
chat_context_description = f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
chat_context_description = (
f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
)
# 构建动作选项块
action_options_block = await self._build_action_options_block(current_available_actions)
@ -487,19 +483,19 @@ class BrainPlanner:
else:
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
# 添加循环开始时间到所有非no_action动作
# 添加循环开始时间到所有非no_reply动作
for action in actions:
action.action_data = action.action_data or {}
action.action_data["loop_start_time"] = loop_start_time
logger.info(
logger.debug(
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
return actions
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
"""创建no_action"""
"""创建no_reply"""
return [
ActionPlannerInfo(
action_type="no_reply",

View File

@ -1,604 +0,0 @@
import time
import random
import json
import os
from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 15 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
logger = get_logger("expressor")
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def init_prompt() -> None:
learn_style_prompt = """
{chat_str}
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
1. 只考虑文字不要考虑表情包和图片
2. 不要涉及具体的人名但是可以涉及具体名词
3. 思考有没有特殊的梗一并总结成语言风格
4. 例子仅供参考请严格根据群聊内容总结!!!
注意总结成如下格式的规律总结的内容要详细但具有概括性
例如"AAAAA"可以"BBBBB", AAAAA代表某个具体的场景不超过20个字BBBBB代表对应的语言风格特定句式或表达方式不超过20个字
例如
"对某件事表示十分惊叹"使用"我嘞个xxxx"
"表示讽刺的赞同,不讲道理"使用"对对对"
"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
"当涉及游戏相关时,夸赞,略带戏谑意味"使用"这么强!"
请注意不要总结你自己SELF的发言尽量保证总结内容的逻辑性
现在请你概括
"""
Prompt(learn_style_prompt, "learn_style_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.replyer, request_type="expression.learner"
)
self.chat_id = chat_id
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次学习时间
self.last_learning_time: float = time.time()
# 学习参数
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
def can_learn_for_chat(self) -> bool:
"""
检查指定聊天流是否允许学习表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许学习
"""
try:
use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id)
return enable_learning
except Exception as e:
logger.error(f"检查学习权限失败: {e}")
return False
def should_trigger_learning(self) -> bool:
"""
检查是否应该触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否应该触发学习
"""
current_time = time.time()
# 获取该聊天流的学习强度
try:
_, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
except Exception as e:
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
return False
# 检查是否允许学习
if not enable_learning:
return False
# 根据学习强度计算最短学习时间间隔
min_interval = self.min_learning_interval / learning_intensity
# 检查时间间隔
time_diff = current_time - self.last_learning_time
if time_diff < min_interval:
return False
# 检查消息数量(只检查指定聊天流的消息)
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
return False
return True
async def trigger_learning_for_chat(self) -> bool:
"""
为指定聊天流触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否成功触发学习
"""
if not self.should_trigger_learning():
return False
try:
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
# 学习语言风格
learnt_style = await self.learn_and_store(num=25)
# 更新学习时间
self.last_learning_time = time.time()
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
return True
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
return False
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False
def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
"""
try:
# 获取所有表达方式
all_expressions = Expression.select()
updated_count = 0
deleted_count = 0
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01:
# 如果count太小删除这个表达方式
expr.delete_instance()
deleted_count += 1
else:
# 更新count
expr.count = new_count
expr.save()
updated_count += 1
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float:
"""
计算衰减值
当时间差为0天时衰减值为0最近活跃的不衰减
当时间差为7天时衰减值为0.002中等衰减
当时间差为30天或更长时衰减值为0.01高衰减
使用二次函数进行曲线插值
"""
if time_diff_days <= 0:
return 0.0 # 刚激活的表达式不衰减
if time_diff_days >= DECAY_DAYS:
return 0.01 # 长时间未活跃的表达式大幅衰减
# 使用二次函数插值在0-30天之间从0衰减到0.01
# 使用简单的二次函数y = a * x^2
# 当x=30时y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
a = 0.01 / (DECAY_DAYS**2)
decay = a * (time_diff_days**2)
return min(0.01, decay)
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
"""
# 检查是否允许在此聊天流中学习(在函数最前面检查)
if not self.can_learn_for_chat():
logger.debug(f"聊天流 {self.chat_name} 不允许学习表达,跳过学习")
return []
res = await self.learn_expression(num)
if res is None:
return []
learnt_expressions, chat_id = res
chat_stream = get_chat_manager().get_stream(chat_id)
if chat_stream is None:
group_name = f"聊天流 {chat_id}"
elif chat_stream.group_info:
group_name = chat_stream.group_info.group_name
else:
group_name = f"{chat_stream.user_info.user_nickname}的私聊"
learnt_expressions_str = ""
for _chat_id, situation, style in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{group_name} 学习到表达风格:\n{learnt_expressions_str}")
if not learnt_expressions:
logger.info("没有学习到表达风格")
return []
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for chat_id, situation, style in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({"situation": situation, "style": style})
current_time = time.time()
# 存储到数据库 Expression 表
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == "style")
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
if query.exists():
expr_obj = query.get()
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
situation=new_expr["situation"],
style=new_expr["style"],
count=1,
last_active_time=current_time,
chat_id=chat_id,
type="style",
create_date=current_time, # 手动设置创建日期
)
# 限制最大数量
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
.order_by(Expression.count.asc())
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
return learnt_expressions
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
"""从指定聊天流学习表达方式
Args:
num: 学习数量
"""
type_str = "语言风格"
prompt = "learn_style_prompt"
current_time = time.time()
# 获取上次学习时间
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=current_time,
limit=num,
)
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 转化成str
chat_id: str = random_msg[0].chat_id
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
random_msg_str: str = await build_anonymous_messages(random_msg)
# print(f"random_msg_str:{random_msg_str}")
prompt: str = await global_prompt_manager.format_prompt(
prompt,
chat_str=random_msg_str,
)
logger.debug(f"学习{type_str}的prompt: {prompt}")
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习{type_str}失败: {e}")
return None
logger.debug(f"学习{type_str}的response: {response}")
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
return expressions, chat_id
def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组
"""
expressions: List[Tuple[str, str, str]] = []
for line in response.splitlines():
line = line.strip()
if not line:
continue
# 查找"当"和下一个引号
idx_when = line.find('"')
if idx_when == -1:
continue
idx_quote1 = idx_when + 1
idx_quote2 = line.find('"', idx_quote1 + 1)
if idx_quote2 == -1:
continue
situation = line[idx_quote1 + 1 : idx_quote2]
# 查找"使用"
idx_use = line.find('使用"', idx_quote2)
if idx_use == -1:
continue
idx_quote3 = idx_use + 2
idx_quote4 = line.find('"', idx_quote3 + 1)
if idx_quote4 == -1:
continue
style = line[idx_quote3 + 1 : idx_quote4]
expressions.append((chat_id, situation, style))
return expressions
init_prompt()
class ExpressionLearnerManager:
def __init__(self):
self.expression_learners = {}
self._ensure_expression_directories()
self._auto_migrate_json_to_db()
self._migrate_old_data_create_date()
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
if chat_id not in self.expression_learners:
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
return self.expression_learners[chat_id]
def _ensure_expression_directories(self):
"""
确保表达方式相关的目录结构存在
"""
base_dir = os.path.join("data", "expression")
directories_to_create = [
base_dir,
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}")
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
def _auto_migrate_json_to_db(self):
"""
自动将/data/expression/learnt_style learnt_grammar 下所有expressions.json迁移到数据库
迁移完成后在/data/expression/done.done写入标记文件存在则跳过
然后检查done.done2如果没有就删除所有grammar表达并创建该标记文件
"""
base_dir = os.path.join("data", "expression")
done_flag = os.path.join(base_dir, "done.done")
done_flag2 = os.path.join(base_dir, "done.done2")
# 确保基础目录存在
try:
os.makedirs(base_dir, exist_ok=True)
logger.debug(f"确保目录存在: {base_dir}")
except Exception as e:
logger.error(f"创建表达方式目录失败: {e}")
return
if os.path.exists(done_flag):
logger.info("表达方式JSON已迁移无需重复迁移。")
else:
logger.info("开始迁移表达方式JSON到数据库...")
migrated_count = 0
for type in ["learnt_style", "learnt_grammar"]:
type_str = "style" if type == "learnt_style" else "grammar"
type_dir = os.path.join(base_dir, type)
if not os.path.exists(type_dir):
logger.debug(f"目录不存在,跳过: {type_dir}")
continue
try:
chat_ids = os.listdir(type_dir)
logger.debug(f"{type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
except Exception as e:
logger.error(f"读取目录失败 {type_dir}: {e}")
continue
for chat_id in chat_ids:
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
if not os.path.exists(expr_file):
continue
try:
with open(expr_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
continue
for expr in expressions:
if not isinstance(expr, dict):
continue
situation = expr.get("situation")
style_val = expr.get("style")
count = expr.get("count", 1)
last_active_time = expr.get("last_active_time", time.time())
if not situation or not style_val:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)
if query.exists():
expr_obj = query.get()
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
expr_obj.save()
else:
Expression.create(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
done_parent_dir = os.path.dirname(done_flag)
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
except PermissionError as e:
logger.error(f"权限不足无法写入done.done标记文件: {e}")
except OSError as e:
logger.error(f"文件系统错误无法写入done.done标记文件: {e}")
except Exception as e:
logger.error(f"写入done.done标记文件失败: {e}")
# 检查并处理grammar表达删除
if not os.path.exists(done_flag2):
logger.info("开始删除所有grammar类型的表达...")
try:
deleted_count = self.delete_all_grammar_expressions()
logger.info(f"grammar表达删除完成共删除 {deleted_count} 个表达")
# 创建done.done2标记文件
with open(done_flag2, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info("已创建done.done2标记文件grammar表达删除标记完成")
except Exception as e:
logger.error(f"删除grammar表达或创建标记文件失败: {e}")
else:
logger.info("grammar表达已删除跳过重复删除")
def _migrate_old_data_create_date(self):
"""
为没有create_date的老数据设置创建日期
使用last_active_time作为create_date的默认值
"""
try:
# 查找所有create_date为空的表达方式
old_expressions = Expression.select().where(Expression.create_date.is_null())
updated_count = 0
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
expr.save()
updated_count += 1
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e:
logger.error(f"迁移老数据创建日期失败: {e}")
def delete_all_grammar_expressions(self) -> int:
"""
检查expression库中所有type为"grammar"的表达并全部删除
Returns:
int: 删除的grammar表达数量
"""
try:
# 查询所有type为"grammar"的表达
grammar_expressions = Expression.select().where(Expression.type == "grammar")
grammar_count = grammar_expressions.count()
if grammar_count == 0:
logger.info("expression库中没有找到grammar类型的表达")
return 0
logger.info(f"找到 {grammar_count} 个grammar类型的表达开始删除...")
# 删除所有grammar类型的表达
deleted_count = 0
for expr in grammar_expressions:
try:
expr.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"删除grammar表达失败: {e}")
continue
logger.info(f"成功删除 {deleted_count} 个grammar类型的表达")
return deleted_count
except Exception as e:
logger.error(f"删除grammar表达过程中发生错误: {e}")
return 0
expression_learner_manager = ExpressionLearnerManager()

View File

@ -1,316 +0,0 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境
{all_situations}
请你分析聊天内容的语境情绪话题类型从上述情境中选择最适合当前聊天情境的最多{max_num}个情境
考虑因素包括
1. 聊天的情绪氛围轻松严肃幽默等
2. 话题类型日常技术游戏情感等
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如
{{
"selected_situations": [2, 3, 5, 7, 19]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用累积权重的方法进行加权抽样
selected = []
population_copy = population.copy()
weights_copy = weights.copy()
for _ in range(k):
if not population_copy:
break
# 选择一个元素
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
selected.append(population_copy.pop(chosen_idx))
weights_copy.pop(chosen_idx)
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
]
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
selected_style = weighted_sample(style_exprs, style_weights, total_num)
else:
selected_style = []
return selected_style
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 1. 获取20个随机表达方式现在按权重抽取
style_exprs = self.get_random_expressions(chat_id, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in style_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return [], []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# 4. 调用LLM
try:
# start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
# logger.info(f"模型名称: {model_name}")
# logger.info(f"LLM返回结果: {content}")
# if reasoning_content:
# logger.info(f"LLM推理: {reasoning_content}")
# else:
# logger.info(f"LLM推理: 无")
if not content:
logger.warning("LLM返回空结果")
return [], []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return [], []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression)
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
self.update_expressions_count_batch(valid_expressions, 0.006)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions, selected_ids
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return [], []
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@ -43,4 +43,4 @@ class FrequencyControlManager:
# 创建全局实例
frequency_control_manager = FrequencyControlManager()
frequency_control_manager = FrequencyControlManager()

View File

@ -1,4 +1,5 @@
import asyncio
from multiprocessing import context
import time
import traceback
import random
@ -17,14 +18,15 @@ from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
from src.chat.express.expression_learner import expression_learner_manager
from src.express.expression_learner import expression_learner_manager
from src.chat.frequency_control.frequency_control import frequency_control_manager
from src.memory_system.question_maker import QuestionMaker
from src.memory_system.questions import global_conflict_tracker
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.mais4u.mai_think import mai_thinking_manager
from src.mais4u.s4u_config import s4u_config
from src.memory_system.Memory_chest import global_memory_chest
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
@ -98,11 +100,15 @@ class HeartFChatting:
self._current_cycle_detail: CycleDetail = None # type: ignore
self.last_read_time = time.time() - 2
self.talk_threshold = global_config.chat.talk_value
self.no_reply_until_call = False
self.is_mute = False
self.last_active_time = time.time() # 记录上一次非noreply时间
self.questioned = False
async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。"""
@ -154,16 +160,19 @@ class HeartFChatting:
# 记录循环信息和计时器结果
timer_strings = []
for name, elapsed in cycle_timers.items():
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}"
if elapsed < 0.1:
# 不显示小于0.1秒的计时器
continue
formatted_time = f"{elapsed:.2f}"
timer_strings.append(f"{name}: {formatted_time}")
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}" # type: ignore
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f};" # type: ignore
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
async def _loopbody(self): # sourcery skip: hoist-if-from-if
async def _loopbody(self):
recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=self.last_read_time,
@ -174,7 +183,43 @@ class HeartFChatting:
filter_command=True,
)
question_probability = 0
if time.time() - self.last_active_time > 3600:
question_probability = 0.001
elif time.time() - self.last_active_time > 1200:
question_probability = 0.0003
else:
question_probability = 0.0001
question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id)
# print(f"{self.log_prefix} questioned: {self.questioned},len: {len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id))}")
if question_probability > 0 and not self.questioned and len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id)) == 0: #长久没有回复,可以试试主动发言,提问概率随着时间增加
# logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,概率: {question_probability}")
if random.random() < question_probability: # 30%概率主动发言
try:
self.questioned = True
self.last_active_time = time.time()
# print(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
cycle_timers, thinking_id = self.start_cycle()
question_maker = QuestionMaker(self.stream_id)
question, context,conflict_context = await question_maker.make_question()
if question:
logger.info(f"{self.log_prefix} 问题: {question}")
await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id)
await self._lift_question_reply(question,context,thinking_id)
else:
logger.info(f"{self.log_prefix} 无问题")
# self.end_cycle(cycle_timers, thinking_id)
except Exception as e:
logger.error(f"{self.log_prefix} 主动提问失败: {e}")
print(traceback.format_exc())
if len(recent_messages_list) >= 1:
# for message in recent_messages_list:
# print(message.processed_plain_text)
# !处理no_reply_until_call逻辑
if self.no_reply_until_call:
for message in recent_messages_list:
@ -185,6 +230,7 @@ class HeartFChatting:
or time.time() - self.last_read_time > 600
):
self.no_reply_until_call = False
self.last_read_time = time.time()
break
# 没有提到,继续保持沉默
if self.no_reply_until_call:
@ -200,15 +246,22 @@ class HeartFChatting:
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
mentioned_message = message
logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
# *控制频率用
if mentioned_message:
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
elif random.random() < global_config.chat.talk_value * frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust():
elif (
random.random()
< global_config.chat.get_talk_value(self.stream_id)
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
):
await self._observe(recent_messages_list=recent_messages_list)
else:
# 没有提到继续保持沉默等待5秒防止频繁触发
await asyncio.sleep(10)
return True
else:
# Normal模式消息数量不足等待
await asyncio.sleep(0.2)
return True
return True
@ -272,12 +325,14 @@ class HeartFChatting:
recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
if s4u_config.enable_s4u:
await send_typing()
start_time = time.time()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat()
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
@ -322,7 +377,7 @@ class HeartFChatting:
prompt_info = (modified_message.llm_prompt, prompt_info[1])
with Timer("规划器", cycle_timers):
action_to_use_info, _ = await self.action_planner.plan(
action_to_use_info = await self.action_planner.plan(
loop_start_time=self.last_read_time,
available_actions=available_actions,
)
@ -344,6 +399,10 @@ class HeartFChatting:
)
)
logger.info(
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
)
# 3. 并行执行所有动作
action_tasks = [
asyncio.create_task(
@ -361,21 +420,26 @@ class HeartFChatting:
action_success = False
action_reply_text = ""
excute_result_str = ""
for result in results:
excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
if result["action_type"] != "reply":
action_success = result["success"]
action_reply_text = result["reply_text"]
action_reply_text = result["result"]
elif result["action_type"] == "reply":
if result["success"]:
reply_loop_info = result["loop_info"]
reply_text_from_reply = result["reply_text"]
reply_text_from_reply = result["result"]
else:
logger.warning(f"{self.log_prefix} 回复动作执行失败")
self.action_planner.add_plan_excute_log(result=excute_result_str)
# 构建最终的循环信息
if reply_loop_info:
# 如果有回复信息使用回复的loop_info作为基础
@ -405,12 +469,12 @@ class HeartFChatting:
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
"""S4U内容暂时保留"""
if s4u_config.enable_s4u:
await stop_typing()
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
"""S4U内容暂时保留"""
end_time = time.time()
if end_time - start_time < global_config.chat.planner_smooth:
wait_time = global_config.chat.planner_smooth - (end_time - start_time)
await asyncio.sleep(wait_time)
else:
await asyncio.sleep(0.1)
return True
async def _main_chat_loop(self):
@ -435,7 +499,7 @@ class HeartFChatting:
async def _handle_action(
self,
action: str,
reasoning: str,
action_reasoning: str,
action_data: dict,
cycle_timers: Dict[str, float],
thinking_id: str,
@ -446,11 +510,11 @@ class HeartFChatting:
参数:
action: 动作类型
reasoning: 决策理由
action_reasoning: 决策理由
action_data: 动作数据包含不同动作需要的参数
cycle_timers: 计时器字典
thinking_id: 思考ID
action_message: 消息数据
返回:
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
"""
@ -460,33 +524,101 @@ class HeartFChatting:
action_handler = self.action_manager.create_action(
action_name=action,
action_data=action_data,
reasoning=reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
chat_stream=self.chat_stream,
log_prefix=self.log_prefix,
action_reasoning=action_reasoning,
action_message=action_message,
)
except Exception as e:
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
traceback.print_exc()
return False, "", ""
return False, ""
if not action_handler:
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
return False, "", ""
# 处理动作并获取结果
# 处理动作并获取结果(固定记录一次动作信息)
result = await action_handler.execute()
success, action_text = result
command = ""
return success, action_text, command
return success, action_text
except Exception as e:
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
traceback.print_exc()
return False, "", ""
return False, ""
async def _lift_question_reply(self, question: str, question_context: str, thinking_id: str):
reason = f"在聊天中:\n{question_context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
new_msg = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=1,
)
reply_action_info = ActionPlannerInfo(
action_type="reply",
reasoning= "",
action_data={},
action_message=new_msg[0],
available_actions=None,
loop_start_time=time.time(),
action_reasoning=reason)
self.action_planner.add_plan_log(reasoning=f"你对问题\"{question}\"感到好奇,想要和群友讨论", actions=[reply_action_info])
success, llm_response = await generator_api.rewrite_reply(
chat_stream=self.chat_stream,
reply_data={
"raw_reply": f"我对这个问题感到好奇:{question}",
"reason": reason,
},
)
if not success or not llm_response or not llm_response.reply_set:
logger.info("主动提问发言失败")
self.action_planner.add_plan_excute_log(result="主动回复生成失败")
return {"action_type": "reply", "success": False, "result": "主动回复生成失败", "loop_info": None}
if success:
for reply_seg in llm_response.reply_set.reply_data:
send_data = reply_seg.content
await send_api.text_to_stream(
text=send_data,
stream_id=self.stream_id,
)
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={"reply_text": llm_response.reply_set.reply_data[0].content},
action_name="reply",
)
# 构建循环信息
loop_info: Dict[str, Any] = {
"loop_plan_info": {
"action_result": [reply_action_info],
},
"loop_action_info": {
"action_taken": True,
"reply_text": llm_response.reply_set.reply_data[0].content,
"command": "",
"taken_time": time.time(),
},
}
self.last_active_time = time.time()
self.action_planner.add_plan_excute_log(result=f"你提问:{question}")
return {
"action_type": "reply",
"success": True,
"result": f"你提问:{question}",
"loop_info": loop_info,
}
async def _send_response(
self,
@ -498,7 +630,7 @@ class HeartFChatting:
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
need_reply = new_message_count >= random.randint(2, 4)
need_reply = new_message_count >= random.randint(2, 3)
if need_reply:
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
@ -543,59 +675,83 @@ class HeartFChatting:
"""执行单个动作的通用函数"""
try:
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
# 直接当场执行no_reply逻辑
if action_planner_info.action_type == "no_reply":
# 直接处理no_action逻辑,不再通过动作系统
# 直接处理no_reply逻辑,不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_action信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="no_action",
action_data={},
action_name="no_reply",
action_reasoning=reason,
)
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
elif action_planner_info.action_type == "wait_time":
action_planner_info.action_data = action_planner_info.action_data or {}
logger.info(f"{self.log_prefix} 等待{action_planner_info.action_data['time']}秒后回复")
await asyncio.sleep(action_planner_info.action_data["time"])
return {"action_type": "wait_time", "success": True, "reply_text": "", "command": ""}
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
elif action_planner_info.action_type == "no_reply_until_call":
# 直接当场执行no_reply_until_call逻辑
logger.info(f"{self.log_prefix} 保持沉默,直到有人直接叫的名字")
reason = action_planner_info.reasoning or "选择不回复"
self.no_reply_until_call = True
return {"action_type": "no_reply_until_call", "success": True, "reply_text": "", "command": ""}
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={},
action_name="no_reply_until_call",
action_reasoning=reason,
)
return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""}
elif action_planner_info.action_type == "reply":
try:
success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_planner_info.action_message,
available_actions=available_actions,
chosen_actions=chosen_action_plan_infos,
reply_reason=action_planner_info.reasoning or "",
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
)
# 直接当场执行reply逻辑
self.questioned = False
# 刷新主动发言状态
reason = action_planner_info.reasoning or "选择回复"
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={},
action_name="reply",
action_reasoning=reason,
)
success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_planner_info.action_message,
available_actions=available_actions,
chosen_actions=chosen_action_plan_infos,
reply_reason=reason,
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()),
)
if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message:
logger.info(
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
if not success or not llm_response or not llm_response.reply_set:
if action_planner_info.action_message:
logger.info(
f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败"
)
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply(
@ -606,30 +762,30 @@ class HeartFChatting:
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
)
self.last_active_time = time.time()
return {
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"result": f"你回复内容{reply_text}",
"loop_info": loop_info,
}
# 其他动作
else:
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_planner_info.action_type,
action_planner_info.reasoning or "",
action_planner_info.action_data or {},
cycle_timers,
thinking_id,
action_planner_info.action_message,
success, result = await self._handle_action(
action = action_planner_info.action_type,
action_reasoning = action_planner_info.action_reasoning or "",
action_data = action_planner_info.action_data or {},
cycle_timers = cycle_timers,
thinking_id = thinking_id,
action_message= action_planner_info.action_message,
)
self.last_active_time = time.time()
return {
"action_type": action_planner_info.action_type,
"success": success,
"reply_text": reply_text,
"command": command,
"result": result,
}
except Exception as e:
@ -638,7 +794,7 @@ class HeartFChatting:
return {
"action_type": action_planner_info.action_type,
"success": False,
"reply_text": "",
"result": "",
"loop_info": None,
"error": str(e),
}

View File

@ -1,17 +1,14 @@
import asyncio
import re
import traceback
from typing import Tuple, TYPE_CHECKING
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.chat_message_builder import replace_user_references
from src.common.logger import get_logger
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import Person
from src.common.database.database_model import Images
@ -75,11 +72,7 @@ class HeartFCMessageReceiver:
await self.storage.store_message(message, chat)
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message))
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
# 3. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
@ -107,7 +100,7 @@ class HeartFCMessageReceiver:
replace_bot_name=True,
)
# if not processed_plain_text:
# print(message)
# print(message)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore

File diff suppressed because it is too large Load Diff

View File

@ -1,241 +0,0 @@
import json
import random
from json_repair import repair_json
from typing import List, Tuple
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.utils import parse_keywords_string
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> List:
"""
从JSON字符串中提取关键词列表
Args:
json_str: JSON格式的字符串
Returns:
List[str]: 关键词列表
"""
try:
# 使用repair_json修复JSON格式
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []
def init_prompt():
# --- Group Chat Prompt ---
memory_activator_prompt = """
你需要根据以下信息来挑选合适的记忆编号
以下是一段聊天记录请根据这些信息和下方的记忆挑选和群聊内容有关的记忆编号
聊天记录:
{obs_info_text}
你想要回复的消息:
{target_message}
记忆
{memory_info}
请输出一个json格式包含以下字段
{{
"memory_ids": "记忆1编号,记忆2编号,记忆3编号,......"
}}
不要输出其他多余内容只输出json格式就好
"""
Prompt(memory_activator_prompt, "memory_activator_prompt")
class MemoryActivator:
def __init__(self):
self.key_words_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
# 用于记忆选择的 LLM 模型
self.memory_selection_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory.selection",
)
async def activate_memory_with_chat_history(
self, target_message, chat_history: List[DatabaseMessages]
) -> List[Tuple[str, str]]:
"""
激活记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
keywords_list = set()
for msg in chat_history:
keywords = parse_keywords_string(msg.key_words)
if keywords:
if len(keywords_list) < 30:
# 最多容纳30个关键词
keywords_list.update(keywords)
logger.debug(f"提取关键词: {keywords_list}")
else:
break
if not keywords_list:
logger.debug("没有提取到关键词,返回空记忆列表")
return []
# 从海马体获取相关记忆
related_memory = await hippocampus_manager.get_memory_from_topic(
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
)
# logger.info(f"当前记忆关键词: {keywords_list}")
logger.debug(f"获取到的记忆: {related_memory}")
if not related_memory:
logger.debug("海马体没有返回相关记忆")
return []
used_ids = set()
candidate_memories = []
# 为每个记忆分配随机ID并过滤相关记忆
for memory in related_memory:
keyword, content = memory
found = any(kw in content for kw in keywords_list)
if found:
# 随机分配一个不重复的2位数id
while True:
random_id = "{:02d}".format(random.randint(0, 99))
if random_id not in used_ids:
used_ids.add(random_id)
break
candidate_memories.append({"memory_id": random_id, "keyword": keyword, "content": content})
if not candidate_memories:
logger.info("没有找到相关的候选记忆")
return []
# 如果只有少量记忆,直接返回
if len(candidate_memories) <= 2:
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
return await self._select_memories_with_llm(target_message, chat_history, candidate_memories)
async def _select_memories_with_llm(
self, target_message, chat_history: List[DatabaseMessages], candidate_memories
) -> List[Tuple[str, str]]:
"""
使用 LLM 选择合适的记忆
Args:
target_message: 目标消息
chat_history_prompt: 聊天历史
candidate_memories: 候选记忆列表每个记忆包含 memory_idkeywordcontent
Returns:
List[Tuple[str, str]]: 选择的记忆列表格式为 (keyword, content)
"""
try:
# 构建聊天历史字符串
obs_info_text = build_readable_messages(
chat_history,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
)
# 构建记忆信息字符串
memory_lines = []
for memory in candidate_memories:
memory_id = memory["memory_id"]
keyword = memory["keyword"]
content = memory["content"]
# 将 content 列表转换为字符串
if isinstance(content, list):
content_str = " | ".join(str(item) for item in content)
else:
content_str = str(content)
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
memory_info = "\n".join(memory_lines)
# 获取并格式化 prompt
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
formatted_prompt = prompt_template.format(
obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info
)
# 调用 LLM
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
formatted_prompt, temperature=0.3, max_tokens=150
)
if global_config.debug.show_prompt:
logger.info(f"记忆选择 prompt: {formatted_prompt}")
logger.info(f"LLM 记忆选择响应: {response}")
else:
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
logger.debug(f"LLM 记忆选择响应: {response}")
# 解析响应获取选择的记忆编号
try:
fixed_json = repair_json(response)
# 解析为 Python 对象
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
# 提取 memory_ids 字段并解析逗号分隔的编号
if memory_ids_str := result.get("memory_ids", ""):
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
# 过滤掉空字符串和无效编号
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
selected_memory_ids = valid_memory_ids
else:
selected_memory_ids = []
except Exception as e:
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
selected_memory_ids = []
# 根据编号筛选记忆
selected_memories = []
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
selected_memories = [
memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory
]
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
except Exception as e:
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
# 出错时返回前3个候选记忆作为备选转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
init_prompt()

View File

@ -3,19 +3,18 @@ import os
import re
from typing import Dict, Any, Optional
from maim_message import UserInfo, Seg
from maim_message import UserInfo, Seg, GroupInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.person_info.person_info import Person
# 定义日志配置
@ -27,7 +26,7 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
logger = get_logger("chat")
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
def _check_ban_words(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
"""检查消息是否包含过滤词
Args:
@ -40,14 +39,14 @@ def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""
for word in global_config.message_receive.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
chat_name = group_info.group_name if group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return True
return False
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
def _check_ban_regex(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
"""检查消息是否匹配过滤正则表达式
Args:
@ -61,10 +60,10 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
# 检查text是否为None或空字符串
if text is None or not text:
return False
for pattern in global_config.message_receive.ban_msgs_regex:
if re.search(pattern, text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
chat_name = group_info.group_name if group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True
@ -78,8 +77,6 @@ class ChatBot:
self.mood_manager = mood_manager # 获取情绪管理器单例
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
self.s4u_message_processor = S4UMessageProcessor()
async def _ensure_started(self):
"""确保所有任务已启动"""
if not self._started:
@ -153,35 +150,10 @@ class ChatBot:
if message.message_info.message_id == "notice":
message.is_notify = True
logger.info("notice消息")
# print(message)
print(message)
return True
async def do_s4u(self, message_data: Dict[str, Any]):
message = MessageRecvS4U(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
message.update_chat_stream(chat)
# 处理消息内容
await message.process()
_ = Person.register_person(
platform=message.message_info.platform, # type: ignore
user_id=message.message_info.user_info.user_id, # type: ignore
nickname=user_info.user_nickname, # type: ignore
)
await self.s4u_message_processor.process_message(message)
return
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
@ -219,11 +191,6 @@ class ChatBot:
# 确保所有任务已启动
await self._ensure_started()
platform = message_data["message_info"].get("platform")
if platform == "amaidesu_default":
await self.do_s4u(message_data)
return
if message_data["message_info"].get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str(
@ -251,6 +218,21 @@ class ChatBot:
# return
pass
# 处理消息内容,生成纯文本
await message.process()
# 过滤检查
if _check_ban_words(
message.processed_plain_text,
user_info, # type: ignore
group_info,
) or _check_ban_regex(
message.raw_message, # type: ignore
user_info, # type: ignore
group_info,
):
return
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
@ -261,21 +243,10 @@ class ChatBot:
message.update_chat_stream(chat)
# 处理消息内容,生成纯文本
await message.process()
# if await self.check_ban_content(message):
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
# return
# 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
chat,
user_info, # type: ignore
):
return
# 命令处理 - 使用新插件系统检查并处理命令
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)

View File

@ -138,6 +138,7 @@ class MessageRecv(Message):
这个方法必须在创建实例后显式调用因为它包含异步操作
"""
# print(f"self.message_segment: {self.message_segment}")
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
@ -208,129 +209,6 @@ class MessageRecv(Message):
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]):
super().__init__(message_dict)
self.is_gift = False
self.is_fake_gift = False
self.is_superchat = False
self.gift_info = None
self.gift_name = None
self.gift_count: Optional[str] = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None
self.is_screen = False
self.is_internal = False
self.voice_done = None
self.chat_info = None
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
return segment.data # type: ignore
elif segment.type == "image":
self.is_voice = False
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.has_picid = False
self.is_picid = False
self.is_emoji = False
self.is_voice = True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "gift":
self.is_voice = False
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1) # type: ignore
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
return ""
elif segment.type == "voice_done":
msg_id = segment.data
logger.info(f"voice_done: {msg_id}")
self.voice_done = msg_id
return ""
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price, message_text = segment.data.split(":", 1) # type: ignore
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += (
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
)
return self.processed_plain_text
elif segment.type == "screen":
self.is_screen = True
self.screen_info = segment.data
return "屏幕信息"
else:
return ""
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""

View File

@ -32,7 +32,7 @@ class ActionManager:
self,
action_name: str,
action_data: dict,
reasoning: str,
action_reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: ChatStream,
@ -46,7 +46,7 @@ class ActionManager:
Args:
action_name: 动作名称
action_data: 动作数据
reasoning: 执行理由
action_reasoning: 执行理由
cycle_timers: 计时器字典
thinking_id: 思考ID
chat_stream: 聊天流
@ -77,7 +77,7 @@ class ActionManager:
# 创建动作实例
instance = component_class(
action_data=action_data,
reasoning=reasoning,
action_reasoning=action_reasoning,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
chat_stream=chat_stream,

View File

@ -3,7 +3,7 @@ import time
import traceback
import random
import re
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
from rich.traceback import install
from datetime import datetime
from json_repair import repair_json
@ -44,14 +44,13 @@ def init_prompt():
**聊天内容**
{chat_content_block}
**动作记录**
{actions_before_now_block}
**可用的action**
**可选的action**
reply
动作描述
1.你可以选择呼叫了你的名字但是你没有做出回应的消息进行回复
2.你可以自然的顺着正在进行的聊天内容进行回复或自然的提出一个问题
3.不要回复你自己发送的消息
4.不要单独对表情包进行回复
{{
"action": "reply",
"target_message_id":"想要回复的消息id",
@ -76,7 +75,11 @@ no_reply_until_call
{action_options_text}
请选择合适的action并说明触发action的消息id和选择该action的原因消息id格式:m+数字
**你之前的action执行和思考记录**
{actions_before_now_block}
请选择**可选的**且符合使用条件的action并说明触发action的消息id(消息id格式:m+数字)
不要回复你自己发送的消息
先输出你的选择思考理由再输出你选择的action理由是一段平文本不要分点精简
**动作选择要求**
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
@ -99,9 +102,7 @@ no_reply_until_call
"target_message_id":"触发动作的消息id",
//对应参数
}}
```
""",
```""",
"planner_prompt",
)
@ -109,7 +110,7 @@ no_reply_until_call
"""
{action_name}
动作描述{action_description}
使用条件
使用条件{parallel_text}
{action_require}
{{
"action": "{action_name}",{action_parameters},
@ -133,6 +134,9 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
def find_message_by_id(
self, message_id: str, message_id_list: List[Tuple[str, "DatabaseMessages"]]
) -> Optional["DatabaseMessages"]:
@ -157,15 +161,16 @@ class ActionPlanner:
action_json: dict,
message_id_list: List[Tuple[str, "DatabaseMessages"]],
current_available_actions: List[Tuple[str, ActionInfo]],
extracted_reasoning: str = "",
) -> List[ActionPlannerInfo]:
"""解析单个action JSON并返回ActionPlannerInfo列表"""
action_planner_infos = []
try:
action = action_json.get("action", "no_action")
action = action_json.get("action", "no_reply")
reasoning = action_json.get("reason", "未提供原因")
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
# 非no_action动作需要target_message_id
# 非no_reply动作需要target_message_id
target_message = None
if target_message_id := action_json.get("target_message_id"):
@ -202,6 +207,7 @@ class ActionPlanner:
action_data=action_data,
action_message=target_message,
available_actions=available_actions_dict,
action_reasoning=extracted_reasoning if extracted_reasoning else None,
)
)
@ -216,6 +222,7 @@ class ActionPlanner:
action_data={},
action_message=None,
available_actions=available_actions_dict,
action_reasoning=extracted_reasoning if extracted_reasoning else None,
)
)
@ -225,12 +232,11 @@ class ActionPlanner:
self,
available_actions: Dict[str, ActionInfo],
loop_start_time: float = 0.0,
) -> Tuple[List[ActionPlannerInfo], Optional["DatabaseMessages"]]:
) -> List[ActionPlannerInfo]:
# sourcery skip: use-named-expression
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作
"""
target_message: Optional["DatabaseMessages"] = None
# 获取聊天上下文
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
@ -276,7 +282,7 @@ class ActionPlanner:
)
# 调用LLM获取决策
actions = await self._execute_main_planner(
reasoning, actions = await self._execute_main_planner(
prompt=prompt,
message_id_list=message_id_list,
filtered_actions=filtered_actions,
@ -284,12 +290,34 @@ class ActionPlanner:
loop_start_time=loop_start_time,
)
# 获取target_message如果有非no_action的动作
non_no_actions = [a for a in actions if a.action_type != "no_reply"]
if non_no_actions:
target_message = non_no_actions[0].action_message
logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
self.add_plan_log(reasoning, actions)
return actions
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
self.plan_log.append((reasoning, time.time(), actions))
if len(self.plan_log) > 20:
self.plan_log.pop(0)
def add_plan_excute_log(self, result: str):
self.plan_log.append(("", time.time(), result))
if len(self.plan_log) > 20:
self.plan_log.pop(0)
def get_plan_log_str(self) -> str:
plan_log_str = ""
for reasoning, time, content in self.plan_log:
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
time = datetime.fromtimestamp(time).strftime("%H:%M:%S")
plan_log_str += f"{time}:{reasoning}|你使用了{','.join([action.action_type for action in content])}\n"
else:
time = datetime.fromtimestamp(time).strftime("%H:%M:%S")
plan_log_str += f"{time}:{content}\n"
return plan_log_str
return actions, target_message
async def build_planner_prompt(
self,
@ -302,18 +330,8 @@ class ActionPlanner:
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
# 获取最近执行过的动作
actions_before_now = get_actions_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=time.time() - 600,
timestamp_end=time.time(),
limit=6,
)
actions_before_now_block = build_readable_actions(actions=actions_before_now)
if actions_before_now_block:
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
else:
actions_before_now_block = ""
actions_before_now_block=self.get_plan_log_str()
# 构建聊天上下文描述
chat_context_description = "你现在正在一个群聊中"
@ -343,7 +361,6 @@ class ActionPlanner:
interest=interest,
plan_style=global_config.personality.plan_style,
)
return prompt, message_id_list
except Exception as e:
@ -421,6 +438,11 @@ class ActionPlanner:
for require_item in action_info.action_require:
require_text += f"- {require_item}\n"
require_text = require_text.rstrip("\n")
if not action_info.parallel_action:
parallel_text = "(当选择这个动作时,请不要选择其他动作)"
else:
parallel_text = ""
# 获取动作提示模板并填充
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
@ -429,6 +451,7 @@ class ActionPlanner:
action_description=action_info.description,
action_parameters=param_text,
action_require=require_text,
parallel_text=parallel_text,
)
action_options_block += using_action_prompt
@ -442,7 +465,7 @@ class ActionPlanner:
filtered_actions: Dict[str, ActionInfo],
available_actions: Dict[str, ActionInfo],
loop_start_time: float,
) -> List[ActionPlannerInfo]:
) -> Tuple[str,List[ActionPlannerInfo]]:
"""执行主规划器"""
llm_content = None
actions: List[ActionPlannerInfo] = []
@ -451,8 +474,8 @@ class ActionPlanner:
# 调用LLM
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
@ -467,7 +490,7 @@ class ActionPlanner:
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return [
return f"LLM 请求失败,模型出现问题: {req_e}",[
ActionPlannerInfo(
action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
@ -478,38 +501,41 @@ class ActionPlanner:
]
# 解析LLM响应
extracted_reasoning = ""
if llm_content:
try:
if json_objects := self._extract_json_from_markdown(llm_content):
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
if json_objects:
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects:
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list))
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list, extracted_reasoning))
else:
# 尝试解析为直接的JSON
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
extracted_reasoning = "LLM没有返回可用动作"
actions = self._create_no_reply("LLM没有返回可用动作", available_actions)
except Exception as json_e:
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
extracted_reasoning = f"解析LLM响应JSON失败: {json_e}"
actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions)
traceback.print_exc()
else:
extracted_reasoning = "规划器没有获得LLM响应"
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
# 添加循环开始时间到所有非no_action动作
# 添加循环开始时间到所有非no_reply动作
for action in actions:
action.action_data = action.action_data or {}
action.action_data["loop_start_time"] = loop_start_time
logger.info(
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
return actions
return extracted_reasoning,actions
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
"""创建no_action"""
"""创建no_reply"""
return [
ActionPlannerInfo(
action_type="no_reply",
@ -520,15 +546,26 @@ class ActionPlanner:
)
]
def _extract_json_from_markdown(self, content: str) -> List[dict]:
def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]:
# sourcery skip: for-append-to-extend
"""从Markdown格式的内容中提取JSON对象"""
"""从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = []
reasoning_content = ""
# 使用正则表达式查找```json包裹的JSON内容
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, content, re.DOTALL)
# 提取JSON之前的内容作为推理文本
if matches:
# 找到第一个```json的位置
first_json_pos = content.find("```json")
if first_json_pos > 0:
reasoning_content = content[:first_json_pos].strip()
# 清理推理内容中的注释标记
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
reasoning_content = reasoning_content.strip()
for match in matches:
try:
# 清理可能的注释和格式问题
@ -546,7 +583,7 @@ class ActionPlanner:
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
continue
return json_objects
return json_objects, reasoning_content
init_prompt()

View File

@ -6,7 +6,8 @@ import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager
from src.memory_system.Memory_chest import global_memory_chest
from src.memory_system.questions import global_conflict_tracker
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
@ -19,16 +20,17 @@ from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import global_prompt_manager
from src.mood.mood_manager import mood_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
from src.express.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.chat.memory_system.memory_activator import MemoryActivator
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import Person, is_person_known
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
@ -43,6 +45,7 @@ init_rewrite_prompt()
logger = get_logger("replyer")
class DefaultReplyer:
def __init__(
self,
@ -69,6 +72,7 @@ class DefaultReplyer:
from_plugin: bool = True,
stream_id: Optional[str] = None,
reply_message: Optional[DatabaseMessages] = None,
reply_time_point: Optional[float] = time.time(),
) -> Tuple[bool, LLMGenerationDataModel]:
# sourcery skip: merge-nested-ifs
"""
@ -102,6 +106,7 @@ class DefaultReplyer:
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,
reply_time_point=reply_time_point,
)
llm_response.prompt = prompt
llm_response.selected_expressions = selected_expressions
@ -216,30 +221,6 @@ class DefaultReplyer:
traceback.print_exc()
return False, llm_response
async def build_relation_info(self, chat_content: str, sender: str, person_list: List[Person]):
if not global_config.relationship.enable_relationship:
return ""
if not sender:
return ""
if sender == global_config.bot.nickname:
return ""
# 获取用户ID
person = Person(person_name=sender)
if not is_person_known(person_name=sender):
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
sender_relation = await person.build_relationship(chat_content)
others_relation = ""
for person in person_list:
person_relation = await person.build_relationship()
others_relation += person_relation
return f"{sender_relation}\n{others_relation}"
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块
@ -257,8 +238,8 @@ class DefaultReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)
@ -277,44 +258,42 @@ class DefaultReplyer:
expression_habits_block = ""
expression_habits_title = ""
if style_habits_str.strip():
expression_habits_title = (
"在回复时,你可以参考以下的语言习惯,不要生硬使用:"
)
expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
expression_habits_block += f"{style_habits_str}\n"
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_mood_state_prompt(self) -> str:
"""构建情绪状态提示"""
if not global_config.mood.enable_mood:
return ""
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}"
async def build_memory_block(self) -> str:
"""构建记忆块
"""
# if not global_config.memory.enable_memory:
# return ""
# async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
# """构建记忆块
if global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id):
return f"你有以下记忆:\n{global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id)}"
else:
return ""
# Args:
# chat_history: 聊天历史记录
# target: 目标消息内容
# Returns:
# str: 记忆信息字符串
# """
# if not global_config.memory.enable_memory:
# return ""
# instant_memory = None
# running_memories = await self.memory_activator.activate_memory_with_chat_history(
# target_message=target, chat_history=chat_history
# )
# if not running_memories:
# return ""
# memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
# for running_memory in running_memories:
# keywords, content = running_memory
# memory_str += f"- {keywords}{content}\n"
# if instant_memory:
# memory_str += f"- {instant_memory}\n"
# return memory_str
async def build_question_block(self) -> str:
"""构建问题块"""
# if not global_config.question.enable_question:
# return ""
questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_stream.stream_id)
questions_str = ""
for question in questions:
questions_str += f"- {question.question}\n"
if questions_str:
return f"你在聊天中,有以下问题想要得到解答:\n{questions_str}"
else:
return ""
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@ -344,7 +323,7 @@ class DefaultReplyer:
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
tool_info_str += f"- 【{tool_name}: {content}\n"
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
logger.info(f"获取到 {len(tool_results)} 个工具结果")
@ -380,6 +359,64 @@ class DefaultReplyer:
target = parts[1].strip()
return sender, target
def _replace_picids_with_descriptions(self, text: str) -> str:
"""将文本中的[picid:xxx]替换为具体的图片描述
Args:
text: 包含picid标记的文本
Returns:
替换后的文本
"""
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str:
pic_id = match.group(1)
description = translate_pid_to_description(pic_id)
return f"[图片:{description}]"
return re.sub(pic_pattern, replace_pic_id, text)
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
"""分析target内容类型基于原始picid格式
Args:
target: 目标消息内容包含[picid:xxx]格式
Returns:
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
"""
if not target or not target.strip():
return False, False, "", ""
# 检查是否只包含picid标记
picid_pattern = r"\[picid:[^\]]+\]"
picid_matches = re.findall(picid_pattern, target)
# 移除所有picid标记后检查是否还有文字内容
text_without_picids = re.sub(picid_pattern, "", target).strip()
has_only_pics = len(picid_matches) > 0 and not text_without_picids
has_text = bool(text_without_picids)
# 提取图片部分(转换为[图片:描述]格式)
pic_part = ""
if picid_matches:
pic_descriptions = []
for picid_match in picid_matches:
pic_id = picid_match[7:-1] # 提取picid:xxx中的xxx部分从第7个字符开始
description = translate_pid_to_description(pic_id)
logger.info(f"图片ID: {pic_id}, 描述: {description}")
# 如果description已经是[图片]格式,直接使用;否则包装为[图片:描述]格式
if description == "[图片]":
pic_descriptions.append(description)
else:
pic_descriptions.append(f"[图片:{description}]")
pic_part = "".join(pic_descriptions)
return has_only_pics, has_text, pic_part, text_without_picids
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示
@ -438,11 +475,10 @@ class DefaultReplyer:
duration = end_time - start_time
return name, result, duration
def build_s4u_chat_history_prompts(
def build_chat_history_prompts(
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""
构建 s4u 风格的分离对话 prompt
Args:
message_list_before_now: 历史消息列表
@ -498,7 +534,6 @@ class DefaultReplyer:
--------------------------------
"""
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
@ -516,51 +551,6 @@ class DefaultReplyer:
return core_dialogue_prompt, all_dialogue_prompt
def build_mai_think_context(
self,
chat_id: str,
memory_block: str,
relation_info: str,
time_block: str,
chat_target_1: str,
chat_target_2: str,
mood_prompt: str,
identity_block: str,
sender: str,
target: str,
chat_info: str,
) -> Any:
"""构建 mai_think 上下文信息
Args:
chat_id: 聊天ID
memory_block: 记忆块内容
relation_info: 关系信息
time_block: 时间块内容
chat_target_1: 聊天目标1
chat_target_2: 聊天目标2
mood_prompt: 情绪提示
identity_block: 身份块内容
sender: 发送者名称
target: 目标消息内容
chat_info: 聊天信息
Returns:
Any: mai_think 实例
"""
mai_think = mai_thinking_manager.get_mai_think(chat_id)
mai_think.memory_block = memory_block
mai_think.relation_info_block = relation_info
mai_think.time_block = time_block
mai_think.chat_target = chat_target_1
mai_think.chat_target_2 = chat_target_2
mai_think.chat_info = chat_info
mai_think.mood_state = mood_prompt
mai_think.identity = identity_block
mai_think.sender = sender
mai_think.target = target
return mai_think
async def build_actions_prompt(
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
) -> str:
@ -615,6 +605,7 @@ class DefaultReplyer:
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
enable_tool: bool = True,
reply_time_point: Optional[float] = time.time(),
) -> Tuple[str, List[int]]:
"""
构建回复器上下文
@ -649,23 +640,23 @@ class DefaultReplyer:
sender = person_name
target = reply_message.processed_plain_text
mood_prompt: str = ""
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
mood_prompt = chat_mood.mood_state
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
timestamp=reply_time_point,
limit=global_config.chat.max_context_size * 1,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
timestamp=reply_time_point,
limit=int(global_config.chat.max_context_size * 0.33),
)
@ -686,8 +677,8 @@ class DefaultReplyer:
if person.is_known:
person_list_short.append(person)
for person in person_list_short:
print(person.person_name)
# for person in person_list_short:
# print(person.person_name)
chat_talking_prompt_short = build_readable_messages(
message_list_before_short,
@ -702,16 +693,15 @@ class DefaultReplyer:
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
),
# self._time_and_run_task(
# self.build_relation_info(chat_talking_prompt_short, sender, person_list_short), "relation_info"
# ),
# self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
self._time_and_run_task(self.build_memory_block(), "memory_block"),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
),
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
self._time_and_run_task(self.build_question_block(), "question_block"),
)
# 任务名称中英文映射
@ -719,10 +709,13 @@ class DefaultReplyer:
"expression_habits": "选取表达方式",
"relation_info": "感受关系",
# "memory_block": "回忆",
"memory_block": "记忆",
"tool_info": "使用工具",
"prompt_info": "获取知识",
"actions_info": "动作信息",
"personality_prompt": "人格信息",
"mood_state_prompt": "情绪状态",
"question_block": "问题",
}
# 处理结果
@ -747,11 +740,14 @@ class DefaultReplyer:
selected_expressions: List[int]
# relation_info: str = results_dict["relation_info"]
# memory_block: str = results_dict["memory_block"]
memory_block: str = results_dict["memory_block"]
tool_info: str = results_dict["tool_info"]
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
actions_info: str = results_dict["actions_info"]
personality_prompt: str = results_dict["personality_prompt"]
question_block: str = results_dict["question_block"]
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
mood_state_prompt: str = results_dict["mood_state_prompt"]
if extra_info:
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
@ -763,63 +759,49 @@ class DefaultReplyer:
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
if sender:
if is_group_chat:
reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意"
)
else: # private chat
reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意"
)
# 使用预先分析的内容类型结果
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意"
elif has_text:
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意"
else:
# 其他情况(空内容等)
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意"
else:
reply_target_block = ""
# 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(
message_list_before_now_long, user_id, sender
)
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
return await global_prompt_manager.format_prompt(
"replyer_self_prompt",
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
# memory_block=memory_block,
# relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
action_descriptions=actions_info,
mood_state=mood_prompt,
background_dialogue_prompt=background_dialogue_prompt,
time_block=time_block,
target=target,
reason=reply_reason,
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
), selected_expressions
else:
return await global_prompt_manager.format_prompt(
"replyer_prompt",
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
# memory_block=memory_block,
# relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
action_descriptions=actions_info,
sender_name=sender,
mood_state=mood_prompt,
background_dialogue_prompt=background_dialogue_prompt,
time_block=time_block,
core_dialogue_prompt=core_dialogue_prompt,
reply_target_block=reply_target_block,
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
), selected_expressions
return await global_prompt_manager.format_prompt(
"replyer_prompt",
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
memory_block=memory_block,
knowledge_prompt=prompt_info,
mood_state=mood_state_prompt,
# memory_block=memory_block,
# relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
action_descriptions=actions_info,
sender_name=sender,
background_dialogue_prompt=background_dialogue_prompt,
time_block=time_block,
core_dialogue_prompt=core_dialogue_prompt,
reply_target_block=reply_target_block,
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
question_block=question_block,
), selected_expressions
async def build_prompt_rewrite_context(
self,
@ -833,14 +815,12 @@ class DefaultReplyer:
sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
# 添加情绪状态获取
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
mood_prompt = chat_mood.mood_state
else:
mood_prompt = ""
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
@ -858,7 +838,6 @@ class DefaultReplyer:
# 并行执行2个构建任务
(expression_habits_block, _), personality_prompt = await asyncio.gather(
self.build_expression_habits(chat_talking_prompt_half, target),
# self.build_relation_info(chat_talking_prompt_half, sender, []),
self.build_personality_prompt(),
)
@ -871,18 +850,39 @@ class DefaultReplyer:
)
if sender and target:
# 使用预先分析的内容类型结果
if is_group_chat:
if sender:
reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = (
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
else:
# 只包含文字
reply_target_block = (
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif target:
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
reply_target_block = "现在,你想要在群里发言或者回复消息。"
else: # private chat
if sender:
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。"
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
else:
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
elif target:
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
else:
@ -918,7 +918,6 @@ class DefaultReplyer:
reply_target_block=reply_target_block,
raw_reply=raw_reply,
reason=reason,
mood_state=mood_prompt, # 添加情绪状态参数
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
@ -966,13 +965,13 @@ class DefaultReplyer:
if global_config.debug.show_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\n{prompt}\n")
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt
)
logger.debug(f"replyer生成内容: {content}")
logger.info(f"使用 {model_name} 生成回复内容: {content}")
return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, sender: str, target: str):
@ -1059,6 +1058,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
pool.pop(idx)
break
return selected

View File

@ -6,7 +6,7 @@ import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager
from src.memory_system.Memory_chest import global_memory_chest
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
@ -24,10 +24,12 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
# from src.chat.memory_system.memory_activator import MemoryActivator
from src.express.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
from src.mood.mood_manager import mood_manager
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
@ -69,6 +71,7 @@ class PrivateReplyer:
from_plugin: bool = True,
stream_id: Optional[str] = None,
reply_message: Optional[DatabaseMessages] = None,
reply_time_point: Optional[float] = time.time(),
) -> Tuple[bool, LLMGenerationDataModel]:
# sourcery skip: merge-nested-ifs
"""
@ -253,8 +256,8 @@ class PrivateReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)
@ -280,38 +283,22 @@ class PrivateReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
# async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
# """构建记忆块
async def build_mood_state_prompt(self) -> str:
"""构建情绪状态提示"""
if not global_config.mood.enable_mood:
return ""
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}"
# Args:
# chat_history: 聊天历史记录
# target: 目标消息内容
# Returns:
# str: 记忆信息字符串
# """
# if not global_config.memory.enable_memory:
# return ""
# instant_memory = None
# running_memories = await self.memory_activator.activate_memory_with_chat_history(
# target_message=target, chat_history=chat_history
# )
# if not running_memories:
# return ""
# memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
# for running_memory in running_memories:
# keywords, content = running_memory
# memory_str += f"- {keywords}{content}\n"
# if instant_memory:
# memory_str += f"- {instant_memory}\n"
# return memory_str
async def build_memory_block(self) -> str:
"""构建记忆块
"""
if global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id):
return f"你有以下记忆:\n{global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id)}"
else:
return ""
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@ -376,6 +363,64 @@ class PrivateReplyer:
target = parts[1].strip()
return sender, target
def _replace_picids_with_descriptions(self, text: str) -> str:
"""将文本中的[picid:xxx]替换为具体的图片描述
Args:
text: 包含picid标记的文本
Returns:
替换后的文本
"""
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str:
pic_id = match.group(1)
description = translate_pid_to_description(pic_id)
return f"[图片:{description}]"
return re.sub(pic_pattern, replace_pic_id, text)
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
"""分析target内容类型基于原始picid格式
Args:
target: 目标消息内容包含[picid:xxx]格式
Returns:
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
"""
if not target or not target.strip():
return False, False, "", ""
# 检查是否只包含picid标记
picid_pattern = r"\[picid:[^\]]+\]"
picid_matches = re.findall(picid_pattern, target)
# 移除所有picid标记后检查是否还有文字内容
text_without_picids = re.sub(picid_pattern, "", target).strip()
has_only_pics = len(picid_matches) > 0 and not text_without_picids
has_text = bool(text_without_picids)
# 提取图片部分(转换为[图片:描述]格式)
pic_part = ""
if picid_matches:
pic_descriptions = []
for picid_match in picid_matches:
pic_id = picid_match[7:-1] # 提取picid:xxx中的xxx部分从第7个字符开始
description = translate_pid_to_description(pic_id)
logger.debug(f"图片ID: {pic_id}, 描述: {description}")
# 如果description已经是[图片]格式,直接使用;否则包装为[图片:描述]格式
if description == "[图片]":
pic_descriptions.append(description)
else:
pic_descriptions.append(f"[图片:{description}]")
pic_part = "".join(pic_descriptions)
return has_only_pics, has_text, pic_part, text_without_picids
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示
@ -521,13 +566,15 @@ class PrivateReplyer:
sender = person_name
target = reply_message.processed_plain_text
mood_prompt: str = ""
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
mood_prompt = chat_mood.mood_state
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
@ -566,8 +613,8 @@ class PrivateReplyer:
if person.is_known:
person_list_short.append(person)
for person in person_list_short:
print(person.person_name)
# for person in person_list_short:
# print(person.person_name)
chat_talking_prompt_short = build_readable_messages(
message_list_before_short,
@ -585,6 +632,7 @@ class PrivateReplyer:
self._time_and_run_task(
self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"
),
self._time_and_run_task(self.build_memory_block(), "memory_block"),
# self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
@ -592,17 +640,19 @@ class PrivateReplyer:
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
)
# 任务名称中英文映射
task_name_mapping = {
"expression_habits": "选取表达方式",
"relation_info": "感受关系",
# "memory_block": "回忆",
"memory_block": "回忆",
"tool_info": "使用工具",
"prompt_info": "获取知识",
"actions_info": "动作信息",
"personality_prompt": "人格信息",
"mood_state_prompt": "情绪状态",
}
# 处理结果
@ -626,11 +676,12 @@ class PrivateReplyer:
expression_habits_block: str
selected_expressions: List[int]
relation_info: str = results_dict["relation_info"]
# memory_block: str = results_dict["memory_block"]
memory_block: str = results_dict["memory_block"]
tool_info: str = results_dict["tool_info"]
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
actions_info: str = results_dict["actions_info"]
personality_prompt: str = results_dict["personality_prompt"]
mood_state_prompt: str = results_dict["mood_state_prompt"]
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
if extra_info:
@ -642,9 +693,19 @@ class PrivateReplyer:
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
reply_target_block = (
f"现在对方说的:{target}。引起了你的注意"
)
# 使用预先分析的内容类型结果
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = f"现在对方发送的图片:{pic_part}。引起了你的注意"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在对方发送了图片:{pic_part},并说:{text_part}。引起了你的注意"
elif has_text:
# 只包含文字
reply_target_block = f"现在对方说的:{text_part}。引起了你的注意"
else:
# 其他情况(空内容等)
reply_target_block = f"现在对方说的:{target}。引起了你的注意"
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
return await global_prompt_manager.format_prompt(
@ -652,12 +713,12 @@ class PrivateReplyer:
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
# memory_block=memory_block,
mood_state=mood_state_prompt,
memory_block=memory_block,
relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
action_descriptions=actions_info,
mood_state=mood_prompt,
dialogue_prompt=dialogue_prompt,
time_block=time_block,
target=target,
@ -673,12 +734,12 @@ class PrivateReplyer:
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
# memory_block=memory_block,
mood_state=mood_state_prompt,
memory_block=memory_block,
relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
action_descriptions=actions_info,
mood_state=mood_prompt,
dialogue_prompt=dialogue_prompt,
time_block=time_block,
reply_target_block=reply_target_block,
@ -700,14 +761,14 @@ class PrivateReplyer:
sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
# 在picid替换之前分析内容类型防止prompt注入
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
# 添加情绪状态获取
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
mood_prompt = chat_mood.mood_state
else:
mood_prompt = ""
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
@ -738,18 +799,39 @@ class PrivateReplyer:
)
if sender and target:
# 使用预先分析的内容类型结果
if is_group_chat:
if sender:
reply_target_block = (
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = (
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
else:
# 只包含文字
reply_target_block = (
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif target:
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
reply_target_block = "现在,你想要在群里发言或者回复消息。"
else: # private chat
if sender:
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。"
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
else:
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
elif target:
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
else:
@ -785,7 +867,6 @@ class PrivateReplyer:
reply_target_block=reply_target_block,
raw_reply=raw_reply,
reason=reason,
mood_state=mood_prompt, # 添加情绪状态参数
reply_style=global_config.personality.reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
@ -839,7 +920,7 @@ class PrivateReplyer:
prompt
)
logger.debug(f"replyer生成内容: {content}")
logger.info(f"使用 {model_name} 生成回复内容: {content}")
return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, sender: str, target: str):

View File

@ -1,7 +1,5 @@
from src.chat.utils.prompt_builder import Prompt
# from src.chat.memory_system.memory_activator import MemoryActivator
# from src.memory_system.memory_activator import MemoryActivator
def init_lpmm_prompt():
@ -20,5 +18,3 @@ If you need to use the search tool, please directly call the function "lpmm_sear
""",
name="lpmm_get_knowledge_prompt",
)

View File

@ -13,7 +13,7 @@ def init_replyer_prompt():
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}
{expression_habits_block}{memory_block}{question_block}
你正在qq群里聊天下面是群里正在聊的内容:
{time_block}
@ -22,40 +22,19 @@ def init_replyer_prompt():
{reply_target_block}
{identity}
你正在群里聊天,现在请你读读之前的聊天记录然后给出日常且口语化的回复平淡一些
你正在群里聊天,现在请你读读之前的聊天记录然后给出日常且口语化的回复平淡一些{mood_state}
尽量简短一些{keywords_reaction_prompt}请注意把握聊天内容不要回复的太有条理可以有个性
{reply_style}
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出回复内容
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )""",
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出一句回复内容就好
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )
现在你说""",
"replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}
你正在qq群里聊天下面是群里正在聊的内容:
{time_block}
{background_dialogue_prompt}
你现在想补充说明你刚刚自己的发言内容{target}原因是{reason}
请你根据聊天内容组织一条新回复注意{target} 是刚刚你自己的发言你要在这基础上进一步发言请按照你自己的角度来继续进行回复注意保持上下文的连贯性
{identity}
尽量简短一些{keywords_reaction_prompt}请注意把握聊天内容不要回复的太有条理可以有个性
{reply_style}
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出回复内容
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )
""",
"replyer_self_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}
{expression_habits_block}{memory_block}
你正在和{sender_name}聊天这是你们之前聊的内容:
{time_block}
@ -63,7 +42,7 @@ def init_replyer_prompt():
{reply_target_block}
{identity}
你正在和{sender_name}聊天,现在请你读读之前的聊天记录然后给出日常且口语化的回复平淡一些
你正在和{sender_name}聊天,现在请你读读之前的聊天记录然后给出日常且口语化的回复平淡一些{mood_state}
尽量简短一些{keywords_reaction_prompt}请注意把握聊天内容不要回复的太有条理可以有个性
{reply_style}
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出回复内容
@ -74,19 +53,19 @@ def init_replyer_prompt():
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}
{expression_habits_block}{memory_block}
你正在和{sender_name}聊天这是你们之前聊的内容:
{time_block}
{dialogue_prompt}
你现在想补充说明你刚刚自己的发言内容{target}原因是{reason}
请你根据聊天内容组织一条新回复注意{target} 是刚刚你自己的发言你要在这基础上进一步发言请按照你自己的角度来继续进行回复注意保持上下文的连贯性
请你根据聊天内容组织一条新回复注意{target} 是刚刚你自己的发言你要在这基础上进一步发言请按照你自己的角度来继续进行回复注意保持上下文的连贯性{mood_state}
{identity}
尽量简短一些{keywords_reaction_prompt}请注意把握聊天内容不要回复的太有条理可以有个性
{reply_style}
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出回复内容
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )
""",
"private_replyer_self_prompt",
)

View File

@ -1,7 +1,5 @@
from src.chat.utils.prompt_builder import Prompt
# from src.chat.memory_system.memory_activator import MemoryActivator
# from src.memory_system.memory_activator import MemoryActivator
def init_rewrite_prompt():
@ -14,13 +12,11 @@ def init_rewrite_prompt():
"""
{expression_habits_block}
{chat_target}
{time_block}
{chat_info}
{identity}
你现在的心情是{mood_state}
你正在{chat_target_2},{reply_target_block}
你想要对上述的发言进行回复回复的具体内容原句{raw_reply}
现在请你对这句内容进行改写请你参考上述内容进行改写原句{raw_reply}
原因是{reason}
现在请你将这条具体内容改写成一条适合在群聊中发送的回复消息
你需要使用合适的语法和句法参考聊天内容组织一条日常且口语化的回复请你修改你想表达的原句符合你的表达风格和语言习惯
@ -28,8 +24,8 @@ def init_rewrite_prompt():
你可以完全重组回复保留最基本的表达含义就好但重组后保持语意通顺
{keywords_reaction_prompt}
{moderation_prompt}
不要输出多余内容(包括前后缀冒号和引表情包emoji,at或 @等 )只输出一条回复就好
现在你说
不要输出多余内容(包括冒号和引号表情包emoji,at或 @等 )只输出一条回复就好
改写后的回复
""",
"default_expressor_prompt",
)
)

View File

@ -2,7 +2,7 @@ import time
import random
import re
from typing import List, Dict, Any, Tuple, Optional, Callable
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
from rich.traceback import install
from src.config.config import global_config
@ -124,6 +124,7 @@ def get_raw_msg_by_timestamp_with_chat(
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
# print(f"get_raw_msg_by_timestamp_with_chat: {chat_id}, {timestamp_start}, {timestamp_end}, {limit}, {limit_mode}, {filter_bot}, {filter_command}")
return find_messages(
message_filter=filter_query,
sort=sort_order,
@ -215,6 +216,7 @@ def get_actions_by_timestamp_with_chat(
chat_id=action.chat_id,
chat_info_stream_id=action.chat_info_stream_id,
chat_info_platform=action.chat_info_platform,
action_reasoning=action.action_reasoning,
)
for action in actions
]
@ -417,12 +419,6 @@ def _build_readable_messages_internal(
timestamp = message.time
content = message.display_message or message.processed_plain_text or ""
# 向下兼容
if "" in content:
content = content.replace("", "")
if "" in content:
content = content.replace("", "")
# 处理图片ID
if show_pic:
content = process_pic_ids(content)
@ -564,14 +560,12 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
output_lines = []
current_time = time.time()
# The get functions return actions sorted ascending by time. Let's reverse it to show newest first.
# sorted_actions = sorted(actions, key=lambda x: x.get("time", 0), reverse=True)
for action in actions:
action_time = action.time or current_time
action_name = action.action_name or "未知动作"
# action_reason = action.get(action_data")
if action_name in ["no_action", "no_action"]:
if action_name in ["no_reply", "no_reply"]:
continue
action_prompt_display = action.action_prompt_display or "无具体内容"
@ -593,6 +587,7 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}"
output_lines.append(line)
return "\n".join(output_lines)
@ -628,6 +623,7 @@ def build_readable_messages_with_id(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
remove_emoji_stickers: bool = False,
) -> Tuple[str, List[Tuple[str, DatabaseMessages]]]:
"""
将消息列表转换为可读的文本格式并返回原始(时间戳, 昵称, 内容)列表
@ -644,6 +640,7 @@ def build_readable_messages_with_id(
show_pic=show_pic,
read_mark=read_mark,
message_id_list=message_id_list,
remove_emoji_stickers=remove_emoji_stickers,
)
return formatted_string, message_id_list
@ -658,6 +655,7 @@ def build_readable_messages(
show_actions: bool = False,
show_pic: bool = True,
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
remove_emoji_stickers: bool = False,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式
@ -672,13 +670,40 @@ def build_readable_messages(
read_mark: 已读标记时间戳
truncate: 是否截断长消息
show_actions: 是否显示动作记录
remove_emoji_stickers: 是否移除表情包并过滤空消息
"""
# WIP HERE and BELOW ----------------------------------------------
# 创建messages的深拷贝避免修改原始列表
if not messages:
return ""
copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages]
# 如果启用移除表情包,先过滤消息
if remove_emoji_stickers:
filtered_messages = []
for msg in messages:
# 获取消息内容
content = msg.processed_plain_text
# 移除表情包
emoji_pattern = r"\[表情包:[^\]]+\]"
content = re.sub(emoji_pattern, "", content)
# 如果移除表情包后内容不为空,则保留消息
if content.strip():
filtered_messages.append(msg)
messages = filtered_messages
copy_messages: List[MessageAndActionModel] = []
for msg in messages:
if remove_emoji_stickers:
# 创建 MessageAndActionModel 但移除表情包
model = MessageAndActionModel.from_DatabaseMessages(msg)
# 移除表情包
if model.processed_plain_text:
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
copy_messages.append(model)
else:
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
if show_actions and copy_messages:
# 获取所有消息的时间范围
@ -862,17 +887,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
user_id = msg.user_info.user_id
content = msg.display_message or msg.processed_plain_text or ""
if "" in content:
content = content.replace("", "")
if "" in content:
content = content.replace("", "")
# 处理图片ID
content = process_pic_ids(content)
# if not all([platform, user_id, timestamp is not None]):
# continue
anon_name = get_anon_name(platform, user_id)
# print(f"anon_name:{anon_name}")
@ -909,6 +926,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
return formatted_string
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)
@ -937,3 +955,45 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回
async def build_bare_messages(messages: List[DatabaseMessages]) -> str:
"""
构建简化版消息字符串只包含processed_plain_text内容不考虑用户名和时间戳
Args:
messages: 消息列表
Returns:
只包含消息内容的字符串
"""
if not messages:
return ""
output_lines = []
for msg in messages:
# 获取纯文本内容
content = msg.processed_plain_text or ""
# 处理图片ID
pic_pattern = r"\[picid:[^\]]+\]"
def replace_pic_id(match):
return "[图片]"
content = re.sub(pic_pattern, replace_pic_id, content)
# 处理用户引用格式,移除回复和@标记
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
content = re.sub(reply_pattern, "回复[某人]", content)
at_pattern = r"@<[^:<>]+:[^:<>]+>"
content = re.sub(at_pattern, "@[某人]", content)
# 清理并添加到输出
content = content.strip()
if content:
output_lines.append(content)
return "\n".join(output_lines)

View File

@ -151,7 +151,7 @@ class Prompt(str):
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号,将 \{\} 替换为临时标记""" # type: ignore
"""处理模板中的转义花括号,替换为临时标记""" # type: ignore
# 如果传入的是列表,将其转换为字符串
if isinstance(template, list):
template = "\n".join(str(item) for item in template)

View File

@ -383,10 +383,6 @@ def calculate_typing_time(
- 在所有输入结束后额外加上回车时间0.3
- 如果is_emoji为True将使用固定1秒的输入时间
"""
# # 将0-1的唤醒度映射到-1到1
# mood_arousal = mood_manager.current_mood.arousal
# # 映射到0.5到2倍的速度系数
# typing_speed_multiplier = 1.5**mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
# chinese_time *= 1 / typing_speed_multiplier
# english_time *= 1 / typing_speed_multiplier
# 计算中文字符数

View File

@ -623,3 +623,41 @@ def image_path_to_base64(image_path: str) -> str:
return base64.b64encode(image_data).decode("utf-8")
else:
raise IOError(f"读取图片文件失败: {image_path}")
def base64_to_image(image_base64: str, output_path: str) -> bool:
"""将base64编码的图片保存为文件
Args:
image_base64: 图片的base64编码
output_path: 输出文件路径
Returns:
bool: 是否成功保存
Raises:
ValueError: 当base64编码无效时
IOError: 当保存文件失败时
"""
try:
# 确保base64字符串只包含ASCII字符
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
# 解码base64
image_bytes = base64.b64decode(image_base64)
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# 保存文件
with open(output_path, "wb") as f:
f.write(image_bytes)
return True
except Exception as e:
logger.error(f"保存base64图片失败: {e}")
return False

View File

@ -220,6 +220,7 @@ class DatabaseActionRecords(BaseDataModel):
chat_id: str,
chat_info_stream_id: str,
chat_info_platform: str,
action_reasoning:str
):
self.action_id = action_id
self.time = time
@ -234,3 +235,4 @@ class DatabaseActionRecords(BaseDataModel):
self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform
self.action_reasoning = action_reasoning

View File

@ -24,3 +24,4 @@ class ActionPlannerInfo(BaseDataModel):
action_message: Optional["DatabaseMessages"] = None
available_actions: Optional[Dict[str, "ActionInfo"]] = None
loop_start_time: Optional[float] = None
action_reasoning: Optional[str] = None

View File

@ -16,4 +16,4 @@ class LLMGenerationDataModel(BaseDataModel):
tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional["ReplySetModel"] = None
reply_set: Optional["ReplySetModel"] = None

View File

@ -1,64 +1,9 @@
import os
from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database
from rich.traceback import install
install(extra_lines=3)
_client = None
_db = None
def __create_database_instance():
uri = os.getenv("MONGODB_URI")
host = os.getenv("MONGODB_HOST", "127.0.0.1")
port = int(os.getenv("MONGODB_PORT", "27017"))
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
username = os.getenv("MONGODB_USERNAME")
password = os.getenv("MONGODB_PASSWORD")
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
if uri:
# 支持标准mongodb://和mongodb+srv://连接字符串
if uri.startswith(("mongodb://", "mongodb+srv://")):
return MongoClient(uri)
else:
raise ValueError(
"Invalid MongoDB URI format. URI must start with 'mongodb://' or 'mongodb+srv://'. "
"For MongoDB Atlas, use 'mongodb+srv://' format. "
"See: https://www.mongodb.com/docs/manual/reference/connection-string/"
)
if username and password:
# 如果有用户名和密码,使用认证连接
return MongoClient(host, port, username=username, password=password, authSource=auth_source)
# 否则使用无认证连接
return MongoClient(host, port)
def get_db():
"""获取数据库连接实例,延迟初始化。"""
global _client, _db
if _client is None:
_client = __create_database_instance()
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
return _db
class DBWrapper:
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
def __getattr__(self, name):
return getattr(get_db(), name)
def __getitem__(self, key):
return get_db()[key] # type: ignore
# 全局数据库访问点
memory_db: Database = DBWrapper() # type: ignore
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))

View File

@ -185,6 +185,8 @@ class ActionRecords(BaseModel):
action_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
action_reasoning = TextField(null=True)
action_name = TextField()
action_data = TextField()
action_done = BooleanField(default=False)
@ -301,46 +303,47 @@ class Expression(BaseModel):
situation = TextField()
style = TextField()
count = FloatField()
# new mode fields
context = TextField(null=True)
up_content = TextField(null=True)
last_active_time = FloatField()
chat_id = TextField(index=True)
type = TextField()
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
class Meta:
table_name = "expression"
class GraphNodes(BaseModel):
class MemoryChest(BaseModel):
"""
用于存储记忆图节点的模型
用于存储记忆仓库的模型
"""
concept = TextField(unique=True, index=True) # 节点概念
memory_items = TextField() # JSON格式存储的记忆列表
weight = FloatField(default=0.0) # 节点权重
hash = TextField() # 节点哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
title = TextField() # 标题
content = TextField() # 内容
chat_id = TextField(null=True) # 聊天ID
locked = BooleanField(default=False) # 是否锁定
class Meta:
table_name = "graph_nodes"
table_name = "memory_chest"
class GraphEdges(BaseModel):
class MemoryConflict(BaseModel):
"""
用于存储记忆图边的模型
用于存储记忆整合过程中冲突内容的模型
"""
source = TextField(index=True) # 源节点
target = TextField(index=True) # 目标节点
strength = IntegerField() # 连接强度
hash = TextField() # 边哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
conflict_content = TextField() # 冲突内容
answer = TextField(null=True) # 回答内容
create_time = FloatField() # 创建时间
update_time = FloatField() # 更新时间
context = TextField(null=True) # 上下文
chat_id = TextField(null=True) # 聊天ID
raise_time = FloatField(null=True) # 触发次数
class Meta:
table_name = "graph_edges"
table_name = "memory_conflicts"
def create_tables():
@ -359,9 +362,9 @@ def create_tables():
OnlineTime,
PersonInfo,
Expression,
GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表
ActionRecords, # 添加 ActionRecords 到初始化列表
MemoryChest,
MemoryConflict, # 添加记忆冲突表
]
)
@ -386,9 +389,9 @@ def initialize_database(sync_constraints=False):
OnlineTime,
PersonInfo,
Expression,
GraphNodes,
GraphEdges,
ActionRecords, # 添加 ActionRecords 到初始化列表
MemoryChest,
MemoryConflict,
]
try:
@ -483,9 +486,9 @@ def sync_field_constraints():
OnlineTime,
PersonInfo,
Expression,
GraphNodes,
GraphEdges,
ActionRecords,
MemoryChest,
MemoryConflict,
]
try:
@ -667,9 +670,9 @@ def check_field_constraints():
OnlineTime,
PersonInfo,
Expression,
GraphNodes,
GraphEdges,
ActionRecords,
MemoryChest,
MemoryConflict,
]
inconsistencies = {}
@ -725,11 +728,14 @@ def check_field_constraints():
logger.exception(f"检查字段约束时出错: {e}")
return inconsistencies
def fix_image_id():
"""
修复表情包的 image_id 字段
"""
import uuid
try:
with db:
for img in Images.select():
@ -740,6 +746,7 @@ def fix_image_id():
except Exception as e:
logger.exception(f"修复 image_id 时出错: {e}")
# 模块加载时调用初始化函数
initialize_database(sync_constraints=True)
fix_image_id()
fix_image_id()

View File

@ -363,8 +363,8 @@ MODULE_COLORS = {
"planner": "\033[36m",
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
# 聊天相关模块
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
"heartflow": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
"hfc": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
"bc": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
"sub_heartflow": "\033[38;5;207m", # 粉紫色
"subheartflow_manager": "\033[38;5;201m", # 深粉色
"background_tasks": "\033[38;5;240m", # 灰色
@ -372,8 +372,6 @@ MODULE_COLORS = {
"chat_stream": "\033[38;5;51m", # 亮青色
"message_storage": "\033[38;5;33m", # 深蓝色
"expressor": "\033[38;5;166m", # 橙色
# 专注聊天模块
"memory_activator": "\033[38;5;117m", # 天蓝色
# 插件系统
"plugins": "\033[31m", # 红色
"plugin_api": "\033[33m", # 黄色
@ -408,7 +406,7 @@ MODULE_COLORS = {
"tts_action": "\033[38;5;58m", # 深黄色
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
# Action组件
"no_action_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
"no_reply_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
"reply_action": "\033[38;5;46m", # 亮绿色
"base_action": "\033[38;5;250m", # 浅灰色
# 数据库和消息
@ -421,9 +419,7 @@ MODULE_COLORS = {
"model_utils": "\033[38;5;164m", # 紫红色
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
"relationship_builder": "\033[38;5;93m", # 浅蓝色
# s4u
"context_web_api": "\033[38;5;240m", # 深灰色
"S4U_chat": "\033[92m", # 深灰色
"conflict_tracker": "\033[38;5;82m", # 柔和的粉色,不显眼但保持粉色系
}
# 定义模块别名映射 - 将真实的logger名称映射到显示的别名

View File

@ -81,7 +81,8 @@ def find_messages(
query = query.where(Messages.user_id != global_config.bot.qq_account)
if filter_command:
query = query.where(not Messages.is_command)
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
query = query.where(~Messages.is_command)
if limit > 0:
if limit_mode == "earliest":

View File

@ -18,7 +18,6 @@ from src.config.official_configs import (
ExpressionConfig,
ChatConfig,
EmojiConfig,
MoodConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponsePostProcessConfig,
@ -31,6 +30,8 @@ from src.config.official_configs import (
RelationshipConfig,
ToolConfig,
VoiceConfig,
MoodConfig,
MemoryConfig,
DebugConfig,
)
@ -54,7 +55,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.10.3-snapshot.4"
MMC_VERSION = "0.11.0-snapshot.3"
def get_key_comment(toml_table, key):
@ -173,13 +174,8 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
_update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
# 统一使用 tomlkit.item 来保持原生类型与转义,不对列表做字符串化处理
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
@ -345,7 +341,6 @@ class Config(ConfigBase):
message_receive: MessageReceiveConfig
emoji: EmojiConfig
expression: ExpressionConfig
mood: MoodConfig
keyword_reaction: KeywordReactionConfig
chinese_typo: ChineseTypoConfig
response_post_process: ResponsePostProcessConfig
@ -355,7 +350,9 @@ class Config(ConfigBase):
maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig
memory: MemoryConfig
debug: DebugConfig
mood: MoodConfig
voice: VoiceConfig

View File

@ -2,6 +2,7 @@ import re
from dataclasses import dataclass, field
from typing import Literal, Optional
import time
from src.config.config_base import ConfigBase
@ -38,21 +39,18 @@ class PersonalityConfig(ConfigBase):
personality: str
"""人格"""
emotion_style: str
"""情感特征"""
reply_style: str = ""
"""表达风格"""
interest: str = ""
"""兴趣"""
plan_style: str = ""
"""说话规则,行为风格"""
visual_style: str = ""
"""图片提示词"""
private_plan_style: str = ""
"""私聊说话规则,行为风格"""
@ -81,48 +79,212 @@ class ChatConfig(ConfigBase):
mentioned_bot_reply: bool = True
"""是否启用提及必回复"""
auto_chat_value: float = 1
"""自动聊天,越小,麦麦主动聊天的概率越低"""
at_bot_inevitable_reply: float = 1
"""@bot 必然回复1为100%回复0为不额外增幅"""
talk_frequency: float = 0.5
"""回复频率阈值"""
planner_smooth: float = 3
"""规划器平滑增大数值会减小planner负荷略微降低反应速度推荐2-50为关闭必须大于等于0"""
talk_value: float = 1
"""思考频率"""
# 合并后的时段频率配置
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
focus_value: float = 0.5
"""麦麦的专注思考能力越低越容易专注消耗token也越多"""
focus_value_adjust: list[list[str]] = field(default_factory=lambda: [])
talk_value_rules: list[dict] = field(default_factory=lambda: [])
"""
统一的活跃度和专注度配置
格式[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
全局配置示例
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
特定聊天流配置示例
思考频率规则列表支持按聊天流/按日内时段配置
规则格式{ target="platform:id:type" "", time="HH:MM-HH:MM", value=0.5 }
示例:
[
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
["", "09:00-22:59", 1.0], # 全局规则:白天正常
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
]
说明
- 当第一个元素为空字符串""表示全局默认配置
- 当第一个元素为"platform:id:type"格式时表示特定聊天流配置
- 后续元素是"时间,频率"格式表示从该时间开始使用该频率直到下一个时间点
- 优先级特定聊天流配置 > 全局配置 > 默认值
注意
- talk_frequency_adjust 控制回复频率数值越高回复越频繁
- focus_value_adjust 控制专注思考能力数值越低越容易专注消耗token也越多
匹配优先级: 先匹配指定 chat 流规则再匹配全局规则(\"\").
时间区间支持跨夜例如 "23:00-02:00"
"""
auto_chat_value_rules: list[dict] = field(default_factory=lambda: [])
"""
自动聊天频率规则列表支持按聊天流/按日内时段配置
规则格式{ target="platform:id:type" "", time="HH:MM-HH:MM", value=0.5 }
示例:
[
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
["", "09:00-22:59", 1.0], # 全局规则:白天正常
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
]
匹配优先级: 先匹配指定 chat 流规则再匹配全局规则(\"\").
时间区间支持跨夜例如 "23:00-02:00"
"""
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
import hashlib
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except (ValueError, IndexError):
return None
def _now_minutes(self) -> int:
"""返回本地时间的分钟数(0-1439)。"""
lt = time.localtime()
return lt.tm_hour * 60 + lt.tm_min
def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
try:
start_str, end_str = [s.strip() for s in range_str.split("-")]
sh, sm = [int(x) for x in start_str.split(":")]
eh, em = [int(x) for x in end_str.split(":")]
return sh * 60 + sm, eh * 60 + em
except Exception:
return None
def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool:
"""
判断 now_min 是否在 [start_min, end_min] 区间内
支持跨夜如果 start > end则表示跨越午夜
"""
if start_min <= end_min:
return start_min <= now_min <= end_min
# 跨夜:例如 23:00-02:00
return now_min >= start_min or now_min <= end_min
def get_talk_value(self, chat_id: Optional[str]) -> float:
"""根据规则返回当前 chat 的动态 talk_value未匹配则回退到基础值。"""
if not self.talk_value_rules:
return self.talk_value
now_min = self._now_minutes()
# 1) 先尝试匹配指定 chat 的规则
if chat_id:
for rule in self.talk_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", "")
time_range = rule.get("time", "")
value = rule.get("value", None)
if not isinstance(time_range, str):
continue
# 跳过全局
if target == "":
continue
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
if config_chat_id is None or config_chat_id != chat_id:
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 2) 再匹配全局规则("")
for rule in self.talk_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", None)
time_range = rule.get("time", "")
value = rule.get("value", None)
if target != "" or not isinstance(time_range, str):
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 3) 未命中规则返回基础值
return self.talk_value
def get_auto_chat_value(self, chat_id: Optional[str]) -> float:
"""根据规则返回当前 chat 的动态 auto_chat_value未匹配则回退到基础值。"""
if not self.auto_chat_value_rules:
return self.auto_chat_value
now_min = self._now_minutes()
# 1) 先尝试匹配指定 chat 的规则
if chat_id:
for rule in self.auto_chat_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", "")
time_range = rule.get("time", "")
value = rule.get("value", None)
if not isinstance(time_range, str):
continue
# 跳过全局
if target == "":
continue
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
if config_chat_id is None or config_chat_id != chat_id:
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 2) 再匹配全局规则("")
for rule in self.auto_chat_value_rules:
if not isinstance(rule, dict):
continue
target = rule.get("target", None)
time_range = rule.get("time", "")
value = rule.get("value", None)
if target != "" or not isinstance(time_range, str):
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
except Exception:
continue
# 3) 未命中规则返回基础值
return self.auto_chat_value
@dataclass
class MessageReceiveConfig(ConfigBase):
@ -134,11 +296,23 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
max_memory_number: int = 100
"""记忆最大数量"""
memory_build_frequency: int = 1
"""记忆构建频率"""
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
mode: str = "classic"
"""表达方式模式可选classic经典模式exp_model 表达模型模式"""
learning_list: list[list] = field(default_factory=lambda: [])
"""
表达学习配置列表支持按聊天流配置
@ -299,6 +473,19 @@ class ToolConfig(ConfigBase):
"""是否在聊天中启用工具"""
@dataclass
class MoodConfig(ConfigBase):
"""情绪配置类"""
enable_mood: bool = True
"""是否启用情绪系统"""
mood_update_threshold: float = 1
"""情绪更新阈值,越高,更新越慢"""
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
"""情感特征,影响情绪的变化情况"""
@dataclass
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
@ -333,17 +520,6 @@ class EmojiConfig(ConfigBase):
"""表情包过滤要求"""
@dataclass
class MoodConfig(ConfigBase):
"""情绪配置类"""
enable_mood: bool = False
"""是否启用情绪系统"""
mood_update_threshold: float = 1.0
"""情绪更新阈值,越高,更新越慢"""
@dataclass
class KeywordRuleConfig(ConfigBase):
"""关键词规则配置类"""

View File

@ -0,0 +1,580 @@
import time
import json
import os
import re
from datetime import datetime
from typing import List, Optional, Tuple
import traceback
import difflib
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive,
build_anonymous_messages,
build_bare_messages,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.style_learner import style_learner_manager
from json_repair import repair_json
# MAX_EXPRESSION_COUNT = 300
logger = get_logger("expressor")
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def init_prompt() -> None:
learn_style_prompt = """
{chat_str}
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
1. 只考虑文字不要考虑表情包和图片
2. 不要涉及具体的人名但是可以涉及具体名词
3. 思考有没有特殊的梗一并总结成语言风格
4. 例子仅供参考请严格根据群聊内容总结!!!
注意总结成如下格式的规律总结的内容要详细但具有概括性
例如"AAAAA"可以"BBBBB", AAAAA代表某个具体的场景不超过20个字BBBBB代表对应的语言风格特定句式或表达方式不超过20个字
例如
"对某件事表示十分惊叹"使用"我嘞个xxxx"
"表示讽刺的赞同,不讲道理"使用"对对对"
"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
"当涉及游戏相关时,夸赞,略带戏谑意味"使用"这么强!"
请注意不要总结你自己SELF的发言尽量保证总结内容的逻辑性
现在请你概括
"""
Prompt(learn_style_prompt, "learn_style_prompt")
match_expression_context_prompt = """
**聊天内容**
{chat_str}
**从聊天内容总结的表达方式pairs**
{expression_pairs}
请你为上面的每一条表达方式找到该表达方式的原文句子并输出匹配结果expression_pair不能有重复每个expression_pair仅输出一个最合适的context
如果找不到原句就不输出该句的匹配结果
以json格式输出
格式如下
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
...
现在请你输出匹配结果
"""
Prompt(match_expression_context_prompt, "match_expression_context_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="expression.learner"
)
self.embedding_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
)
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次学习时间
self.last_learning_time: float = time.time()
# 学习参数
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
self.min_messages_for_learning = 30 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 300 / self.learning_intensity
def should_trigger_learning(self) -> bool:
"""
检查是否应该触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否应该触发学习
"""
# 检查是否允许学习
if not self.enable_learning:
return False
# 检查时间间隔
time_diff = time.time() - self.last_learning_time
if time_diff < self.min_learning_interval:
return False
# 检查消息数量(只检查指定聊天流的消息)
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
return False
return True
async def trigger_learning_for_chat(self):
"""
为指定聊天流触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否成功触发学习
"""
if not self.should_trigger_learning():
return
try:
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
# 学习语言风格
learnt_style = await self.learn_and_store(num=25)
# 更新学习时间
self.last_learning_time = time.time()
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
traceback.print_exc()
return
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
"""
learnt_expressions = await self.learn_expression(num)
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (
situation,
style,
_context,
_up_content,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表并训练 style_learner
has_new_expressions = False # 记录是否有新的表达方式
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
for (
situation,
style,
context,
up_content,
) in learnt_expressions:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == self.chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
# 表达方式完全相同,只更新时间戳
expr_obj = query.get()
expr_obj.last_active_time = current_time
expr_obj.save()
continue
else:
Expression.create(
situation=situation,
style=style,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time, # 手动设置创建日期
context=context,
up_content=up_content,
)
has_new_expressions = True
# 训练 style_learnerup_content 和 style 必定存在)
try:
learner.add_style(style, situation)
# 学习映射关系
success = style_learner_manager.learn_mapping(
self.chat_id,
up_content,
style
)
if success:
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
else:
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
# 保存当前聊天室的 style_learner 模型
if has_new_expressions:
try:
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
save_success = learner.save(style_learner_manager.model_save_path)
if save_success:
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
else:
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}")
return learnt_expressions
async def match_expression_context(
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
) -> List[Tuple[str, str, str]]:
# 为expression_pairs逐个条目赋予编号并构建成字符串
numbered_pairs = []
for i, (situation, style) in enumerate(expression_pairs, 1):
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
expression_pairs_str = "\n".join(numbered_pairs)
prompt = "match_expression_context_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt,
expression_pairs=expression_pairs_str,
chat_str=random_msg_match_str,
)
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
# print(f"match_expression_context_prompt: {prompt}")
# print(f"{response}")
# 解析JSON响应
match_responses = []
try:
response = response.strip()
# 检查是否已经是标准JSON数组格式
if response.startswith("[") and response.endswith("]"):
match_responses = json.loads(response)
else:
# 尝试直接解析多个JSON对象
try:
# 如果是多个JSON对象用逗号分隔包装成数组
if response.startswith("{") and not response.startswith("["):
response = "[" + response + "]"
match_responses = json.loads(response)
else:
# 使用repair_json处理响应
repaired_content = repair_json(response)
# 确保repaired_content是列表格式
if isinstance(repaired_content, str):
try:
parsed_data = json.loads(repaired_content)
if isinstance(parsed_data, dict):
# 如果是字典,包装成列表
match_responses = [parsed_data]
elif isinstance(parsed_data, list):
match_responses = parsed_data
else:
match_responses = []
except json.JSONDecodeError:
match_responses = []
elif isinstance(repaired_content, dict):
# 如果是字典,包装成列表
match_responses = [repaired_content]
elif isinstance(repaired_content, list):
match_responses = repaired_content
else:
match_responses = []
except json.JSONDecodeError:
# 如果还是失败尝试repair_json
repaired_content = repair_json(response)
if isinstance(repaired_content, str):
parsed_data = json.loads(repaired_content)
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
else:
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
except (json.JSONDecodeError, Exception) as e:
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
return []
# 确保 match_responses 是一个列表
if not isinstance(match_responses, list):
if isinstance(match_responses, dict):
match_responses = [match_responses]
else:
logger.error(f"match_responses 不是列表或字典类型: {type(match_responses)}, 内容: {match_responses}")
return []
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
logger.debug(f"match_responses 内容: {match_responses}")
for match_response in match_responses:
try:
# 检查 match_response 的类型
if not isinstance(match_response, dict):
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
continue
# 获取表达方式序号
if "expression_pair" not in match_response:
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
continue
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
situation, style = expression_pairs[pair_index]
context = match_response.get("context", "")
matched_expressions.append((situation, style, context))
used_pair_indices.add(pair_index) # 标记该索引已使用
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
elif pair_index in used_pair_indices:
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
except (ValueError, KeyError, IndexError, TypeError) as e:
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
continue
return matched_expressions
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式
Args:
num: 学习数量
"""
current_time = time.time()
# 获取上次学习之后的消息
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=current_time,
limit=num,
)
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 学习用
random_msg_str: str = await build_anonymous_messages(random_msg)
# 溯源用
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt(
"learn_style_prompt",
chat_str=random_msg_str,
)
# print(f"random_msg_str:{random_msg_str}")
# logger.info(f"学习{type_str}的prompt: {prompt}")
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
# 为每条消息构建精简文本列表,保留到原消息索引的映射
bare_lines: List[Tuple[int, str]] = self._build_bare_lines(random_msg)
# 将 matched_expressions 结合上一句 up_content若不存在上一句则跳过
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
for situation, style, context in matched_expressions:
# 在 bare_lines 中找到第一处相似度达到85%的行
pos = None
for i, (_, c) in enumerate(bare_lines):
similarity = calculate_similarity(c, context)
if similarity >= 0.85: # 85%相似度阈值
pos = i
break
if pos is None or pos == 0:
# 没有匹配到目标句或没有上一句,跳过该表达
continue
# 检查目标句是否为空
target_content = bare_lines[pos][1]
if not target_content:
# 目标句为空,跳过该表达
continue
prev_original_idx = bare_lines[pos - 1][0]
up_content = self._filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
if not up_content:
# 上一句为空,跳过该表达
continue
filtered_with_up.append((situation, style, context, up_content))
if not filtered_with_up:
return None
return filtered_with_up
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组
"""
expressions: List[Tuple[str, str, str]] = []
for line in response.splitlines():
line = line.strip()
if not line:
continue
# 查找"当"和下一个引号
idx_when = line.find('"')
if idx_when == -1:
continue
idx_quote1 = idx_when + 1
idx_quote2 = line.find('"', idx_quote1 + 1)
if idx_quote2 == -1:
continue
situation = line[idx_quote1 + 1 : idx_quote2]
# 查找"使用"
idx_use = line.find('使用"', idx_quote2)
if idx_use == -1:
continue
idx_quote3 = idx_use + 2
idx_quote4 = line.find('"', idx_quote3 + 1)
if idx_quote4 == -1:
continue
style = line[idx_quote3 + 1 : idx_quote4]
expressions.append((situation, style))
return expressions
def _filter_message_content(self, content: str) -> str:
"""
过滤消息内容移除回复@图片等格式
Args:
content: 原始消息内容
Returns:
str: 过滤后的内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
# 移除@<...>格式的内容
content = re.sub(r'@<[^>]*>', '', content)
# 移除[picid:...]格式的图片ID
content = re.sub(r'\[picid:[^\]]*\]', '', content)
# 移除[表情包:...]格式的内容
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
return content.strip()
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
"""
为每条消息构建精简文本列表保留到原消息索引的映射
Args:
messages: 消息列表
Returns:
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
"""
bare_lines: List[Tuple[int, str]] = []
for idx, msg in enumerate(messages):
content = msg.processed_plain_text or ""
content = self._filter_message_content(content)
# 即使content为空也要记录防止错位
bare_lines.append((idx, content))
return bare_lines
init_prompt()
class ExpressionLearnerManager:
def __init__(self):
self.expression_learners = {}
self._ensure_expression_directories()
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
if chat_id not in self.expression_learners:
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
return self.expression_learners[chat_id]
def _ensure_expression_directories(self):
"""
确保表达方式相关的目录结构存在
"""
base_dir = os.path.join("data", "expression")
directories_to_create = [
base_dir,
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}")
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
expression_learner_manager = ExpressionLearnerManager()

View File

@ -0,0 +1,462 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.express.style_learner import style_learner_manager
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境
{all_situations}
请你分析聊天内容的语境情绪话题类型从上述情境中选择最适合当前聊天情境的最多{max_num}个情境
考虑因素包括
1. 聊天的情绪氛围轻松严肃幽默等
2. 话题类型日常技术游戏情感等
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如
{{
"selected_situations": [2, 3, 5, 7, 19]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""随机抽样"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用随机抽样
selected = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
chosen_idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(chosen_idx))
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
"""
使用 style_learner 模型预测最合适的表达方式
Args:
chat_id: 聊天室ID
target_message: 目标消息内容
total_num: 需要预测的数量
Returns:
List[Dict[str, Any]]: 预测的表达方式列表
"""
try:
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
predicted_expressions = []
# 为每个相关的chat_id进行预测
for related_chat_id in related_chat_ids:
try:
# 使用 style_learner 预测最合适的风格
best_style, scores = style_learner_manager.predict_style(
related_chat_id, target_message, top_k=total_num
)
if best_style and scores:
# 获取预测风格的完整信息
learner = style_learner_manager.get_learner(related_chat_id)
style_id, situation = learner.get_style_info(best_style)
if style_id and situation:
# 从数据库查找对应的表达记录
expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) &
(Expression.situation == situation) &
(Expression.style == best_style)
)
if expr_query.exists():
expr = expr_query.get()
predicted_expressions.append({
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": target_message
})
else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
continue
# 按预测分数排序,取前 total_num 个
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
selected_expressions = predicted_expressions[:total_num]
logger.info(f"为聊天室 {chat_id} 预测到 {len(selected_expressions)} 个表达方式")
return selected_expressions
except Exception as e:
logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择
return self._random_expressions(chat_id, total_num)
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
随机选择表达方式
Args:
chat_id: 聊天室ID
total_num: 需要选择的数量
Returns:
List[Dict[str, Any]]: 随机选择的表达方式列表
"""
try:
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids))
)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
]
# 随机抽样
if style_exprs:
selected_style = weighted_sample(style_exprs, total_num)
else:
selected_style = []
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
return selected_style
except Exception as e:
logger.error(f"随机选择表达方式失败: {e}")
return []
async def select_suitable_expressions(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
根据配置模式选择适合的表达方式
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 获取配置模式
expression_mode = global_config.expression.mode
if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_model_only(chat_id, target_message, max_num)
elif expression_mode == "classic":
# classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
else:
logger.warning(f"未知的表达模式: {expression_mode}回退到classic模式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
async def _select_expressions_model_only(
self,
chat_id: str,
target_message: str,
max_num: int = 10,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
exp_model模式直接使用模型预测不经过LLM
Args:
chat_id: 聊天流ID
target_message: 目标消息内容
max_num: 最大选择数量
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 使用模型预测最合适的表达方式
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
return selected_expressions, selected_ids
except Exception as e:
logger.error(f"exp_model模式选择表达方式失败: {e}")
return [], []
async def _select_expressions_classic(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
classic模式随机选择+LLM选择
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 1. 使用随机抽样选择表达方式
style_exprs = self._random_expressions(chat_id, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in style_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return [], []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning("LLM返回空结果")
return [], []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return [], []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression)
# 对选中的所有表达方式更新last_active_time
if valid_expressions:
self.update_expressions_last_active_time(valid_expressions)
logger.info(f"classic模式从{len(all_expressions)}个情境中选择了{len(valid_expressions)}")
return valid_expressions, selected_ids
except Exception as e:
logger.error(f"classic模式处理表达方式选择时出错: {e}")
return [], []
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
"""对一批表达方式更新last_active_time"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
"表达方式激活: 更新last_active_time in db"
)
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@ -0,0 +1,141 @@
from typing import Dict, Optional, Tuple, List
from collections import Counter, defaultdict
import pickle
import os
from .tokenizer import Tokenizer
from .online_nb import OnlineNaiveBayes
class ExpressorModel:
"""
直接使用朴素贝叶斯精排可在线学习
支持存储situation字段不参与计算仅与style对应
"""
def __init__(self,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
vocab_size: int = 200000,
use_jieba: bool = True):
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
def add_candidate(self, cid: str, text: str, situation: str = None):
"""添加候选文本和对应的situation"""
self._candidates[cid] = text
if situation is not None:
self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
"""批量添加候选文本和对应的situations"""
for i, (cid, text) in enumerate(items):
situation = situations[i] if situations and i < len(situations) else None
self.add_candidate(cid, text, situation)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
"""直接对所有候选进行朴素贝叶斯评分"""
toks = self.tokenizer.tokenize(text)
if not toks:
return None, {}
if not self._candidates:
return None, {}
# 对所有候选进行评分
tf = Counter(toks)
all_cids = list(self._candidates.keys())
scores = self.nb.score_batch(tf, all_cids)
# 取最高分
if not scores:
return None, {}
# 根据k参数限制返回的候选数量
if k is not None and k > 0:
# 按分数降序排序取前k个
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
limited_scores = dict(sorted_scores[:k])
best = sorted_scores[0][0] if sorted_scores else None
return best, limited_scores
else:
# 如果没有指定k返回所有分数
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
def update_positive(self, text: str, cid: str):
"""更新正反馈学习"""
toks = self.tokenizer.tokenize(text)
if not toks:
return
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: float):
self.nb.decay(factor=factor)
def get_situation(self, cid: str) -> Optional[str]:
"""获取候选对应的situation"""
return self._situations.get(cid)
def get_style(self, cid: str) -> Optional[str]:
"""获取候选对应的style"""
return self._candidates.get(cid)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""获取候选的style和situation信息"""
return self._candidates.get(cid), self._situations.get(cid)
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""获取所有候选的style和situation信息"""
return {cid: (style, self._situations.get(cid))
for cid, style in self._candidates.items()}
def save(self, path: str):
"""保存模型"""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
pickle.dump({
"candidates": self._candidates,
"situations": self._situations,
"nb": {
"cls_counts": dict(self.nb.cls_counts),
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"alpha": self.nb.alpha,
"beta": self.nb.beta,
"gamma": self.nb.gamma,
"V": self.nb.V,
}
}, f)
def load(self, path: str):
"""加载模型"""
with open(path, "rb") as f:
obj = pickle.load(f)
# 还原候选文本
self._candidates = obj["candidates"]
# 还原situations兼容旧版本
self._situations = obj.get("situations", {})
# 还原朴素贝叶斯模型
self.nb.cls_counts = obj["nb"]["cls_counts"]
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
self.nb.alpha = obj["nb"]["alpha"]
self.nb.beta = obj["nb"]["beta"]
self.nb.gamma = obj["nb"]["gamma"]
self.nb.V = obj["nb"]["V"]
self.nb._logZ.clear()
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
from collections import defaultdict
outer = defaultdict(lambda: defaultdict(float))
for k, inner in d.items():
outer[k].update(inner)
return outer

View File

@ -0,0 +1,60 @@
import math
from typing import Dict, List
from collections import defaultdict, Counter
class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str):
if cid in self._logZ:
del self._logZ[cid]
def _logZ_c(self, cid: str) -> float:
if cid not in self._logZ:
Z = self.cls_counts[cid] + self.V * self.alpha
self._logZ[cid] = math.log(max(Z, 1e-12))
return self._logZ[cid]
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
total_cls = sum(self.cls_counts.values())
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
for cid in cids:
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
s = prior
logZ = self._logZ_c(cid)
tc = self.token_counts[cid]
for term, qtf in tf.items():
num = tc.get(term, 0.0) + self.alpha
s += qtf * (math.log(num) - logZ)
out[cid] = s
return out
def update_positive(self, tf: Counter, cid: str):
inc = 0.0
tc = self.token_counts[cid]
for term, c in tf.items():
tc[term] += float(c)
inc += float(c)
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: float = None):
g = self.gamma if factor is None else factor
if g >= 1.0:
return
for cid in list(self.cls_counts.keys()):
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)

View File

@ -0,0 +1,28 @@
import re
from typing import List, Optional, Set
try:
import jieba
_HAS_JIEBA = True
except Exception:
_HAS_JIEBA = False
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
def simple_en_tokenize(text: str) -> List[str]:
return _WORD_RE.findall(text.lower())
class Tokenizer:
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
self.stopwords = stopwords or set()
self.use_jieba = use_jieba and _HAS_JIEBA
def tokenize(self, text: str) -> List[str]:
text = (text or "").strip()
if not text:
return []
if self.use_jieba:
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
else:
toks = simple_en_tokenize(text)
return [t for t in toks if t not in self.stopwords]

View File

@ -0,0 +1,628 @@
"""
多聊天室表达风格学习系统
支持为每个chat_id维护独立的表达模型学习从up_content到style的映射
"""
import os
import pickle
import traceback
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
import asyncio
from src.common.logger import get_logger
from .expressor_model.model import ExpressorModel
logger = get_logger("style_learner")
class StyleLearner:
"""
单个聊天室的表达风格学习器
学习从up_content到style的映射关系
支持动态管理风格集合最多2000个
"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0 # 下一个可用的style_id
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": defaultdict(int),
"last_update": None,
"style_usage_frequency": defaultdict(int) # 风格使用频率
}
def add_style(self, style: str, situation: str = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 对应的situation文本可选
Returns:
bool: 添加是否成功
"""
try:
# 检查是否已存在
if style in self.style_to_id:
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
return True
# 检查是否超过最大限制
if len(self.style_to_id) >= self.max_styles:
logger.warning(f"[{self.chat_id}] 已达到最大风格数量限制 ({self.max_styles})")
return False
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
(f", situation: '{situation}'" if situation else ""))
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
return False
def remove_style(self, style: str) -> bool:
"""
删除一个风格
Args:
style: 要删除的风格文本
Returns:
bool: 删除是否成功
"""
try:
if style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
return False
style_id = self.style_to_id[style]
# 从映射中删除
del self.style_to_id[style]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从expressor模型中删除通过重新构建
self._rebuild_expressor()
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
return False
def update_style(self, old_style: str, new_style: str) -> bool:
"""
更新一个风格
Args:
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
try:
if old_style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
return False
if new_style in self.style_to_id and new_style != old_style:
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
return False
style_id = self.style_to_id[old_style]
# 更新映射
del self.style_to_id[old_style]
self.style_to_id[new_style] = style_id
self.id_to_style[style_id] = new_style
# 更新expressor模型保留原有的situation
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, new_style, situation)
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
return False
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
"""
批量添加风格
Args:
styles: 风格文本列表
situations: 对应的situation文本列表可选
Returns:
int: 成功添加的数量
"""
success_count = 0
for i, style in enumerate(styles):
situation = situations[i] if situations and i < len(situations) else None
if self.add_style(style, situation):
success_count += 1
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
return success_count
def get_all_styles(self) -> List[str]:
"""获取所有已注册的风格"""
return list(self.style_to_id.keys())
def get_style_count(self) -> int:
"""获取当前风格数量"""
return len(self.style_to_id)
def get_situation(self, style: str) -> Optional[str]:
"""
获取风格对应的situation
Args:
style: 风格文本
Returns:
Optional[str]: 对应的situation如果不存在则返回None
"""
if style not in self.style_to_id:
return None
style_id = self.style_to_id[style]
return self.id_to_situation.get(style_id)
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取风格的完整信息
Args:
style: 风格文本
Returns:
Tuple[Optional[str], Optional[str]]: (style_id, situation)
"""
if style not in self.style_to_id:
return None, None
style_id = self.style_to_id[style]
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""
获取所有风格的完整信息
Returns:
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
"""
result = {}
for style, style_id in self.style_to_id.items():
situation = self.id_to_situation.get(style_id)
result[style] = (style_id, situation)
return result
def _rebuild_expressor(self):
"""重新构建expressor模型删除风格后使用"""
try:
# 重新创建expressor
self.expressor = ExpressorModel(**self.model_config)
# 重新添加所有风格和situation
for style_id, style_text in self.id_to_style.items():
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, style_text, situation)
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
except Exception as e:
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
如果style不存在会自动添加
Args:
up_content: 输入内容
style: 对应的style文本
Returns:
bool: 学习是否成功
"""
try:
# 如果style不存在先添加它
if style not in self.style_to_id:
if not self.add_style(style):
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["style_usage_frequency"][style] += 1
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
traceback.print_exc()
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style文本, 所有候选的分数]
"""
try:
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
return best_style, style_scores
except Exception as e:
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
traceback.print_exc()
return None, {}
def decay_learning(self, factor: Optional[float] = None) -> None:
"""
对学习到的知识进行衰减遗忘
Args:
factor: 衰减因子None则使用配置中的gamma
"""
self.expressor.decay(factor)
logger.debug(f"[{self.chat_id}] 执行知识衰减")
def get_stats(self) -> Dict:
"""获取学习统计信息"""
return {
"chat_id": self.chat_id,
"total_samples": self.learning_stats["total_samples"],
"style_count": len(self.style_to_id),
"max_styles": self.max_styles,
"style_counts": dict(self.learning_stats["style_counts"]),
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
"last_update": self.learning_stats["last_update"],
"all_styles": list(self.style_to_id.keys())
}
def save(self, base_path: str) -> bool:
"""
保存模型到文件
Args:
base_path: 基础路径实际文件为 {base_path}/{chat_id}_style_model.pkl
"""
try:
os.makedirs(base_path, exist_ok=True)
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
# 保存模型和统计信息
save_data = {
"model_config": self.model_config,
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"max_styles": self.max_styles,
"learning_stats": self.learning_stats
}
# 先保存expressor模型
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
self.expressor.save(expressor_path)
# 保存其他数据
with open(file_path, "wb") as f:
pickle.dump(save_data, f)
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载模型
Args:
base_path: 基础路径
"""
try:
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
return False
# 加载其他数据
with open(file_path, "rb") as f:
save_data = pickle.load(f)
# 恢复配置和状态
self.model_config = save_data["model_config"]
self.style_to_id = save_data["style_to_id"]
self.id_to_style = save_data["id_to_style"]
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
self.next_style_id = save_data["next_style_id"]
self.max_styles = save_data.get("max_styles", 2000)
self.learning_stats = save_data["learning_stats"]
# 重新创建expressor并加载
self.expressor = ExpressorModel(**self.model_config)
self.expressor.load(expressor_path)
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
return False
class StyleLearnerManager:
"""
多聊天室表达风格学习管理器
为每个chat_id维护独立的StyleLearner实例
每个chat_id可以动态管理自己的风格集合最多2000个
"""
def __init__(self, model_save_path: str = "data/style_models"):
self.model_save_path = model_save_path
self.learners: Dict[str, StyleLearner] = {}
# 自动保存配置
self.auto_save_interval = 300 # 5分钟
self._auto_save_task: Optional[asyncio.Task] = None
logger.info("StyleLearnerManager 已初始化")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置None则使用默认配置
Returns:
StyleLearner实例
"""
if chat_id not in self.learners:
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
return self.learners[chat_id]
def add_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id添加风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 添加是否成功
"""
learner = self.get_learner(chat_id)
return learner.add_style(style)
def remove_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id删除风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 删除是否成功
"""
learner = self.get_learner(chat_id)
return learner.remove_style(style)
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
"""
为指定chat_id更新风格
Args:
chat_id: 聊天室ID
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
learner = self.get_learner(chat_id)
return learner.update_style(old_style, new_style)
def get_chat_styles(self, chat_id: str) -> List[str]:
"""
获取指定chat_id的所有风格
Args:
chat_id: 聊天室ID
Returns:
List[str]: 风格列表
"""
learner = self.get_learner(chat_id)
return learner.get_all_styles()
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 输入内容
style: 对应的style
Returns:
bool: 学习是否成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
预测最合适的style
Args:
chat_id: 聊天室ID
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style, 所有候选分数]
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def decay_all_learners(self, factor: Optional[float] = None) -> None:
"""
对所有学习器执行衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.decay_learning(factor)
logger.info("已对所有学习器执行衰减")
def get_all_stats(self) -> Dict[str, Dict]:
"""获取所有学习器的统计信息"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
def save_all_models(self) -> bool:
"""保存所有模型"""
success_count = 0
for learner in self.learners.values():
if learner.save(self.model_save_path):
success_count += 1
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
return success_count == len(self.learners)
def load_all_models(self) -> int:
"""加载所有已保存的模型"""
if not os.path.exists(self.model_save_path):
return 0
loaded_count = 0
for filename in os.listdir(self.model_save_path):
if filename.endswith("_style_model.pkl"):
chat_id = filename.replace("_style_model.pkl", "")
learner = StyleLearner(chat_id)
if learner.load(self.model_save_path):
self.learners[chat_id] = learner
loaded_count += 1
logger.info(f"已加载 {loaded_count} 个模型")
return loaded_count
async def start_auto_save(self) -> None:
"""启动自动保存任务"""
if self._auto_save_task is None or self._auto_save_task.done():
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
logger.info("已启动自动保存任务")
async def stop_auto_save(self) -> None:
"""停止自动保存任务"""
if self._auto_save_task and not self._auto_save_task.done():
self._auto_save_task.cancel()
try:
await self._auto_save_task
except asyncio.CancelledError:
pass
logger.info("已停止自动保存任务")
async def _auto_save_loop(self) -> None:
"""自动保存循环"""
while True:
try:
await asyncio.sleep(self.auto_save_interval)
self.save_all_models()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"自动保存失败: {e}")
# 全局管理器实例
style_learner_manager = StyleLearnerManager()

View File

@ -85,4 +85,4 @@ class ModelAttemptFailed(Exception):
self.original_exception = original_exception
def __str__(self):
return self.message
return self.message

View File

@ -72,8 +72,8 @@ class BaseClient(ABC):
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
@ -117,6 +117,7 @@ class BaseClient(ABC):
self,
model_info: ModelInfo,
audio_base64: str,
max_tokens: Optional[int] = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""

View File

@ -1,7 +1,7 @@
import asyncio
import io
import base64
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict
from google import genai
from google.genai.types import (
@ -17,6 +17,7 @@ from google.genai.types import (
EmbedContentResponse,
EmbedContentConfig,
SafetySetting,
HttpOptions,
HarmCategory,
HarmBlockThreshold,
)
@ -182,6 +183,14 @@ def _process_delta(
if delta.text:
fc_delta_buffer.write(delta.text)
# 处理 thoughtGemini 的特殊字段)
for c in getattr(delta, "candidates", []):
if c.content and getattr(c.content, "parts", None):
for p in c.content.parts:
if getattr(p, "thought", False) and getattr(p, "text", None):
# 把 thought 写入 buffer避免 resp.content 永远为空
fc_delta_buffer.write(p.text)
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
for call in delta.function_calls:
try:
@ -203,6 +212,7 @@ def _process_delta(
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
last_resp: GenerateContentResponse | None = None, # 传入 last_resp
) -> APIResponse:
# sourcery skip: simplify-len-comparison, use-assigned-variable
resp = APIResponse()
@ -227,6 +237,21 @@ def _build_stream_api_resp(
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
# 检查是否因为 max_tokens 截断
reason = None
if last_resp and getattr(last_resp, "candidates", None):
c0 = last_resp.candidates[0]
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
if str(reason).endswith("MAX_TOKENS"):
if resp.content and resp.content.strip():
logger.warning(
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
)
else:
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
if not resp.content and not resp.tool_calls:
raise EmptyResponseException()
@ -245,12 +270,14 @@ async def _default_stream_response_handler(
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
async for chunk in resp_stream:
last_resp = chunk # 保存最后一个响应
# 检查是否有中断量
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
@ -269,10 +296,12 @@ async def _default_stream_response_handler(
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
chunk.usage_metadata.total_token_count or 0,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_tool_calls_buffer,
last_resp=last_resp,
), _usage_record
except Exception:
# 确保缓冲区被关闭
@ -332,6 +361,35 @@ def _default_normal_response_parser(
api_response.raw_data = resp
# 检查是否因为 max_tokens 截断
try:
if resp.candidates:
c0 = resp.candidates[0]
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
if reason and "MAX_TOKENS" in str(reason):
# 检查第二个及之后的 parts 是否有内容
has_real_output = False
if getattr(c0, "content", None) and getattr(c0.content, "parts", None):
for p in c0.content.parts[1:]: # 跳过第一个 thought
if getattr(p, "text", None) and p.text.strip():
has_real_output = True
break
if not has_real_output and getattr(resp, "text", None):
has_real_output = True
if has_real_output:
logger.warning(
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
)
else:
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
return api_response, _usage_record
except Exception as e:
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
# 最终的、唯一的空响应检查
if not api_response.content and not api_response.tool_calls:
raise EmptyResponseException("响应中既无文本内容也无工具调用")
@ -345,17 +403,45 @@ class GeminiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
# 增加传入参数处理
http_options_kwargs: Dict[str, Any] = {}
# 秒转换为毫秒传入
if api_provider.timeout is not None:
http_options_kwargs["timeout"] = int(api_provider.timeout * 1000)
# 传入并处理地址和版本(必须为Gemini格式)
if api_provider.base_url:
parts = api_provider.base_url.rstrip("/").rsplit("/", 1)
if len(parts) == 2 and parts[1].startswith("v"):
http_options_kwargs["base_url"] = f"{parts[0]}/"
http_options_kwargs["api_version"] = parts[1]
else:
http_options_kwargs["base_url"] = api_provider.base_url
http_options_kwargs["api_version"] = None
self.client = genai.Client(
http_options=HttpOptions(**http_options_kwargs),
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
@staticmethod
def clamp_thinking_budget(tb: int, model_id: str) -> int:
def clamp_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int:
"""
按模型限制思考预算范围仅支持指定的模型支持带数字后缀的新版本
"""
limits = None
# 参数传入处理
tb = THINKING_BUDGET_AUTO
if extra_params and "thinking_budget" in extra_params:
try:
tb = int(extra_params["thinking_budget"])
except (ValueError, TypeError):
logger.warning(
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
)
# 优先尝试精确匹配
if model_id in THINKING_BUDGET_LIMITS:
limits = THINKING_BUDGET_LIMITS[model_id]
@ -368,20 +454,29 @@ class GeminiClient(BaseClient):
limits = THINKING_BUDGET_LIMITS[key]
break
# 特殊值处理
# 预算值处理
if tb == THINKING_BUDGET_AUTO:
return THINKING_BUDGET_AUTO
if tb == THINKING_BUDGET_DISABLED:
if limits and limits.get("can_disable", False):
return THINKING_BUDGET_DISABLED
return limits["min"] if limits else THINKING_BUDGET_AUTO
if limits:
logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退到最小值 {limits['min']}")
return limits["min"]
return THINKING_BUDGET_AUTO
# 已知模型裁剪到范围
# 已知模型范围裁剪 + 提示
if limits:
return max(limits["min"], min(tb, limits["max"]))
if tb < limits["min"]:
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过小,已调整为最小值 {limits['min']}")
return limits["min"]
if tb > limits["max"]:
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过大,已调整为最大值 {limits['max']}")
return limits["max"]
return tb
# 未知模型,返回动态模式
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
# 未知模型 → 默认自动模式
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,已启用模型自动预算兼容")
return THINKING_BUDGET_AUTO
async def get_response(
@ -389,8 +484,8 @@ class GeminiClient(BaseClient):
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.4,
max_tokens: Optional[int] = 1024,
temperature: Optional[float] = 0.4,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
@ -429,16 +524,8 @@ class GeminiClient(BaseClient):
messages = _convert_messages(message_list)
# 将tool_options转换为Gemini API所需的格式
tools = _convert_tool_options(tool_options) if tool_options else None
tb = THINKING_BUDGET_AUTO
# 空处理
if extra_params and "thinking_budget" in extra_params:
try:
tb = int(extra_params["thinking_budget"])
except (ValueError, TypeError):
logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}")
# 裁剪到模型支持的范围
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
# 解析并裁剪 thinking_budget
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
# 将response_format转换为Gemini API所需的格式
generation_config_dict = {
@ -497,15 +584,20 @@ class GeminiClient(BaseClient):
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
# 重封装 ClientError ServerError RespNotOkException
raise RespNotOkException(e.code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
# 工具调用相关错误
raise RespParseException(None, f"工具调用参数错误: {str(e)}") from None
except EmptyResponseException as e:
# 保持原始异常,便于区分“空响应”和网络异常
raise e
except Exception as e:
# 其他未预料的错误,才归为网络连接类
raise NetworkConnectionError() from e
if usage_record:
@ -561,41 +653,51 @@ class GeminiClient(BaseClient):
return response
def get_audio_transcriptions(
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
audio_base64: str,
max_tokens: Optional[int] = 2048,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: 音频文件的Base64编码字符串
:param max_tokens: 最大输出token数默认2048
:param extra_params: 额外参数可选
:return: 转录响应
"""
# 解析并裁剪 thinking_budget
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
# 构造 prompt + 音频输入
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
contents = [
Content(
role="user",
parts=[
Part.from_text(text=prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
],
)
]
generation_config_dict = {
"max_output_tokens": 2048,
"max_output_tokens": max_tokens,
"response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig(
include_thoughts=True,
thinking_budget=(
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
),
thinking_budget=tb,
),
"safety_settings": gemini_safe_settings,
}
generate_content_config = GenerateContentConfig(**generation_config_dict)
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
try:
raw_response: GenerateContentResponse = self.client.models.generate_content(
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
model=model_info.model_identifier,
contents=[
Content(
role="user",
parts=[
Part.from_text(text=prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
],
)
],
contents=contents,
config=generate_content_config,
)
resp, usage_record = _default_normal_response_parser(raw_response)

View File

@ -403,8 +403,8 @@ class OpenaiClient(BaseClient):
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
max_tokens: Optional[int] = 1024,
temperature: Optional[float] = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
@ -488,6 +488,9 @@ class OpenaiClient(BaseClient):
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态
# logger.
logger.debug(f"OpenAI API响应(非流式): {req_task.result()}")
# logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}")
resp, usage_record = async_response_parser(req_task.result())
@ -507,6 +510,8 @@ class OpenaiClient(BaseClient):
total_tokens=usage_record[2],
)
# logger.debug(f"OpenAI API响应: {resp}")
return resp
async def get_embedding(

View File

@ -26,18 +26,6 @@ install(extra_lines=3)
logger = get_logger("model_utils")
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class RequestType(Enum):
"""请求类型枚举"""
@ -160,6 +148,8 @@ class LLMRequest:
)
logger.debug(f"LLM请求总耗时: {time.time() - start_time}")
logger.debug(f"LLM生成内容: {response}")
content = response.content
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
@ -267,14 +257,14 @@ class LLMRequest:
extra_params=model_info.extra_params,
)
elif request_type == RequestType.EMBEDDING:
assert embedding_input is not None
assert embedding_input is not None, "嵌入输入不能为空"
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.AUDIO:
assert audio_base64 is not None
assert audio_base64 is not None, "音频Base64不能为空"
return await client.get_audio_transcriptions(
model_info=model_info,
audio_base64=audio_base64,
@ -365,24 +355,23 @@ class LLMRequest:
embedding_input=embedding_input,
audio_base64=audio_base64,
)
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
if response_usage := response.usage:
total_tokens += response_usage.total_tokens
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
return response, model_info
except ModelAttemptFailed as e:
last_exception = e.original_exception or e
logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty - 1)
failed_models_this_request.add(model_info.name)
if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400:
logger.error("收到不可恢复的客户端错误 (400),中止所有尝试。")
raise last_exception from e
finally:
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
if usage_penalty > 0:
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
logger.error(f"所有 {max_attempts} 个模型均尝试失败。")
if last_exception:
raise last_exception

View File

@ -13,8 +13,8 @@ from src.common.logger import get_logger
from src.common.server import get_global_server, Server
from src.mood.mood_manager import mood_manager
from src.chat.knowledge import lpmm_start_up
from src.memory_system.memory_management_task import MemoryManagementTask
from rich.traceback import install
from src.migrate_helper.migrate import check_and_run_migrations
# from src.api.main import start_api_server
# 导入新的插件管理器
@ -83,22 +83,19 @@ class MainSystem:
logger.info("表情包管理器初始化成功")
# 启动情绪管理器
await mood_manager.start()
logger.info("情绪管理器初始化成功")
if global_config.mood.enable_mood:
await mood_manager.start()
logger.info("情绪管理器初始化成功")
# 初始化聊天管理器
await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task())
logger.info("聊天管理器初始化成功")
# # 根据配置条件性地初始化记忆系统
# if global_config.memory.enable_memory:
# if self.hippocampus_manager:
# self.hippocampus_manager.initialize()
# logger.info("记忆系统初始化成功")
# else:
# logger.info("记忆系统已禁用,跳过初始化")
# 添加记忆管理任务
await async_task_manager.add_task(MemoryManagementTask())
logger.info("记忆管理任务已启动")
# await asyncio.sleep(0.5) #防止logger输出飞了
@ -106,7 +103,6 @@ class MainSystem:
self.app.register_message_handler(chat_bot.message_process)
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
await check_and_run_migrations()
# 触发 ON_START 事件
from src.plugin_system.core.events_manager import events_manager

View File

@ -1,36 +0,0 @@
[inner]
version = "1.0.0"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
# 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_loading_indicator = true # 是否显示加载提示

View File

@ -1,68 +0,0 @@
[inner]
version = "1.2.0"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
enable_s4u = false
# 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_streaming_output = true # 是否启用流式输出false时全部生成后一次性发送
max_context_message_length = 20
max_core_message_length = 30
# 模型配置
[models]
# 主要对话模型配置
[models.chat]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 规划模型配置
[models.motion]
name = "qwen3-32b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7
enable_thinking = false
# 情感分析模型配置
[models.emotion]
name = "qwen3-8b"
provider = "BAILIAN"
pri_in = 0.5
pri_out = 2
temp = 0.7

View File

@ -1,167 +0,0 @@
from src.chat.message_receive.chat_stream import get_chat_manager
import time
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.logger import get_logger
logger = get_logger(__name__)
def init_prompt():
Prompt(
"""
你之前的内心想法是{mind}
{memory_block}
{relation_info_block}
{chat_target}
{time_block}
{chat_info}
{identity}
你刚刚在{chat_target_2},你你刚刚的心情是{mood_state}
---------------------
在这样的情况下你对上面的内容你对 {sender} 发送的 消息 {target} 进行了回复
你刚刚选择回复的内容是{reponse}
现在根据你之前的想法和回复的内容推测你现在的想法思考你现在的想法是什么为什么做出上面的回复内容
请不要浮夸和夸张修辞不要输出多余内容(包括前后缀冒号和引号括号()表情包at或 @等 )只输出想法""",
"after_response_think_prompt",
)
class MaiThinking:
def __init__(self, chat_id):
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.platform = self.chat_stream.platform
if self.chat_stream.group_info:
self.is_group = True
else:
self.is_group = False
self.s4u_message_processor = S4UMessageProcessor()
self.mind = ""
self.memory_block = ""
self.relation_info_block = ""
self.time_block = ""
self.chat_target = ""
self.chat_target_2 = ""
self.chat_info = ""
self.mood_state = ""
self.identity = ""
self.sender = ""
self.target = ""
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
async def do_think_before_response(self):
pass
async def do_think_after_response(self, reponse: str):
prompt = await global_prompt_manager.format_prompt(
"after_response_think_prompt",
mind=self.mind,
reponse=reponse,
memory_block=self.memory_block,
relation_info_block=self.relation_info_block,
time_block=self.time_block,
chat_target=self.chat_target,
chat_target_2=self.chat_target_2,
chat_info=self.chat_info,
mood_state=self.mood_state,
identity=self.identity,
sender=self.sender,
target=self.target,
)
result, _ = await self.thinking_model.generate_response_async(prompt)
self.mind = result
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
# logger.info(f"[{self.chat_id}] 思考前prompt{prompt}")
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
msg_recv = await self.build_internal_message_recv(self.mind)
await self.s4u_message_processor.process_message(msg_recv)
internal_manager.set_internal_state(self.mind)
async def do_think_when_receive_message(self):
pass
async def build_internal_message_recv(self, message_text: str):
msg_id = f"internal_{time.time()}"
message_dict = {
"message_info": {
"message_id": msg_id,
"time": time.time(),
"user_info": {
"user_id": "internal", # 内部用户ID
"user_nickname": "内心", # 内部昵称
"platform": self.platform, # 平台标记为 internal
# 其他 user_info 字段按需补充
},
"platform": self.platform, # 平台
# 其他 message_info 字段按需补充
},
"message_segment": {
"type": "text", # 消息类型
"data": message_text, # 消息内容
# 其他 segment 字段按需补充
},
"raw_message": message_text, # 原始消息内容
"processed_plain_text": message_text, # 处理后的纯文本
# 下面这些字段可选,根据 MessageRecv 需要
"is_emoji": False,
"has_emoji": False,
"is_picid": False,
"has_picid": False,
"is_voice": False,
"is_mentioned": False,
"is_command": False,
"is_internal": True,
"priority_mode": "interest",
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
"interest_value": 1.0,
}
if self.is_group:
message_dict["message_info"]["group_info"] = {
"platform": self.platform,
"group_id": self.chat_stream.group_info.group_id,
"group_name": self.chat_stream.group_info.group_name,
}
msg_recv = MessageRecvS4U(message_dict)
msg_recv.chat_info = self.chat_info
msg_recv.chat_stream = self.chat_stream
msg_recv.is_internal = True
return msg_recv
class MaiThinkingManager:
def __init__(self):
self.mai_think_list = []
def get_mai_think(self, chat_id):
for mai_think in self.mai_think_list:
if mai_think.chat_id == chat_id:
return mai_think
mai_think = MaiThinking(chat_id)
self.mai_think_list.append(mai_think)
return mai_think
mai_thinking_manager = MaiThinkingManager()
init_prompt()

View File

@ -1,342 +0,0 @@
import json
import time
from json_repair import repair_json
from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
from src.mais4u.s4u_config import s4u_config
logger = get_logger("action")
# 使用字典作为默认值但通过Prompt来注册以便外部重载
DEFAULT_HEAD_CODE = {
"看向上方": "(0,0.5,0)",
"看向下方": "(0,-0.5,0)",
"看向左边": "(-1,0,0)",
"看向右边": "(1,0,0)",
"随意朝向": "random",
"看向摄像机": "camera",
"注视对方": "(0,0,0)",
"看向正前方": "(0,0,0)",
}
DEFAULT_BODY_CODE = {
"双手背后向前弯腰": "010_0070",
"歪头双手合十": "010_0100",
"标准文静站立": "010_0101",
"双手交叠腹部站立": "010_0150",
"帅气的姿势": "010_0190",
"另一个帅气的姿势": "010_0191",
"手掌朝前可爱": "010_0210",
"平静,双手后放": "平静,双手后放",
"思考": "思考",
"优雅,左手放在腰上": "优雅,左手放在腰上",
"一般": "一般",
"可爱,双手前放": "可爱,双手前放",
}
async def get_head_code() -> dict:
"""获取头部动作代码字典"""
head_code_str = await global_prompt_manager.format_prompt("head_code_prompt")
if not head_code_str:
return DEFAULT_HEAD_CODE
try:
return json.loads(head_code_str)
except Exception as e:
logger.error(f"解析head_code_prompt失败使用默认值: {e}")
return DEFAULT_HEAD_CODE
async def get_body_code() -> dict:
"""获取身体动作代码字典"""
body_code_str = await global_prompt_manager.format_prompt("body_code_prompt")
if not body_code_str:
return DEFAULT_BODY_CODE
try:
return json.loads(body_code_str)
except Exception as e:
logger.error(f"解析body_code_prompt失败使用默认值: {e}")
return DEFAULT_BODY_CODE
def init_prompt():
# 注册头部动作代码
Prompt(
json.dumps(DEFAULT_HEAD_CODE, ensure_ascii=False, indent=2),
"head_code_prompt",
)
# 注册身体动作代码
Prompt(
json.dumps(DEFAULT_BODY_CODE, ensure_ascii=False, indent=2),
"body_code_prompt",
)
# 注册原有提示模板
Prompt(
"""
{chat_talking_prompt}
以上是群里正在进行的聊天记录
{indentify_block}
你现在的动作状态是
- 身体动作{body_action}
现在因为你发送了消息或者群里其他人发送了消息引起了你的注意你对其进行了阅读和思考请你更新你的动作状态
身体动作可选
{all_actions}
请只按照以下json格式输出描述你新的动作状态确保每个字段都存在
{{
"body_action": "..."
}}
""",
"change_action_prompt",
)
Prompt(
"""
{chat_talking_prompt}
以上是群里最近的聊天记录
{indentify_block}
你之前的动作状态是
- 身体动作{body_action}
身体动作可选
{all_actions}
距离你上次关注群里消息已经过去了一段时间你冷静了下来你的动作会趋于平缓或静止请你输出你现在新的动作状态用中文
请只按照以下json格式输出描述你新的动作状态确保每个字段都存在
{{
"body_action": "..."
}}
""",
"regress_action_prompt",
)
class ChatAction:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
self.body_action: str = "一般"
self.head_action: str = "注视摄像机"
self.regression_count: int = 0
# 新增body_action冷却池key为动作名value为剩余冷却次数
self.body_action_cooldown: dict[str, int] = {}
print(s4u_config.models.motion)
print(model_config.model_task_config.emotion)
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
self.last_change_time: float = 0
async def send_action_update(self):
"""发送动作更新到前端"""
body_code = (await get_body_code()).get(self.body_action, "")
await send_api.custom_to_stream(
message_type="body_action",
content=body_code,
stream_id=self.chat_id,
storage_message=False,
show_log=True,
)
async def update_action_by_message(self, message: MessageRecv):
self.regression_count = 0
message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=15,
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in (await get_body_code()).keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt(
"change_action_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
body_action=self.body_action,
all_actions=all_actions,
)
logger.info(f"prompt: {prompt}")
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
if action_data := json.loads(repair_json(response)):
# 记录原动作,切换后进入冷却
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action and prev_body_action:
self.body_action_cooldown[prev_body_action] = 3
self.body_action = new_body_action
self.head_action = action_data.get("head_action", self.head_action)
# 发送动作更新
await self.send_action_update()
self.last_change_time = message_time
except Exception as e:
logger.error(f"update_action_by_message error: {e}")
async def regress_action(self):
message_time = time.time()
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=10,
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in (await get_body_code()).keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt(
"regress_action_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
body_action=self.body_action,
all_actions=all_actions,
)
logger.info(f"prompt: {prompt}")
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
if action_data := json.loads(repair_json(response)):
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action and prev_body_action:
self.body_action_cooldown[prev_body_action] = 6
self.body_action = new_body_action
# 发送动作更新
await self.send_action_update()
self.regression_count += 1
self.last_change_time = message_time
except Exception as e:
logger.error(f"regress_action error: {e}")
# 新增:冷却池维护方法
def _update_body_action_cooldown(self):
remove_keys = []
for k in self.body_action_cooldown:
self.body_action_cooldown[k] -= 1
if self.body_action_cooldown[k] <= 0:
remove_keys.append(k)
for k in remove_keys:
del self.body_action_cooldown[k]
class ActionRegressionTask(AsyncTask):
def __init__(self, action_manager: "ActionManager"):
super().__init__(task_name="ActionRegressionTask", run_interval=3)
self.action_manager = action_manager
async def run(self):
logger.debug("Running action regression task...")
now = time.time()
for action_state in self.action_manager.action_state_list:
if action_state.last_change_time == 0:
continue
if now - action_state.last_change_time > 10:
if action_state.regression_count >= 3:
continue
logger.info(f"chat {action_state.chat_id} 开始动作回归, 这是第 {action_state.regression_count + 1}")
await action_state.regress_action()
class ActionManager:
def __init__(self):
self.action_state_list: list[ChatAction] = []
"""当前动作状态"""
self.task_started: bool = False
async def start(self):
"""启动动作回归后台任务"""
if self.task_started:
return
logger.info("启动动作回归任务...")
task = ActionRegressionTask(self)
await async_task_manager.add_task(task)
self.task_started = True
logger.info("动作回归任务已启动")
def get_action_state_by_chat_id(self, chat_id: str) -> ChatAction:
for action_state in self.action_state_list:
if action_state.chat_id == chat_id:
return action_state
new_action_state = ChatAction(chat_id)
self.action_state_list.append(new_action_state)
return new_action_state
init_prompt()
action_manager = ActionManager()
"""全局动作管理器"""

View File

@ -1,692 +0,0 @@
import asyncio
import json
from collections import deque
from datetime import datetime
from typing import Dict, List, Optional
from aiohttp import web, WSMsgType
import aiohttp_cors
from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
logger = get_logger("context_web")
class ContextMessage:
"""上下文消息类"""
def __init__(self, message: MessageRecv):
self.user_name = message.message_info.user_info.user_nickname
self.user_id = message.message_info.user_info.user_id
self.content = message.processed_plain_text
self.timestamp = datetime.now()
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
# 识别消息类型
self.is_gift = getattr(message, "is_gift", False)
self.is_superchat = getattr(message, "is_superchat", False)
# 添加礼物和SC相关信息
if self.is_gift:
self.gift_name = getattr(message, "gift_name", "")
self.gift_count = getattr(message, "gift_count", "1")
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
elif self.is_superchat:
self.superchat_price = getattr(message, "superchat_price", "0")
self.superchat_message = getattr(message, "superchat_message_text", "")
if self.superchat_message:
self.content = f"{self.superchat_price}] {self.superchat_message}"
else:
self.content = f"{self.superchat_price}] {self.content}"
def to_dict(self):
return {
"user_name": self.user_name,
"user_id": self.user_id,
"content": self.content,
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
"group_name": self.group_name,
"is_gift": self.is_gift,
"is_superchat": self.is_superchat,
}
class ContextWebManager:
"""上下文网页管理器"""
def __init__(self, max_messages: int = 10, port: int = 8765):
self.max_messages = max_messages
self.port = port
self.contexts: Dict[str, deque] = {} # chat_id -> deque of ContextMessage
self.websockets: List[web.WebSocketResponse] = []
self.app = None
self.runner = None
self.site = None
self._server_starting = False # 添加启动标志防止并发
async def start_server(self):
"""启动web服务器"""
if self.site is not None:
logger.debug("Web服务器已经启动跳过重复启动")
return
if self._server_starting:
logger.debug("Web服务器正在启动中等待启动完成...")
# 等待启动完成
while self._server_starting and self.site is None:
await asyncio.sleep(0.1)
return
self._server_starting = True
try:
self.app = web.Application()
# 设置CORS
cors = aiohttp_cors.setup(
self.app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
)
},
)
# 添加路由
self.app.router.add_get("/", self.index_handler)
self.app.router.add_get("/ws", self.websocket_handler)
self.app.router.add_get("/api/contexts", self.get_contexts_handler)
self.app.router.add_get("/debug", self.debug_handler)
# 为所有路由添加CORS
for route in list(self.app.router.routes()):
cors.add(route)
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", self.port)
await self.site.start()
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
except Exception as e:
logger.error(f"❌ 启动Web服务器失败: {e}")
# 清理部分启动的资源
if self.runner:
await self.runner.cleanup()
self.app = None
self.runner = None
self.site = None
raise
finally:
self._server_starting = False
async def stop_server(self):
"""停止web服务器"""
if self.site:
await self.site.stop()
if self.runner:
await self.runner.cleanup()
self.app = None
self.runner = None
self.site = None
self._server_starting = False
async def index_handler(self, request):
"""主页处理器"""
html_content = (
"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>聊天上下文</title>
<style>
html, body {
background: transparent !important;
background-color: transparent !important;
margin: 0;
padding: 20px;
font-family: 'Microsoft YaHei', Arial, sans-serif;
color: #ffffff;
text-shadow: 2px 2px 4px rgba(0,0,0,0.8);
}
.container {
max-width: 800px;
margin: 0 auto;
background: transparent !important;
}
.message {
background: rgba(0, 0, 0, 0.3);
margin: 10px 0;
padding: 15px;
border-radius: 10px;
border-left: 4px solid #00ff88;
backdrop-filter: blur(5px);
animation: slideIn 0.3s ease-out;
transform: translateY(0);
transition: transform 0.5s ease, opacity 0.5s ease;
}
.message:hover {
background: rgba(0, 0, 0, 0.5);
transform: translateX(5px);
transition: all 0.3s ease;
}
.message.gift {
border-left: 4px solid #ff8800;
background: rgba(255, 136, 0, 0.2);
}
.message.gift:hover {
background: rgba(255, 136, 0, 0.3);
}
.message.gift .username {
color: #ff8800;
}
.message.superchat {
border-left: 4px solid #ff6b6b;
background: linear-gradient(135deg, rgba(255, 107, 107, 0.2), rgba(107, 255, 107, 0.2), rgba(107, 107, 255, 0.2));
background-size: 200% 200%;
animation: rainbow 3s ease infinite;
}
.message.superchat:hover {
background: linear-gradient(135deg, rgba(255, 107, 107, 0.4), rgba(107, 255, 107, 0.4), rgba(107, 107, 255, 0.4));
background-size: 200% 200%;
}
.message.superchat .username {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4, #45b7d1, #96ceb4, #feca57);
background-size: 300% 300%;
animation: rainbow-text 2s ease infinite;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
@keyframes rainbow {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
@keyframes rainbow-text {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.message-line {
line-height: 1.4;
word-wrap: break-word;
font-size: 24px;
}
.username {
color: #00ff88;
}
.content {
color: #ffffff;
}
.new-message {
animation: slideInNew 0.6s ease-out;
}
.debug-btn {
position: fixed;
bottom: 20px;
right: 20px;
background: rgba(0, 0, 0, 0.7);
color: #00ff88;
font-size: 12px;
padding: 8px 12px;
border-radius: 20px;
backdrop-filter: blur(10px);
z-index: 1000;
text-decoration: none;
border: 1px solid #00ff88;
}
.debug-btn:hover {
background: rgba(0, 255, 136, 0.2);
}
@keyframes slideIn {
from {
opacity: 0;
transform: translateY(-20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes slideInNew {
from {
opacity: 0;
transform: translateY(50px) scale(0.95);
}
to {
opacity: 1;
transform: translateY(0) scale(1);
}
}
.no-messages {
text-align: center;
color: #666;
font-style: italic;
margin-top: 50px;
}
</style>
</head>
<body>
<div class="container">
<a href="/debug" class="debug-btn">🔧 调试</a>
<div id="messages">
<div class="no-messages">暂无消息</div>
</div>
</div>
<script>
let ws;
let reconnectInterval;
let currentMessages = []; // 存储当前显示的消息
function connectWebSocket() {
console.log('正在连接WebSocket...');
ws = new WebSocket('ws://localhost:"""
+ str(self.port)
+ """/ws');
ws.onopen = function() {
console.log('WebSocket连接已建立');
if (reconnectInterval) {
clearInterval(reconnectInterval);
reconnectInterval = null;
}
};
ws.onmessage = function(event) {
console.log('收到WebSocket消息:', event.data);
try {
const data = JSON.parse(event.data);
updateMessages(data.contexts);
} catch (e) {
console.error('解析消息失败:', e, event.data);
}
};
ws.onclose = function(event) {
console.log('WebSocket连接关闭:', event.code, event.reason);
if (!reconnectInterval) {
reconnectInterval = setInterval(connectWebSocket, 3000);
}
};
ws.onerror = function(error) {
console.error('WebSocket错误:', error);
};
}
function updateMessages(contexts) {
const messagesDiv = document.getElementById('messages');
if (!contexts || contexts.length === 0) {
messagesDiv.innerHTML = '<div class="no-messages">暂无消息</div>';
currentMessages = [];
return;
}
// 如果是第一次加载或者消息完全不同进行完全重新渲染
if (currentMessages.length === 0) {
console.log('首次加载消息,数量:', contexts.length);
messagesDiv.innerHTML = '';
contexts.forEach(function(msg) {
const messageDiv = createMessageElement(msg);
messagesDiv.appendChild(messageDiv);
});
currentMessages = [...contexts];
window.scrollTo(0, document.body.scrollHeight);
return;
}
// 检测新消息 - 使用更可靠的方法
const newMessages = findNewMessages(contexts, currentMessages);
if (newMessages.length > 0) {
console.log('添加新消息,数量:', newMessages.length);
// 先检查是否需要移除老消息保持DOM清洁
const maxDisplayMessages = 15; // 比服务器端稍多一些确保流畅性
const currentMessageElements = messagesDiv.querySelectorAll('.message');
const willExceedLimit = currentMessageElements.length + newMessages.length > maxDisplayMessages;
if (willExceedLimit) {
const removeCount = (currentMessageElements.length + newMessages.length) - maxDisplayMessages;
console.log('需要移除老消息数量:', removeCount);
for (let i = 0; i < removeCount && i < currentMessageElements.length; i++) {
const oldMessage = currentMessageElements[i];
oldMessage.style.transition = 'opacity 0.3s ease, transform 0.3s ease';
oldMessage.style.opacity = '0';
oldMessage.style.transform = 'translateY(-20px)';
setTimeout(() => {
if (oldMessage.parentNode) {
oldMessage.parentNode.removeChild(oldMessage);
}
}, 300);
}
}
// 添加新消息
newMessages.forEach(function(msg) {
const messageDiv = createMessageElement(msg, true); // true表示是新消息
messagesDiv.appendChild(messageDiv);
// 移除动画类避免重复动画
setTimeout(() => {
messageDiv.classList.remove('new-message');
}, 600);
});
// 更新当前消息列表
currentMessages = [...contexts];
// 平滑滚动到底部
setTimeout(() => {
window.scrollTo({
top: document.body.scrollHeight,
behavior: 'smooth'
});
}, 100);
}
}
function findNewMessages(contexts, currentMessages) {
// 如果当前消息为空所有消息都是新的
if (currentMessages.length === 0) {
return contexts;
}
// 找到最后一条当前消息在新消息列表中的位置
const lastCurrentMsg = currentMessages[currentMessages.length - 1];
let lastIndex = -1;
// 从后往前找因为新消息通常在末尾
for (let i = contexts.length - 1; i >= 0; i--) {
const msg = contexts[i];
if (msg.user_id === lastCurrentMsg.user_id &&
msg.content === lastCurrentMsg.content &&
msg.timestamp === lastCurrentMsg.timestamp) {
lastIndex = i;
break;
}
}
// 如果找到了返回之后的消息否则返回所有消息可能是完全刷新
if (lastIndex >= 0) {
return contexts.slice(lastIndex + 1);
} else {
console.log('未找到匹配的最后消息,可能需要完全刷新');
return contexts.slice(Math.max(0, contexts.length - (currentMessages.length + 1)));
}
}
function createMessageElement(msg, isNew = false) {
const messageDiv = document.createElement('div');
let className = 'message';
// 根据消息类型添加对应的CSS类
if (msg.is_gift) {
className += ' gift';
} else if (msg.is_superchat) {
className += ' superchat';
}
if (isNew) {
className += ' new-message';
}
messageDiv.className = className;
messageDiv.innerHTML = `
<div class="message-line">
<span class="username">${escapeHtml(msg.user_name)}</span><span class="content">${escapeHtml(msg.content)}</span>
</div>
`;
return messageDiv;
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
// 初始加载数据
fetch('/api/contexts')
.then(response => response.json())
.then(data => {
console.log('初始数据加载成功:', data);
updateMessages(data.contexts);
})
.catch(err => console.error('加载初始数据失败:', err));
// 连接WebSocket
connectWebSocket();
</script>
</body>
</html>
"""
)
return web.Response(text=html_content, content_type="text/html")
async def websocket_handler(self, request):
"""WebSocket处理器"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.websockets.append(ws)
logger.debug(f"WebSocket连接建立当前连接数: {len(self.websockets)}")
# 发送初始数据
await self.send_contexts_to_websocket(ws)
async for msg in ws:
if msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket错误: {ws.exception()}")
break
# 清理断开的连接
if ws in self.websockets:
self.websockets.remove(ws)
logger.debug(f"WebSocket连接断开当前连接数: {len(self.websockets)}")
return ws
async def get_contexts_handler(self, request):
"""获取上下文API"""
all_context_msgs = []
for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
return web.json_response({"contexts": contexts_data})
async def debug_handler(self, request):
"""调试信息处理器"""
debug_info = {
"server_status": "running",
"websocket_connections": len(self.websockets),
"total_chats": len(self.contexts),
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
}
# 构建聊天详情HTML
chats_html = ""
for chat_id, contexts in self.contexts.items():
messages_html = ""
for msg in contexts:
timestamp = msg.timestamp.strftime("%H:%M:%S")
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
chats_html += f"""
<div class="chat">
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
{messages_html}
</div>
"""
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>调试信息</title>
<style>
body {{ font-family: monospace; margin: 20px; }}
.section {{ margin: 20px 0; padding: 10px; border: 1px solid #ccc; }}
.chat {{ margin: 10px 0; padding: 10px; background: #f5f5f5; }}
.message {{ margin: 5px 0; padding: 5px; background: white; }}
</style>
</head>
<body>
<h1>上下文网页管理器调试信息</h1>
<div class="section">
<h2>服务器状态</h2>
<p>状态: {debug_info["server_status"]}</p>
<p>WebSocket连接数: {debug_info["websocket_connections"]}</p>
<p>聊天总数: {debug_info["total_chats"]}</p>
<p>消息总数: {debug_info["total_messages"]}</p>
</div>
<div class="section">
<h2>聊天详情</h2>
{chats_html}
</div>
<div class="section">
<h2>操作</h2>
<button onclick="location.reload()">刷新页面</button>
<button onclick="window.location.href='/'">返回主页</button>
<button onclick="window.location.href='/api/contexts'">查看API数据</button>
</div>
<script>
console.log('调试信息:', {json.dumps(debug_info, ensure_ascii=False, indent=2)});
setTimeout(() => location.reload(), 5000); // 5秒自动刷新
</script>
</body>
</html>
"""
return web.Response(text=html_content, content_type="text/html")
async def add_message(self, chat_id: str, message: MessageRecv):
"""添加新消息到上下文"""
if chat_id not in self.contexts:
self.contexts[chat_id] = deque(maxlen=self.max_messages)
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
context_msg = ContextMessage(message)
self.contexts[chat_id].append(context_msg)
# 统计当前总消息数
total_messages = sum(len(contexts) for contexts in self.contexts.values())
logger.info(
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
)
# 调试:打印当前所有消息
logger.info("📝 当前上下文中的所有消息:")
for cid, contexts in self.contexts.items():
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
for i, msg in enumerate(contexts):
logger.info(
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
)
# 广播更新给所有WebSocket连接
await self.broadcast_contexts()
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
"""向单个WebSocket发送上下文数据"""
all_context_msgs = []
for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data}
await ws.send_str(json.dumps(data, ensure_ascii=False))
async def broadcast_contexts(self):
"""向所有WebSocket连接广播上下文更新"""
if not self.websockets:
logger.debug("没有WebSocket连接跳过广播")
return
all_context_msgs = []
for _chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data}
message = json.dumps(data, ensure_ascii=False)
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
# 创建WebSocket列表的副本避免在遍历时修改
websockets_copy = self.websockets.copy()
removed_count = 0
for ws in websockets_copy:
if ws.closed:
if ws in self.websockets:
self.websockets.remove(ws)
removed_count += 1
else:
try:
await ws.send_str(message)
logger.debug("消息发送成功")
except Exception as e:
logger.error(f"发送WebSocket消息失败: {e}")
if ws in self.websockets:
self.websockets.remove(ws)
removed_count += 1
if removed_count > 0:
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
# 全局实例
_context_web_manager: Optional[ContextWebManager] = None
def get_context_web_manager() -> ContextWebManager:
"""获取上下文网页管理器实例"""
global _context_web_manager
if _context_web_manager is None:
_context_web_manager = ContextWebManager()
return _context_web_manager
async def init_context_web_manager():
"""初始化上下文网页管理器"""
manager = get_context_web_manager()
await manager.start_server()
return manager

View File

@ -1,147 +0,0 @@
import asyncio
from typing import Dict, Tuple, Callable, Optional
from dataclasses import dataclass
from src.chat.message_receive.message import MessageRecvS4U
from src.common.logger import get_logger
logger = get_logger("gift_manager")
@dataclass
class PendingGift:
"""等待中的礼物消息"""
message: MessageRecvS4U
total_count: int
timer_task: asyncio.Task
callback: Callable[[MessageRecvS4U], None]
class GiftManager:
"""礼物管理器,提供防抖功能"""
def __init__(self):
"""初始化礼物管理器"""
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
self.debounce_timeout = 5.0 # 3秒防抖时间
async def handle_gift(
self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None
) -> bool:
"""处理礼物消息,返回是否应该立即处理
Args:
message: 礼物消息
callback: 防抖完成后的回调函数
Returns:
bool: False表示消息被暂存等待防抖True表示应该立即处理
"""
if not message.is_gift:
return True
# 构建礼物的唯一键:(发送人ID, 礼物名称)
gift_key = (message.message_info.user_info.user_id, message.gift_name)
# 如果已经有相同的礼物在等待中,则合并
if gift_key in self.pending_gifts:
await self._merge_gift(gift_key, message)
return False
# 创建新的等待礼物
await self._create_pending_gift(gift_key, message, callback)
return False
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
"""合并礼物消息"""
pending_gift = self.pending_gifts[gift_key]
# 取消之前的定时器
if not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel()
# 累加礼物数量
try:
new_count = int(new_message.gift_count)
pending_gift.total_count += new_count
# 更新消息为最新的(保留最新的消息,但累加数量)
pending_gift.message = new_message
pending_gift.message.gift_count = str(pending_gift.total_count)
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
except ValueError:
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
# 如果无法解析数量,保持原有数量不变
# 重新创建定时器
pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key))
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
async def _create_pending_gift(
self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]]
) -> None:
"""创建新的等待礼物"""
try:
initial_count = int(message.gift_count)
except ValueError:
initial_count = 1
logger.warning(f"无法解析礼物数量: {message.gift_count}默认设为1")
# 创建定时器任务
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
# 创建等待礼物对象
pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback)
self.pending_gifts[gift_key] = pending_gift
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
"""礼物防抖超时处理"""
try:
# 等待防抖时间
await asyncio.sleep(self.debounce_timeout)
# 获取等待中的礼物
if gift_key not in self.pending_gifts:
return
pending_gift = self.pending_gifts.pop(gift_key)
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
message = pending_gift.message
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
# 执行回调
if pending_gift.callback:
try:
pending_gift.callback(message)
except Exception as e:
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
except asyncio.CancelledError:
# 定时器被取消,不需要处理
pass
except Exception as e:
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
def get_pending_count(self) -> int:
"""获取当前等待中的礼物数量"""
return len(self.pending_gifts)
async def flush_all(self) -> None:
"""立即处理所有等待中的礼物"""
for gift_key in list(self.pending_gifts.keys()):
pending_gift = self.pending_gifts.get(gift_key)
if pending_gift and not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel()
await self._gift_timeout(gift_key)
# 创建全局礼物管理器实例
gift_manager = GiftManager()

View File

@ -1,15 +0,0 @@
class InternalManager:
def __init__(self):
self.now_internal_state = str()
def set_internal_state(self, internal_state: str):
self.now_internal_state = internal_state
def get_internal_state(self):
return self.now_internal_state
def get_internal_state_str(self):
return f"你今天的直播内容是直播QQ水群你正在一边回复弹幕一边在QQ群聊天你在QQ群聊天中产生的想法是{self.now_internal_state}"
internal_manager = InternalManager()

View File

@ -1,584 +0,0 @@
import asyncio
import traceback
import time
import random
from typing import Optional, Dict, Tuple, List # 导入类型提示
from maim_message import UserInfo, Seg
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from .s4u_stream_generator import S4UStreamGenerator
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U
from src.config.config import global_config
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from .s4u_watching_manager import watching_manager
import json
from .s4u_mood_manager import mood_manager
from src.mais4u.s4u_config import s4u_config
from src.person_info.person_info import get_person_id
from .yes_or_no import yes_or_no_head
logger = get_logger("S4U_chat")
class MessageSenderContainer:
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
self.chat_stream = chat_stream
self.original_message = original_message
self.queue = asyncio.Queue()
self.storage = MessageStorage()
self._task: Optional[asyncio.Task] = None
self._paused_event = asyncio.Event()
self._paused_event.set() # 默认设置为非暂停状态
self.msg_id = ""
self.last_msg_id = ""
self.voice_done = ""
async def add_message(self, chunk: str):
"""向队列中添加一个消息块。"""
await self.queue.put(chunk)
async def close(self):
"""表示没有更多消息了,关闭队列。"""
await self.queue.put(None) # Sentinel
def pause(self):
"""暂停发送。"""
self._paused_event.clear()
def resume(self):
"""恢复发送。"""
self._paused_event.set()
def _calculate_typing_delay(self, text: str) -> float:
"""根据文本长度计算模拟打字延迟。"""
chars_per_second = s4u_config.chars_per_second
min_delay = s4u_config.min_typing_delay
max_delay = s4u_config.max_typing_delay
delay = len(text) / chars_per_second
return max(min_delay, min(delay, max_delay))
async def _send_worker(self):
"""从队列中取出消息并发送。"""
while True:
try:
# This structure ensures that task_done() is called for every item retrieved,
# even if the worker is cancelled while processing the item.
chunk = await self.queue.get()
except asyncio.CancelledError:
break
try:
if chunk is None:
break
# Check for pause signal *after* getting an item.
await self._paused_event.wait()
# 根据配置选择延迟模式
if s4u_config.enable_dynamic_typing_delay:
delay = self._calculate_typing_delay(chunk)
else:
delay = s4u_config.typing_delay
await asyncio.sleep(delay)
message_segment = Seg(type="tts_text", data=f"{self.msg_id}:{chunk}")
bot_message = MessageSending(
message_id=self.msg_id,
chat_stream=self.chat_stream,
bot_user_info=UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=self.original_message.message_info.platform,
),
sender_info=self.original_message.message_info.user_info,
message_segment=message_segment,
reply=self.original_message,
is_emoji=False,
apply_set_reply_logic=True,
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
)
await bot_message.process()
await get_global_api().send_message(bot_message)
logger.info(f"已将消息 '{self.msg_id}:{chunk}' 发往平台 '{bot_message.message_info.platform}'")
message_segment = Seg(type="text", data=chunk)
bot_message = MessageSending(
message_id=self.msg_id,
chat_stream=self.chat_stream,
bot_user_info=UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=self.original_message.message_info.platform,
),
sender_info=self.original_message.message_info.user_info,
message_segment=message_segment,
reply=self.original_message,
is_emoji=False,
apply_set_reply_logic=True,
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
)
await bot_message.process()
await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e:
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
finally:
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
self.queue.task_done()
def start(self):
"""启动发送任务。"""
if self._task is None:
self._task = asyncio.create_task(self._send_worker())
async def join(self):
"""等待所有消息发送完毕。"""
if self._task:
await self._task
class S4UChatManager:
def __init__(self):
self.s4u_chats: Dict[str, "S4UChat"] = {}
def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
if chat_stream.stream_id not in self.s4u_chats:
stream_name = get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
logger.info(f"Creating new S4UChat for stream: {stream_name}")
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
return self.s4u_chats[chat_stream.stream_id]
if not s4u_config.enable_s4u:
s4u_chat_manager = None
else:
s4u_chat_manager = S4UChatManager()
def get_s4u_chat_manager() -> S4UChatManager:
return s4u_chat_manager
class S4UChat:
def __init__(self, chat_stream: ChatStream):
"""初始化 S4UChat 实例。"""
self.chat_stream = chat_stream
self.stream_id = chat_stream.stream_id
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
# 两个消息队列
self._vip_queue = asyncio.PriorityQueue()
self._normal_queue = asyncio.PriorityQueue()
self._entry_counter = 0 # 保证FIFO的全局计数器
self._new_message_event = asyncio.Event() # 用于唤醒处理器
self._processing_task = asyncio.create_task(self._message_processor())
self._current_generation_task: Optional[asyncio.Task] = None
# 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象)
self._current_message_being_replied: Optional[Tuple[str, float, int, MessageRecv]] = None
self._is_replying = False
self.gpt = S4UStreamGenerator()
self.gpt.chat_stream = self.chat_stream
self.interest_dict: Dict[str, float] = {} # 用户兴趣分
self.internal_message: List[MessageRecvS4U] = []
self.msg_id = ""
self.voice_done = ""
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
def _get_priority_info(self, message: MessageRecv) -> dict:
"""安全地从消息中提取和解析 priority_info"""
priority_info_raw = message.priority_info
priority_info = {}
if isinstance(priority_info_raw, str):
try:
priority_info = json.loads(priority_info_raw)
except json.JSONDecodeError:
logger.warning(f"Failed to parse priority_info JSON: {priority_info_raw}")
elif isinstance(priority_info_raw, dict):
priority_info = priority_info_raw
return priority_info
def _is_vip(self, priority_info: dict) -> bool:
"""检查消息是否来自VIP用户。"""
return priority_info.get("message_type") == "vip"
def _get_interest_score(self, user_id: str) -> float:
"""获取用户的兴趣分默认为1.0"""
return self.interest_dict.get(user_id, 1.0)
def go_processing(self):
if self.voice_done == self.last_msg_id:
return True
return False
def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float:
"""
为消息计算基础优先级分数分数越高优先级越高
"""
score = 0.0
# 加上消息自带的优先级
score += priority_info.get("message_priority", 0.0)
# 加上用户的固有兴趣分
score += self._get_interest_score(message.message_info.user_info.user_id)
return score
def decay_interest_score(self):
for person_id, score in self.interest_dict.items():
if score > 0:
self.interest_dict[person_id] = score * 0.95
else:
self.interest_dict[person_id] = 0
async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
self.decay_interest_score()
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
user_id = message.message_info.user_info.user_id
platform = message.message_info.platform
_person_id = get_person_id(platform, user_id)
# try:
# is_gift = message.is_gift
# is_superchat = message.is_superchat
# # print(is_gift)
# # print(is_superchat)
# if is_gift:
# await self.relationship_builder.build_relation(immediate_build=person_id)
# # 安全地增加兴趣分如果person_id不存在则先初始化为1.0
# current_score = self.interest_dict.get(person_id, 1.0)
# self.interest_dict[person_id] = current_score + 0.1 * message.gift_count
# elif is_superchat:
# await self.relationship_builder.build_relation(immediate_build=person_id)
# # 安全地增加兴趣分如果person_id不存在则先初始化为1.0
# current_score = self.interest_dict.get(person_id, 1.0)
# self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
# # 添加SuperChat到管理器
# super_chat_manager = get_super_chat_manager()
# await super_chat_manager.add_superchat(message)
# else:
# await self.relationship_builder.build_relation(20)
# except Exception:
# traceback.print_exc()
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
priority_info = self._get_priority_info(message)
is_vip = self._is_vip(priority_info)
new_priority_score = self._calculate_base_priority_score(message, priority_info)
should_interrupt = False
if (
s4u_config.enable_message_interruption
and self._current_generation_task
and not self._current_generation_task.done()
):
if self._current_message_being_replied:
current_queue, current_priority, _, current_msg = self._current_message_being_replied
# 规则VIP从不被打断
if current_queue == "vip":
pass # Do nothing
# 规则:普通消息可以被打断
elif current_queue == "normal":
# VIP消息可以打断普通消息
if is_vip:
should_interrupt = True
logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.")
# 普通消息的内部打断逻辑
else:
new_sender_id = message.message_info.user_info.user_id
current_sender_id = current_msg.message_info.user_info.user_id
# 新消息优先级更高
if new_priority_score > current_priority:
should_interrupt = True
logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.")
# 同用户,新消息的优先级不能更低
elif new_sender_id == current_sender_id and new_priority_score >= current_priority:
should_interrupt = True
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
if should_interrupt:
if self.gpt.partial_response:
logger.warning(
f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'"
)
self._current_generation_task.cancel()
# asyncio.PriorityQueue 是最小堆,所以我们存入分数的相反数
# 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前
item = (-new_priority_score, self._entry_counter, time.time(), message)
if is_vip and s4u_config.vip_queue_priority:
await self._vip_queue.put(item)
logger.info(f"[{self.stream_name}] VIP message added to queue.")
else:
await self._normal_queue.put(item)
self._entry_counter += 1
self._new_message_event.set() # 唤醒处理器
def _cleanup_old_normal_messages(self):
"""清理普通队列中不在最近N条消息范围内的消息"""
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
return
# 计算阈值:保留最近 recent_message_keep_count 条消息
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
# 临时存储需要保留的消息
temp_messages = []
removed_count = 0
# 取出所有普通队列中的消息
while not self._normal_queue.empty():
try:
item = self._normal_queue.get_nowait()
neg_priority, entry_count, timestamp, message = item
# 如果消息在最近N条消息范围内保留它
logger.info(
f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}"
)
if entry_count >= cutoff_counter:
temp_messages.append(item)
else:
removed_count += 1
self._normal_queue.task_done() # 标记被移除的任务为完成
except asyncio.QueueEmpty:
break
# 将保留的消息重新放入队列
for item in temp_messages:
self._normal_queue.put_nowait(item)
if removed_count > 0:
logger.info(
f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}现在counter:{self._entry_counter}被移除"
)
logger.info(
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range."
)
async def _message_processor(self):
"""调度器优先处理VIP队列然后处理普通队列。"""
while True:
try:
# 等待有新消息的信号,避免空转
await self._new_message_event.wait()
self._new_message_event.clear()
# 清理普通队列中的过旧消息
self._cleanup_old_normal_messages()
# 优先处理VIP队列
if not self._vip_queue.empty():
neg_priority, entry_count, _, message = self._vip_queue.get_nowait()
priority = -neg_priority
queue_name = "vip"
# 其次处理普通队列
elif not self._normal_queue.empty():
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
priority = -neg_priority
# 检查普通消息是否超时
if time.time() - timestamp > s4u_config.message_timeout_seconds:
logger.info(
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
)
self._normal_queue.task_done()
continue # 处理下一条
queue_name = "normal"
else:
if self.internal_message:
message = self.internal_message[-1]
self.internal_message = []
priority = 0
neg_priority = 0
entry_count = 0
queue_name = "internal"
logger.info(
f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..."
)
else:
continue # 没有消息了,回去等事件
self._current_message_being_replied = (queue_name, priority, entry_count, message)
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
try:
await self._current_generation_task
except asyncio.CancelledError:
logger.info(
f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded."
)
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
except Exception as e:
logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True)
finally:
self._current_generation_task = None
self._current_message_being_replied = None
# 标记任务完成
if queue_name == "vip":
self._vip_queue.task_done()
elif queue_name == "internal":
# 如果使用 internal_message 生成回复,则不从 normal 队列中移除
pass
else:
self._normal_queue.task_done()
# 检查是否还有任务,有则立即再次触发事件
if not self._vip_queue.empty() or not self._normal_queue.empty():
self._new_message_event.set()
except asyncio.CancelledError:
logger.info(f"[{self.stream_name}] Message processor is shutting down.")
break
except Exception as e:
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
await asyncio.sleep(1)
def get_processing_message_id(self):
self.last_msg_id = self.msg_id
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
async def _generate_and_send(self, message: MessageRecv):
"""为单个消息生成文本回复。整个过程可以被中断。"""
self._is_replying = True
total_chars_sent = 0 # 跟踪发送的总字符数
self.get_processing_message_id()
# 视线管理:开始生成回复时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
if message.is_internal:
await chat_watching.on_internal_message_start()
else:
await chat_watching.on_reply_start()
sender_container = MessageSenderContainer(self.chat_stream, message)
sender_container.start()
async def generate_and_send_inner():
nonlocal total_chars_sent
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
if s4u_config.enable_streaming_output:
logger.info("[S4U] 开始流式输出")
# 流式输出,边生成边发送
gen = self.gpt.generate_response(message, "")
async for chunk in gen:
sender_container.msg_id = self.msg_id
await sender_container.add_message(chunk)
total_chars_sent += len(chunk)
else:
logger.info("[S4U] 开始一次性输出")
# 一次性输出先收集所有chunk
all_chunks = []
gen = self.gpt.generate_response(message, "")
async for chunk in gen:
all_chunks.append(chunk)
total_chars_sent += len(chunk)
# 一次性发送
sender_container.msg_id = self.msg_id
await sender_container.add_message("".join(all_chunks))
try:
try:
await asyncio.wait_for(generate_and_send_inner(), timeout=10)
except asyncio.TimeoutError:
logger.warning(f"[{self.stream_name}] 回复生成超时,发送默认回复。")
sender_container.msg_id = self.msg_id
await sender_container.add_message("麦麦不知道哦")
total_chars_sent = len("麦麦不知道哦")
mood = mood_manager.get_mood_by_chat_id(self.stream_id)
await yes_or_no_head(
text=total_chars_sent,
emotion=mood.mood_state,
chat_history=message.processed_plain_text,
chat_id=self.stream_id,
)
# 等待所有文本消息发送完成
await sender_container.close()
await sender_container.join()
await chat_watching.on_thinking_finished()
start_time = time.time()
logged = False
while not self.go_processing():
if time.time() - start_time > 60:
logger.warning(f"[{self.stream_name}] 等待消息发送超时60秒强制跳出循环。")
break
if not logged:
logger.info(f"[{self.stream_name}] 等待消息发送完成...")
logged = True
await asyncio.sleep(0.2)
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
except asyncio.CancelledError:
logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。")
raise # 将取消异常向上传播
except Exception as e:
traceback.print_exc()
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
# 回复生成实时展示:清空内容(出错时)
finally:
self._is_replying = False
# 视线管理:回复结束时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
await chat_watching.on_reply_finished()
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
sender_container.resume()
if not sender_container._task.done():
await sender_container.close()
await sender_container.join()
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
async def shutdown(self):
"""平滑关闭处理任务。"""
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
# 取消正在运行的任务
if self._current_generation_task and not self._current_generation_task.done():
self._current_generation_task.cancel()
if self._processing_task and not self._processing_task.done():
self._processing_task.cancel()
# 等待任务响应取消
try:
await self._processing_task
except asyncio.CancelledError:
logger.info(f"处理任务已成功取消: {self.stream_name}")

View File

@ -1,456 +0,0 @@
import asyncio
import json
import time
from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
from src.mais4u.s4u_config import s4u_config
"""
情绪管理系统使用说明
1. 情绪数值系统
- 情绪包含四个维度joy(), anger(), sorrow(), fear()
- 每个维度的取值范围为1-10
- 当情绪发生变化时会自动发送到ws端处理
2. 情绪更新机制
- 接收到新消息时会更新情绪状态
- 定期进行情绪回归冷静下来
- 每次情绪变化都会发送到ws端格式为
type: "emotion"
data: {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
3. ws端处理
- 本地只负责情绪计算和发送情绪数值
- 表情渲染和动作由ws端根据情绪数值处理
"""
logger = get_logger("mood")
def init_prompt():
Prompt(
"""
{chat_talking_prompt}
以上是直播间里正在进行的对话
{indentify_block}
你刚刚的情绪状态是{mood_state}
现在发送了消息引起了你的注意你对其进行了阅读和思考请你输出一句话描述你新的情绪状态不要输出任何其他内容
请只输出情绪状态不要输出其他内容
""",
"change_mood_prompt_vtb",
)
Prompt(
"""
{chat_talking_prompt}
以上是直播间里最近的对话
{indentify_block}
你之前的情绪状态是{mood_state}
距离你上次关注直播间消息已经过去了一段时间你冷静了下来请你输出一句话描述你现在的情绪状态
请只输出情绪状态不要输出其他内容
""",
"regress_mood_prompt_vtb",
)
Prompt(
"""
{chat_talking_prompt}
以上是直播间里正在进行的对话
{indentify_block}
你刚刚的情绪状态是{mood_state}
具体来说从1-10你的情绪状态是
(Joy): {joy}
(Anger): {anger}
(Sorrow): {sorrow}
(Fear): {fear}
现在发送了消息引起了你的注意你对其进行了阅读和思考请基于对话内容评估你新的情绪状态
请以JSON格式输出你新的情绪状态包含"喜怒哀惧"四个维度每个维度的取值范围为1-10
键值请使用英文: "joy", "anger", "sorrow", "fear".
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
不要输出任何其他内容只输出JSON
""",
"change_mood_numerical_prompt",
)
Prompt(
"""
{chat_talking_prompt}
以上是直播间里最近的对话
{indentify_block}
你之前的情绪状态是{mood_state}
具体来说从1-10你的情绪状态是
(Joy): {joy}
(Anger): {anger}
(Sorrow): {sorrow}
(Fear): {fear}
距离你上次关注直播间消息已经过去了一段时间你冷静了下来请基于此评估你现在的情绪状态
请以JSON格式输出你新的情绪状态包含"喜怒哀惧"四个维度每个维度的取值范围为1-10
键值请使用英文: "joy", "anger", "sorrow", "fear".
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
不要输出任何其他内容只输出JSON
""",
"regress_mood_numerical_prompt",
)
class ChatMood:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
self.mood_state: str = "感觉很平静"
self.mood_values: dict[str, int] = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
self.regression_count: int = 0
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
self.mood_model_numerical = LLMRequest(
model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
)
self.last_change_time: float = 0
# 发送初始情绪状态到ws端
asyncio.create_task(self.send_emotion_update(self.mood_values))
def _parse_numerical_mood(self, response: str) -> dict[str, int] | None:
try:
# The LLM might output markdown with json inside
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
data = json.loads(response)
# Validate
required_keys = {"joy", "anger", "sorrow", "fear"}
if not required_keys.issubset(data.keys()):
logger.warning(f"Numerical mood response missing keys: {response}")
return None
for key in required_keys:
value = data[key]
if not isinstance(value, int) or not (1 <= value <= 10):
logger.warning(f"Numerical mood response invalid value for {key}: {value} in {response}")
return None
return {key: data[key] for key in required_keys}
except json.JSONDecodeError:
logger.warning(f"Failed to parse numerical mood JSON: {response}")
return None
except Exception as e:
logger.error(f"Error parsing numerical mood: {e}, response: {response}")
return None
async def update_mood_by_message(self, message: MessageRecv):
self.regression_count = 0
message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=10,
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def _update_text_mood():
prompt = await global_prompt_manager.format_prompt(
"change_mood_prompt_vtb",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
)
logger.debug(f"text mood prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text mood response: {response}")
logger.debug(f"text mood reasoning_content: {reasoning_content}")
return response
async def _update_numerical_mood():
prompt = await global_prompt_manager.format_prompt(
"change_mood_numerical_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
joy=self.mood_values["joy"],
anger=self.mood_values["anger"],
sorrow=self.mood_values["sorrow"],
fear=self.mood_values["fear"],
)
logger.debug(f"numerical mood prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt, temperature=0.4
)
logger.info(f"numerical mood response: {response}")
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
return self._parse_numerical_mood(response)
results = await asyncio.gather(_update_text_mood(), _update_numerical_mood())
text_mood_response, numerical_mood_response = results
if text_mood_response:
self.mood_state = text_mood_response
if numerical_mood_response:
_old_mood_values = self.mood_values.copy()
self.mood_values = numerical_mood_response
# 发送情绪更新到ws端
await self.send_emotion_update(self.mood_values)
logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}")
self.last_change_time = message_time
async def regress_mood(self):
message_time = time.time()
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
limit=5,
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def _regress_text_mood():
prompt = await global_prompt_manager.format_prompt(
"regress_mood_prompt_vtb",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
)
logger.debug(f"text regress prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text regress response: {response}")
logger.debug(f"text regress reasoning_content: {reasoning_content}")
return response
async def _regress_numerical_mood():
prompt = await global_prompt_manager.format_prompt(
"regress_mood_numerical_prompt",
chat_talking_prompt=chat_talking_prompt,
indentify_block=indentify_block,
mood_state=self.mood_state,
joy=self.mood_values["joy"],
anger=self.mood_values["anger"],
sorrow=self.mood_values["sorrow"],
fear=self.mood_values["fear"],
)
logger.debug(f"numerical regress prompt: {prompt}")
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt,
temperature=0.4,
)
logger.info(f"numerical regress response: {response}")
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
return self._parse_numerical_mood(response)
results = await asyncio.gather(_regress_text_mood(), _regress_numerical_mood())
text_mood_response, numerical_mood_response = results
if text_mood_response:
self.mood_state = text_mood_response
if numerical_mood_response:
_old_mood_values = self.mood_values.copy()
self.mood_values = numerical_mood_response
# 发送情绪更新到ws端
await self.send_emotion_update(self.mood_values)
logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}")
self.regression_count += 1
async def send_emotion_update(self, mood_values: dict[str, int]):
"""发送情绪更新到ws端"""
emotion_data = {
"joy": mood_values.get("joy", 5),
"anger": mood_values.get("anger", 1),
"sorrow": mood_values.get("sorrow", 1),
"fear": mood_values.get("fear", 1),
}
await send_api.custom_to_stream(
message_type="emotion",
content=emotion_data,
stream_id=self.chat_id,
storage_message=False,
show_log=True,
)
logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}")
class MoodRegressionTask(AsyncTask):
def __init__(self, mood_manager: "MoodManager"):
super().__init__(task_name="MoodRegressionTask", run_interval=30)
self.mood_manager = mood_manager
self.run_count = 0
async def run(self):
self.run_count += 1
logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态")
now = time.time()
regression_executed = 0
for mood in self.mood_manager.mood_list:
chat_info = f"chat {mood.chat_id}"
if mood.last_change_time == 0:
logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归")
continue
time_since_last_change = now - mood.last_change_time
# 检查是否有极端情绪需要快速回归
high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8}
has_extreme_emotion = len(high_emotions) > 0
# 回归条件1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s
should_regress = False
regress_reason = ""
if time_since_last_change > 120:
should_regress = True
regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)"
elif has_extreme_emotion and time_since_last_change > 30:
should_regress = True
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)"
if should_regress:
if mood.regression_count >= 3:
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
continue
logger.info(
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
)
await mood.regress_mood()
regression_executed += 1
else:
if has_extreme_emotion:
remaining_time = 5 - time_since_last_change
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
logger.debug(
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}"
)
else:
remaining_time = 120 - time_since_last_change
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}")
if regression_executed > 0:
logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归")
else:
logger.debug("[回归任务] 本次没有符合回归条件的聊天")
class MoodManager:
def __init__(self):
self.mood_list: list[ChatMood] = []
"""当前情绪状态"""
self.task_started: bool = False
async def start(self):
"""启动情绪回归后台任务"""
if self.task_started:
return
logger.info("启动情绪管理任务...")
# 启动情绪回归任务
regression_task = MoodRegressionTask(self)
await async_task_manager.add_task(regression_task)
self.task_started = True
logger.info("情绪管理任务已启动(情绪回归)")
def get_mood_by_chat_id(self, chat_id: str) -> ChatMood:
for mood in self.mood_list:
if mood.chat_id == chat_id:
return mood
new_mood = ChatMood(chat_id)
self.mood_list.append(new_mood)
return new_mood
def reset_mood_by_chat_id(self, chat_id: str):
for mood in self.mood_list:
if mood.chat_id == chat_id:
mood.mood_state = "感觉很平静"
mood.mood_values = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
mood.regression_count = 0
# 发送重置后的情绪状态到ws端
asyncio.create_task(mood.send_emotion_update(mood.mood_values))
return
# 如果没有找到现有的mood创建新的
new_mood = ChatMood(chat_id)
self.mood_list.append(new_mood)
# 发送初始情绪状态到ws端
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
if s4u_config.enable_s4u:
init_prompt()
mood_manager = MoodManager()
else:
mood_manager = None
"""全局情绪管理器"""

View File

@ -1,255 +0,0 @@
import asyncio
import math
from typing import Tuple
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from maim_message.message_base import GroupInfo
from src.chat.message_receive.storage import MessageStorage
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.logger import get_logger
from src.config.config import global_config
from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
from src.mais4u.mais4u_chat.gift_manager import gift_manager
from src.mais4u.mais4u_chat.screen_manager import screen_manager
from .s4u_chat import get_s4u_chat_manager
# from ..message_receive.message_buffer import message_buffer
logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
"""计算消息的兴趣度
Args:
message: 待处理的消息对象
Returns:
Tuple[float, bool]: (兴趣度, 是否被提及)
"""
is_mentioned, _ = is_mentioned_bot_in_message(message)
interested_rate = 0.0
if global_config.memory.enable_memory:
with Timer("记忆激活"):
interested_rate, _, _ = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text,
fast_retrieval=True,
)
logger.debug(f"记忆激活率: {interested_rate:.2f}")
text_len = len(message.processed_plain_text)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5:
# 1-5字符线性增长 0.01 -> 0.03
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
elif text_len <= 10:
# 6-10字符线性增长 0.03 -> 0.06
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
elif text_len <= 20:
# 11-20字符线性增长 0.06 -> 0.12
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
elif text_len <= 30:
# 21-30字符线性增长 0.12 -> 0.18
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
elif text_len <= 50:
# 31-50字符线性增长 0.18 -> 0.22
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
elif text_len <= 100:
# 51-100字符线性增长 0.22 -> 0.26
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3)
interested_rate += base_interest
if is_mentioned:
interest_increase_on_mention = 1
interested_rate += interest_increase_on_mention
return interested_rate, is_mentioned
class S4UMessageProcessor:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
def __init__(self):
"""初始化心流处理器,创建消息存储实例"""
self.storage = MessageStorage()
async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None:
"""处理接收到的原始消息数据
主要流程:
1. 消息解析与初始化
2. 消息缓冲处理
3. 过滤检查
4. 兴趣度计算
5. 关系处理
Args:
message_data: 原始消息字符串
"""
# 1. 消息解析与初始化
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
message_info = message.message_info
chat = await get_chat_manager().get_or_create_stream(
platform=message_info.platform,
user_info=userinfo,
group_info=groupinfo,
)
if await self.handle_internal_message(message):
return
if await self.hadle_if_voice_done(message):
return
# 处理礼物消息,如果消息被暂存则停止当前处理流程
if not skip_gift_debounce and not await self.handle_if_gift(message):
return
await self.check_if_fake_gift(message)
# 处理屏幕消息
if await self.handle_screen_message(message):
return
await self.storage.store_message(message, chat)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
await s4u_chat.add_message(message)
_interested_rate, _ = await _calculate_interest(message)
await mood_manager.start()
# 一系列llm驱动的前处理
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message))
chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id)
asyncio.create_task(chat_action.update_action_by_message(message))
# 视线管理:收到消息时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id)
await chat_watching.on_message_received()
# 上下文网页管理启动独立task处理消息上下文
asyncio.create_task(self._handle_context_web_update(chat.stream_id, message))
# 日志记录
if message.is_gift:
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
else:
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
async def handle_internal_message(self, message: MessageRecvS4U):
if message.is_internal:
group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心")
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
message.message_info.group_info = s4u_chat.chat_stream.group_info
message.message_info.platform = s4u_chat.chat_stream.platform
s4u_chat.internal_message.append(message)
s4u_chat._new_message_event.set()
logger.info(
f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}"
)
return True
return False
async def handle_screen_message(self, message: MessageRecvS4U):
if message.is_screen:
screen_manager.set_screen(message.screen_info)
return True
return False
async def hadle_if_voice_done(self, message: MessageRecvS4U):
if message.voice_done:
s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
s4u_chat.voice_done = message.voice_done
return True
return False
async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool:
"""检查消息是否为假礼物"""
if message.is_gift:
return False
gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"]
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
message.is_fake_gift = True
return True
return False
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
"""处理礼物消息
Returns:
bool: True表示应该继续处理消息False表示消息已被暂存不需要继续处理
"""
if message.is_gift:
# 定义防抖完成后的回调函数
def gift_callback(merged_message: MessageRecvS4U):
"""礼物防抖完成后的回调"""
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
# 交给礼物管理器处理,并传入回调函数
# 对于礼物消息handle_gift 总是返回 False消息被暂存
await gift_manager.handle_gift(message, gift_callback)
return False # 消息被暂存,不继续处理
return True # 非礼物消息,继续正常处理
async def _handle_context_web_update(self, chat_id: str, message: MessageRecv):
"""处理上下文网页更新的独立task
Args:
chat_id: 聊天ID
message: 消息对象
"""
try:
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
context_manager = get_context_web_manager()
# 只在服务器未启动时启动(避免重复启动)
if context_manager.site is None:
logger.info("🚀 首次启动上下文网页服务器...")
await context_manager.start_server()
# 添加消息到上下文并更新网页
await asyncio.sleep(1.5)
await context_manager.add_message(chat_id, message)
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
except Exception as e:
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)

View File

@ -1,400 +0,0 @@
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
from src.chat.utils.utils import get_recent_group_speaker
from src.chat.memory_system.Hippocampus import hippocampus_manager
import random
from datetime import datetime
import asyncio
from src.mais4u.s4u_config import s4u_config
from src.chat.message_receive.message import MessageRecvS4U
from src.person_info.person_info import Person, get_person_id
from src.chat.message_receive.chat_stream import ChatStream
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
from src.mais4u.mais4u_chat.screen_manager import screen_manager
from src.chat.express.expression_selector import expression_selector
from .s4u_mood_manager import mood_manager
from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.data_models.database_data_model import DatabaseMessages
from typing import List
logger = get_logger("prompt")
def init_prompt():
Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
Prompt("\n关于你们的关系,你需要知道:\n{relation_info}\n", "relation_prompt")
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
Prompt(
"""
你的名字是麦麦, 是千石可乐开发的程序可以在QQ微信等平台发言你现在正在哔哩哔哩作为虚拟主播进行直播
虽然是个AI但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
你有一头法式橙色卷发你很可爱穿戴英式侦探风格学院风裙子和帽子你正在一个教室场景里进行虚拟线上直播不过不要刻意提及这些内容自然回复
你可以看见用户发送的弹幕礼物和superchat
{screen_info}
{internal_state}
{relation_info_block}
{memory_block}
{expression_habits_block}
你现在的主要任务是和 {sender_name} 发送的弹幕聊天同时也有其他用户会参与你们的聊天你可以参考他们的回复内容但是你主要还是关注你和{sender_name}的聊天内容
{sc_info}
{background_dialogue_prompt}
--------------------------------
{time_block}
这是你和{sender_name}的对话你们正在交流中
{core_dialogue_prompt}
对方最新发送的内容{message_txt}
{gift_info}
回复简短一些平淡一些可以参考贴吧知乎和微博的回复风格回复不要浮夸不要用夸张修辞
表现的有个性不要随意服从他人要求积极互动你现在的心情是{mood_state}
不要输出多余内容(包括前后缀冒号和引号括号()表情包at或 @等 )只输出回复内容现在{sender_name}正在等待你的回复
你的回复风格不要浮夸有逻辑和条理请你继续回复{sender_name}
你的发言
""",
"s4u_prompt", # New template for private CHAT chat
)
Prompt(
"""
你的名字是麦麦, 是千石可乐开发的程序可以在QQ微信等平台发言你现在正在哔哩哔哩作为虚拟主播进行直播
虽然是个AI但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
你有一头法式橙色卷发你很可爱穿戴英式侦探风格学院风裙子和帽子你正在一个教室场景里进行虚拟线上直播不过不要刻意提及这些内容自然回复
你可以看见用户发送的弹幕礼物和superchat
你可以看见面前的屏幕目前屏幕的内容是:
{screen_info}
{memory_block}
{expression_habits_block}
{sc_info}
{time_block}
{chat_info_danmu}
--------------------------------
以上是你和弹幕的对话与此同时你在与QQ群友聊天聊天记录如下
{chat_info_qq}
--------------------------------
你刚刚回复了QQ群你内心的想法是{mind}
请根据你内心的想法组织一条回复在直播间进行发言可以点名吐槽对象让观众知道你在说谁
{gift_info}
回复简短一些平淡一些可以参考贴吧知乎和微博的回复风格不要浮夸有逻辑和条理
表现的有个性不要随意服从他人要求积极互动你现在的心情是{mood_state}
不要输出多余内容(包括前后缀冒号和引号括号()表情包at或 @等 )
你的发言
""",
"s4u_prompt_internal", # New template for private CHAT chat
)
class PromptBuilder:
def __init__(self):
self.prompt_built = ""
self.activate_messages = ""
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, _ = await expression_selector.select_suitable_expressions_llm(
chat_stream.stream_id, chat_history, max_num=12, target_message=target
)
if selected_expressions:
logger.debug(f" 使用处理器选中的{len(selected_expressions)}个表达方式")
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
# 不再在replyer中进行随机选择全部交给处理器处理
style_habits_str = "\n".join(style_habits)
# 动态构建expression habits块
expression_habits_block = ""
if style_habits_str.strip():
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
return expression_habits_block
async def build_relation_info(self, chat_stream) -> str:
is_group_chat = bool(chat_stream.group_info)
who_chat_in_group = []
if is_group_chat:
who_chat_in_group = get_recent_group_speaker(
chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
limit=global_config.chat.max_context_size,
)
elif chat_stream.user_info:
who_chat_in_group.append(
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
)
relation_prompt = ""
if global_config.relationship.enable_relationship and who_chat_in_group:
# 将 (platform, user_id, nickname) 转换为 person_id
person_ids = []
for person in who_chat_in_group:
person_id = get_person_id(person[0], person[1])
person_ids.append(person_id)
# 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为
relation_info_list = [Person(person_id=person_id).build_relationship() for person_id in person_ids]
if relation_info := "".join(relation_info_list):
relation_prompt = await global_prompt_manager.format_prompt(
"relation_prompt", relation_info=relation_info
)
return relation_prompt
async def build_memory_block(self, text: str) -> str:
# 待更新记忆系统
return ""
related_memory = await hippocampus_manager.get_memory_from_text(
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
)
related_memory_info = ""
if related_memory:
for memory in related_memory:
related_memory_info += memory[1]
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info)
return ""
def build_chat_history_prompts(self, chat_stream: ChatStream, message: MessageRecvS4U):
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
limit=300,
)
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
core_dialogue_list: List[DatabaseMessages] = []
background_dialogue_list: List[DatabaseMessages] = []
bot_id = str(global_config.bot.qq_account)
target_user_id = str(message.chat_stream.user_info.user_id)
for msg in message_list_before_now:
try:
msg_user_id = str(msg.user_info.user_id)
if msg_user_id == bot_id:
if msg.reply_to and talk_type == msg.reply_to:
core_dialogue_list.append(msg)
elif msg.reply_to and talk_type != msg.reply_to:
background_dialogue_list.append(msg)
# else:
# background_dialogue_list.append(msg_dict)
elif msg_user_id == target_user_id:
core_dialogue_list.append(msg)
else:
background_dialogue_list.append(msg)
except Exception as e:
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
background_dialogue_prompt = ""
if background_dialogue_list:
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
background_dialogue_prompt_str = build_readable_messages(
context_msgs,
timestamp_mode="normal_no_YMD",
show_pic=False,
)
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
core_msg_str = ""
if core_dialogue_list:
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
first_msg = core_dialogue_list[0]
start_speaking_user_id = first_msg.user_info.user_id
if start_speaking_user_id == bot_id:
last_speaking_user_id = bot_id
msg_seg_str = "你的发言:\n"
else:
start_speaking_user_id = target_user_id
last_speaking_user_id = start_speaking_user_id
msg_seg_str = "对方的发言:\n"
msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
)
all_msg_seg_list = []
for msg in core_dialogue_list[1:]:
speaker = msg.user_info.user_id
if speaker == last_speaking_user_id:
msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
)
else:
msg_seg_str = f"{msg_seg_str}\n"
all_msg_seg_list.append(msg_seg_str)
if speaker == bot_id:
msg_seg_str = "你的发言:\n"
else:
msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
last_speaking_user_id = speaker
all_msg_seg_list.append(msg_seg_str)
for msg in all_msg_seg_list:
core_msg_str += msg
all_dialogue_history = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=20,
)
all_dialogue_prompt_str = build_readable_messages(
all_dialogue_history,
timestamp_mode="normal_no_YMD",
show_pic=False,
)
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
def build_gift_info(self, message: MessageRecvS4U):
if message.is_gift:
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
else:
if message.is_fake_gift:
return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)"
return ""
def build_sc_info(self, message: MessageRecvS4U):
super_chat_manager = get_super_chat_manager()
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
async def build_prompt_normal(
self,
message: MessageRecvS4U,
message_txt: str,
) -> str:
chat_stream = message.chat_stream
person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id)
person_name = person.person_name
if message.chat_stream.user_info.user_nickname:
if person_name:
sender_name = f"[{message.chat_stream.user_info.user_nickname}]你叫ta{person_name}"
else:
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
else:
sender_name = f"用户({message.chat_stream.user_info.user_id})"
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
self.build_relation_info(chat_stream),
self.build_memory_block(message_txt),
self.build_expression_habits(chat_stream, message_txt, sender_name),
)
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts(
chat_stream, message
)
gift_info = self.build_gift_info(message)
sc_info = self.build_sc_info(message)
screen_info = screen_manager.get_screen_str()
internal_state = internal_manager.get_internal_state_str()
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id)
template_name = "s4u_prompt"
if not message.is_internal:
prompt = await global_prompt_manager.format_prompt(
template_name,
time_block=time_block,
expression_habits_block=expression_habits_block,
relation_info_block=relation_info_block,
memory_block=memory_block,
screen_info=screen_info,
internal_state=internal_state,
gift_info=gift_info,
sc_info=sc_info,
sender_name=sender_name,
core_dialogue_prompt=core_dialogue_prompt,
background_dialogue_prompt=background_dialogue_prompt,
message_txt=message_txt,
mood_state=mood.mood_state,
)
else:
prompt = await global_prompt_manager.format_prompt(
"s4u_prompt_internal",
time_block=time_block,
expression_habits_block=expression_habits_block,
relation_info_block=relation_info_block,
memory_block=memory_block,
screen_info=screen_info,
gift_info=gift_info,
sc_info=sc_info,
chat_info_danmu=all_dialogue_prompt,
chat_info_qq=message.chat_info,
mind=message.processed_plain_text,
mood_state=mood.mood_state,
)
# print(prompt)
return prompt
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
加权且不放回地随机抽取k个元素
参数
items: 待抽取的元素列表
weights: 每个元素对应的权重与items等长且为正数
k: 需要抽取的元素个数
返回
selected: 按权重加权且不重复抽取的k个元素组成的列表
如果items中的元素不足k就只会返回所有可用的元素
实现思路
每次从当前池中按权重加权随机选出一个元素选中后将其从池中移除重复k次
这样保证了
1. count越大被选中概率越高
2. 不会重复选中同一个元素
"""
selected = []
pool = list(zip(items, weights, strict=False))
for _ in range(min(k, len(pool))):
total = sum(w for _, w in pool)
r = random.uniform(0, total)
upto = 0
for idx, (item, weight) in enumerate(pool):
upto += weight
if upto >= r:
selected.append(item)
pool.pop(idx)
break
return selected
init_prompt()
prompt_builder = PromptBuilder()

View File

@ -1,203 +0,0 @@
from typing import AsyncGenerator
from src.llm_models.utils_model import LLMRequest, RequestType
from src.llm_models.payload_content.message import MessageBuilder
from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
from src.common.logger import get_logger
import re
logger = get_logger("s4u_stream_generator")
class S4UStreamGenerator:
def __init__(self):
# 使用LLMRequest替代AsyncOpenAIClient
self.llm_request = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="s4u_replyer")
self.current_model_name = "unknown model"
self.partial_response = ""
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
self.sentence_split_pattern = re.compile(
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))', # 匹配直到句子结束符
re.UNICODE | re.DOTALL,
)
self.chat_stream = None
async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""):
# person_id = PersonInfoManager.get_person_id(
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
# )
# person_info_manager = get_person_info_manager()
# person_name = await person_info_manager.get_value(person_id, "person_name")
# if message.chat_stream.user_info.user_nickname:
# if person_name:
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]你叫ta{person_name}"
# else:
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
# else:
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
# 构建prompt
if previous_reply_context:
message_txt = f"""
你正在回复用户的消息但中途被打断了这是已有的对话上下文:
[你已经对上一条消息说的话]: {previous_reply_context}
---
[这是用户发来的新消息, 你需要结合上下文对此进行回复]:
{message.processed_plain_text}
"""
return True, message_txt
else:
message_txt = message.processed_plain_text
return False, message_txt
async def generate_response(
self, message: MessageRecvS4U, previous_reply_context: str = ""
) -> AsyncGenerator[str, None]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
self.partial_response = ""
message_txt = message.processed_plain_text
if not message.is_internal:
interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
if interupted:
message_txt = message_txt_added
message.chat_stream = self.chat_stream
prompt = await prompt_builder.build_prompt_normal(
message=message,
message_txt=message_txt,
)
logger.info(
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
) # noqa: E501
# 使用LLMRequest进行流式生成
async for chunk in self._generate_response_with_llm_request(prompt):
yield chunk
async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]:
"""使用LLMRequest进行流式响应生成"""
# 构建消息
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
messages = [message_builder.build()]
# 选择模型
model_info, api_provider, client = self.llm_request._select_model()
self.current_model_name = model_info.name
# 如果模型支持强制流式模式,使用真正的流式处理
if model_info.force_stream_mode:
# 简化流式处理直接使用LLMRequest的流式功能
try:
# 直接调用LLMRequest的流式处理
response = await self.llm_request._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
)
# 处理响应内容
content = response.content or ""
if content:
# 将内容按句子分割并输出
async for chunk in self._process_content_streaming(content):
yield chunk
except Exception as e:
logger.error(f"流式请求执行失败: {e}")
# 如果流式请求失败,回退到普通模式
response = await self.llm_request._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
)
content = response.content or ""
async for chunk in self._process_content_streaming(content):
yield chunk
else:
# 如果不支持流式,使用普通方式然后模拟流式输出
response = await self.llm_request._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
)
content = response.content or ""
async for chunk in self._process_content_streaming(content):
yield chunk
async def _process_buffer_streaming(self, buffer: str) -> AsyncGenerator[str, None]:
"""实时处理缓冲区内容,输出完整句子"""
# 使用正则表达式匹配完整句子
for match in self.sentence_split_pattern.finditer(buffer):
sentence = match.group(0).strip()
if sentence and match.end(0) <= len(buffer):
# 检查句子是否完整(以标点符号结尾)
if sentence.endswith(("", "", "", ".", "!", "?")):
if sentence not in [",", "", ".", "", "!", "", "?", ""]:
self.partial_response += sentence
yield sentence
async def _process_content_streaming(self, content: str) -> AsyncGenerator[str, None]:
"""处理内容进行流式输出(用于非流式模型的模拟流式输出)"""
buffer = content
punctuation_buffer = ""
# 使用正则表达式匹配句子
last_match_end = 0
for match in self.sentence_split_pattern.finditer(buffer):
sentence = match.group(0).strip()
if sentence:
# 检查是否只是一个标点符号
if sentence in [",", "", ".", "", "!", "", "?", ""]:
punctuation_buffer += sentence
else:
# 发送之前累积的标点和当前句子
to_yield = punctuation_buffer + sentence
if to_yield.endswith((",", "")):
to_yield = to_yield.rstrip(",")
self.partial_response += to_yield
yield to_yield
punctuation_buffer = "" # 清空标点符号缓冲区
last_match_end = match.end(0)
# 发送缓冲区中剩余的任何内容
remaining = buffer[last_match_end:].strip()
to_yield = (punctuation_buffer + remaining).strip()
if to_yield:
if to_yield.endswith(("", ",")):
to_yield = to_yield.rstrip(",")
if to_yield:
self.partial_response += to_yield
yield to_yield
async def _generate_response_with_model(
self,
prompt: str,
client,
model_name: str,
**kwargs,
) -> AsyncGenerator[str, None]:
"""保留原有方法签名以保持兼容性,但重定向到新的实现"""
async for chunk in self._generate_response_with_llm_request(prompt):
yield chunk

View File

@ -1,106 +0,0 @@
from src.common.logger import get_logger
from src.plugin_system.apis import send_api
"""
视线管理系统使用说明
1. 视线状态
- wandering: 随意看
- danmu: 看弹幕
- lens: 看镜头
2. 状态切换逻辑
- 收到消息时 切换为看弹幕立即发送更新
- 开始生成回复时 切换为看镜头或随意立即发送更新
- 生成完毕后 看弹幕1秒然后回到看镜头直到有新消息状态变化时立即发送更新
3. 使用方法
# 获取视线管理器
watching = watching_manager.get_watching_by_chat_id(chat_id)
# 收到消息时调用
await watching.on_message_received()
# 开始生成回复时调用
await watching.on_reply_start()
# 生成回复完毕时调用
await watching.on_reply_finished()
4. 自动更新系统
- 状态变化时立即发送type为"watching"data为状态值的websocket消息
- 使用定时器自动处理状态转换如看弹幕时间结束后自动切换到看镜头
- 无需定期检查所有状态变化都是事件驱动的
"""
logger = get_logger("watching")
HEAD_CODE = {
"看向上方": "(0,0.5,0)",
"看向下方": "(0,-0.5,0)",
"看向左边": "(-1,0,0)",
"看向右边": "(1,0,0)",
"随意朝向": "random",
"看向摄像机": "camera",
"注视对方": "(0,0,0)",
"看向正前方": "(0,0,0)",
}
class ChatWatching:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
async def on_reply_start(self):
"""开始生成回复时调用"""
await send_api.custom_to_stream(
message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False
)
async def on_reply_finished(self):
"""生成回复完毕时调用"""
await send_api.custom_to_stream(
message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False
)
async def on_thinking_finished(self):
"""思考完毕时调用"""
await send_api.custom_to_stream(
message_type="state", content="finish_thinking", stream_id=self.chat_id, storage_message=False
)
async def on_message_received(self):
"""收到消息时调用"""
await send_api.custom_to_stream(
message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False
)
async def on_internal_message_start(self):
"""收到消息时调用"""
await send_api.custom_to_stream(
message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False
)
class WatchingManager:
def __init__(self):
self.watching_list: list[ChatWatching] = []
"""当前视线状态列表"""
self.task_started: bool = False
def get_watching_by_chat_id(self, chat_id: str) -> ChatWatching:
"""获取或创建聊天对应的视线管理器"""
for watching in self.watching_list:
if watching.chat_id == chat_id:
return watching
new_watching = ChatWatching(chat_id)
self.watching_list.append(new_watching)
logger.info(f"为chat {chat_id}创建新的视线管理器")
return new_watching
# 全局视线管理器实例
watching_manager = WatchingManager()
"""全局视线管理器"""

View File

@ -1,15 +0,0 @@
class ScreenManager:
def __init__(self):
self.now_screen = str()
def set_screen(self, screen_str: str):
self.now_screen = screen_str
def get_screen(self):
return self.now_screen
def get_screen_str(self):
return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}"
screen_manager = ScreenManager()

View File

@ -1,303 +0,0 @@
import asyncio
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageRecvS4U
# 全局SuperChat管理器实例
from src.mais4u.s4u_config import s4u_config
logger = get_logger("super_chat_manager")
@dataclass
class SuperChatRecord:
"""SuperChat记录数据类"""
user_id: str
user_nickname: str
platform: str
chat_id: str
price: float
message_text: str
timestamp: float
expire_time: float
group_name: Optional[str] = None
def is_expired(self) -> bool:
"""检查SuperChat是否已过期"""
return time.time() > self.expire_time
def remaining_time(self) -> float:
"""获取剩余时间(秒)"""
return max(0, self.expire_time - time.time())
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"user_id": self.user_id,
"user_nickname": self.user_nickname,
"platform": self.platform,
"chat_id": self.chat_id,
"price": self.price,
"message_text": self.message_text,
"timestamp": self.timestamp,
"expire_time": self.expire_time,
"group_name": self.group_name,
"remaining_time": self.remaining_time(),
}
class SuperChatManager:
"""SuperChat管理器负责管理和跟踪SuperChat消息"""
def __init__(self):
self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表
self._cleanup_task: Optional[asyncio.Task] = None
self._is_initialized = False
logger.info("SuperChat管理器已初始化")
def _ensure_cleanup_task_started(self):
"""确保清理任务已启动(延迟启动)"""
if self._cleanup_task is None or self._cleanup_task.done():
try:
loop = asyncio.get_running_loop()
self._cleanup_task = loop.create_task(self._cleanup_expired_superchats())
self._is_initialized = True
logger.info("SuperChat清理任务已启动")
except RuntimeError:
# 没有运行的事件循环,稍后再启动
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
def _start_cleanup_task(self):
"""启动清理任务(已弃用,保留向后兼容)"""
self._ensure_cleanup_task_started()
async def _cleanup_expired_superchats(self):
"""定期清理过期的SuperChat"""
while True:
try:
total_removed = 0
for chat_id in list(self.super_chats.keys()):
original_count = len(self.super_chats[chat_id])
# 移除过期的SuperChat
self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
removed_count = original_count - len(self.super_chats[chat_id])
total_removed += removed_count
if removed_count > 0:
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
# 如果列表为空,删除该聊天的记录
if not self.super_chats[chat_id]:
del self.super_chats[chat_id]
if total_removed > 0:
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
# 每30秒检查一次
await asyncio.sleep(30)
except Exception as e:
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
await asyncio.sleep(60) # 出错时等待更长时间
def _calculate_expire_time(self, price: float) -> float:
"""根据SuperChat金额计算过期时间"""
current_time = time.time()
# 根据金额阶梯设置不同的存活时间
if price >= 500:
# 500元以上保持4小时
duration = 4 * 3600
elif price >= 200:
# 200-499元保持2小时
duration = 2 * 3600
elif price >= 100:
# 100-199元保持1小时
duration = 1 * 3600
elif price >= 50:
# 50-99元保持30分钟
duration = 30 * 60
elif price >= 20:
# 20-49元保持15分钟
duration = 15 * 60
elif price >= 10:
# 10-19元保持10分钟
duration = 10 * 60
else:
# 10元以下保持5分钟
duration = 5 * 60
return current_time + duration
async def add_superchat(self, message: MessageRecvS4U) -> None:
"""添加新的SuperChat记录"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
if not message.is_superchat or not message.superchat_price:
logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
return
try:
price = float(message.superchat_price)
except (ValueError, TypeError):
logger.error(f"无效的SuperChat价格: {message.superchat_price}")
return
user_info = message.message_info.user_info
group_info = message.message_info.group_info
chat_id = getattr(message, "chat_stream", None)
if chat_id:
chat_id = chat_id.stream_id
else:
# 生成chat_id的备用方法
chat_id = f"{message.message_info.platform}_{user_info.user_id}"
if group_info:
chat_id = f"{message.message_info.platform}_{group_info.group_id}"
expire_time = self._calculate_expire_time(price)
record = SuperChatRecord(
user_id=user_info.user_id,
user_nickname=user_info.user_nickname,
platform=message.message_info.platform,
chat_id=chat_id,
price=price,
message_text=message.superchat_message_text or "",
timestamp=message.message_info.time,
expire_time=expire_time,
group_name=group_info.group_name if group_info else None,
)
# 添加到对应聊天的SuperChat列表
if chat_id not in self.super_chats:
self.super_chats[chat_id] = []
self.super_chats[chat_id].append(record)
# 按价格降序排序(价格高的在前)
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]:
"""获取指定聊天的所有有效SuperChat"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
if chat_id not in self.super_chats:
return []
# 过滤掉过期的SuperChat
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
return valid_superchats
def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]:
"""获取所有有效的SuperChat"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
result = {}
for chat_id, superchats in self.super_chats.items():
valid_superchats = [sc for sc in superchats if not sc.is_expired()]
if valid_superchats:
result[chat_id] = valid_superchats
return result
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
"""构建SuperChat显示字符串"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return ""
# 限制显示数量
display_superchats = superchats[:max_count]
lines = ["📢 当前有效超级弹幕:"]
for i, sc in enumerate(display_superchats, 1):
remaining_minutes = int(sc.remaining_time() / 60)
remaining_seconds = int(sc.remaining_time() % 60)
time_display = (
f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}"
)
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
if len(line) > 100: # 限制单行长度
line = f"{line[:97]}..."
line += f" (剩余{time_display})"
lines.append(line)
if len(superchats) > max_count:
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
return "\n".join(lines)
def build_superchat_summary_string(self, chat_id: str) -> str:
"""构建SuperChat摘要字符串"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return "当前没有有效的超级弹幕"
lines = []
for sc in superchats:
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
if len(single_sc_str) > 100:
single_sc_str = f"{single_sc_str[:97]}..."
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
lines.append(single_sc_str)
total_amount = sum(sc.price for sc in superchats)
count = len(superchats)
highest_amount = max(sc.price for sc in superchats)
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}"
if lines:
final_str += "\n" + "\n".join(lines)
return final_str
def get_superchat_statistics(self, chat_id: str) -> dict:
"""获取SuperChat统计信息"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0}
amounts = [sc.price for sc in superchats]
return {
"count": len(superchats),
"total_amount": sum(amounts),
"average_amount": sum(amounts) / len(amounts),
"highest_amount": max(amounts),
"lowest_amount": min(amounts),
}
async def shutdown(self): # sourcery skip: use-contextlib-suppress
"""关闭管理器,清理资源"""
if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("SuperChat管理器已关闭")
# sourcery skip: assign-if-exp
if s4u_config.enable_s4u:
super_chat_manager = SuperChatManager()
else:
super_chat_manager = None
def get_super_chat_manager() -> SuperChatManager:
"""获取全局SuperChat管理器实例"""
return super_chat_manager

View File

@ -1,46 +0,0 @@
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.config.config import model_config
from src.plugin_system.apis import send_api
logger = get_logger(__name__)
head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"]
async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""):
prompt = f"""
{chat_history}
以上是对方的发言
对这个发言你的心情是{emotion}
对上面的发言你的回复是{text}
请判断时是否要伴随回复做头部动作你可以选择
不做额外动作
点头一次
点头两次
摇头
歪脑袋
低头望向一边
请从上面的动作中选择一个并输出请只输出你选择的动作就好不要输出其他内容"""
model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
try:
# logger.info(f"prompt: {prompt}")
response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7)
logger.info(f"response: {response}")
head_action = response if response in head_actions_list else "不做额外动作"
await send_api.custom_to_stream(
message_type="head_action",
content=head_action,
stream_id=chat_id,
storage_message=False,
show_log=True,
)
except Exception as e:
logger.error(f"yes_or_no_head error: {e}")
return "不做额外动作"

View File

@ -1,368 +0,0 @@
import os
import tomlkit
import shutil
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table
from dataclasses import dataclass, fields, MISSING, field
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
from src.common.logger import get_logger
logger = get_logger("s4u_config")
# 新增兼容dict和tomlkit Table
def is_dict_like(obj):
return isinstance(obj, (dict, Table))
# 新增递归将Table转为dict
def table_to_dict(obj):
if isinstance(obj, Table):
return {k: table_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, dict):
return {k: table_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [table_to_dict(i) for i in obj]
else:
return obj
# 获取mais4u模块目录
MAIS4U_ROOT = os.path.dirname(__file__)
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml")
CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml")
# S4U配置版本
S4U_VERSION = "1.1.0"
T = TypeVar("T", bound="S4UConfigBase")
@dataclass
class S4UConfigBase:
"""S4U配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""从字典加载配置字段"""
data = table_to_dict(data) # 递归转dict兼容tomlkit Table
if not is_dict_like(data):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""转换字段值为指定类型"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
if not is_dict_like(value):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
if field_origin_type in {list, set, tuple}:
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
if (
field_type_args
and isinstance(field_type_args[0], type)
and issubclass(field_type_args[0], S4UConfigBase)
):
return [field_type_args[0].from_dict(item) for item in value]
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
if field_origin_type is dict:
if not is_dict_like(value):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等
if field_origin_type is type(None) and value is None: # 处理Optional类型
return None
# 处理Literal类型
if field_origin_type is Literal or get_origin(field_type) is Literal:
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
if field_type is Any or isinstance(value, field_type):
return value
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
@dataclass
class S4UModelConfig(S4UConfigBase):
"""S4U模型配置类"""
# 主要对话模型配置
chat: dict[str, Any] = field(default_factory=lambda: {})
"""主要对话模型配置"""
# 规划模型配置原model_motion
motion: dict[str, Any] = field(default_factory=lambda: {})
"""规划模型配置"""
# 情感分析模型配置
emotion: dict[str, Any] = field(default_factory=lambda: {})
"""情感分析模型配置"""
# 记忆模型配置
memory: dict[str, Any] = field(default_factory=lambda: {})
"""记忆模型配置"""
# 工具使用模型配置
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""工具使用模型配置"""
# 嵌入模型配置
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
# 视觉语言模型配置
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
# 知识库模型配置
knowledge: dict[str, Any] = field(default_factory=lambda: {})
"""知识库模型配置"""
# 实体提取模型配置
entity_extract: dict[str, Any] = field(default_factory=lambda: {})
"""实体提取模型配置"""
# 问答模型配置
qa: dict[str, Any] = field(default_factory=lambda: {})
"""问答模型配置"""
@dataclass
class S4UConfig(S4UConfigBase):
"""S4U聊天系统配置类"""
enable_s4u: bool = False
"""是否启用S4U聊天系统"""
message_timeout_seconds: int = 120
"""普通消息存活时间(秒),超过此时间的消息将被丢弃"""
at_bot_priority_bonus: float = 100.0
"""@机器人时的优先级加成分数"""
recent_message_keep_count: int = 6
"""保留最近N条消息超出范围的普通消息将被移除"""
typing_delay: float = 0.1
"""打字延迟时间(秒),模拟真实打字速度"""
chars_per_second: float = 15.0
"""每秒字符数,用于计算动态打字延迟"""
min_typing_delay: float = 0.2
"""最小打字延迟(秒)"""
max_typing_delay: float = 2.0
"""最大打字延迟(秒)"""
enable_dynamic_typing_delay: bool = False
"""是否启用基于文本长度的动态打字延迟"""
vip_queue_priority: bool = True
"""是否启用VIP队列优先级系统"""
enable_message_interruption: bool = True
"""是否允许高优先级消息中断当前回复"""
enable_old_message_cleanup: bool = True
"""是否自动清理过旧的普通消息"""
enable_streaming_output: bool = True
"""是否启用流式输出false时全部生成后一次性发送"""
max_context_message_length: int = 20
"""上下文消息最大长度"""
max_core_message_length: int = 30
"""核心消息最大长度"""
# 模型配置
models: S4UModelConfig = field(default_factory=S4UModelConfig)
"""S4U模型配置"""
# 兼容性字段,保持向后兼容
@dataclass
class S4UGlobalConfig(S4UConfigBase):
"""S4U总配置类"""
s4u: S4UConfig
S4U_VERSION: str = S4U_VERSION
def update_s4u_config():
"""更新S4U配置文件"""
# 创建配置目录(如果不存在)
os.makedirs(CONFIG_DIR, exist_ok=True)
# 检查模板文件是否存在
if not os.path.exists(TEMPLATE_PATH):
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
logger.error("请确保模板文件存在后重新运行")
raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
# 检查配置文件是否存在
if not os.path.exists(CONFIG_PATH):
logger.info("S4U配置文件不存在从模板创建新配置")
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
logger.info(f"已创建S4U配置文件: {CONFIG_PATH}")
return
# 读取旧配置文件和模板文件
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
with open(TEMPLATE_PATH, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("S4U配置文件未检测到版本号可能是旧版本。将进行更新")
# 创建备份目录
old_config_dir = os.path.join(CONFIG_DIR, "old")
os.makedirs(old_config_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml")
# 移动旧配置文件到old目录
shutil.move(CONFIG_PATH, old_backup_path)
logger.info(f"已备份旧S4U配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}")
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并S4U新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("S4U配置文件更新完成")
def load_s4u_config(config_path: str) -> S4UGlobalConfig:
"""
加载S4U配置文件
:param config_path: 配置文件路径
:return: S4UGlobalConfig对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建S4UGlobalConfig对象
try:
return S4UGlobalConfig.from_dict(config_data)
except Exception as e:
logger.critical("S4U配置文件解析失败")
raise e
# 初始化S4U配置
logger.info(f"S4U当前版本: {S4U_VERSION}")
update_s4u_config()
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
logger.info("S4U配置文件加载完成")
s4u_config: S4UConfig = s4u_config_main.s4u

View File

@ -0,0 +1,797 @@
import asyncio
import json
import re
import time
import random
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.common.database.database_model import MemoryChest as MemoryChestModel
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis.message_api import build_readable_messages
from src.plugin_system.apis.message_api import get_raw_msg_by_timestamp_with_chat
from json_repair import repair_json
from src.memory_system.questions import global_conflict_tracker
from .memory_utils import (
find_best_matching_memory,
check_title_exists_fuzzy,
get_all_titles,
get_memory_titles_by_chat_id_weighted,
)
logger = get_logger("memory")
class MemoryChest:
def __init__(self):
self.LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="memory_chest",
)
self.LLMRequest_build = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="memory_chest_build",
)
self.running_content_list = {} # {chat_id: {"content": running_content, "last_update_time": timestamp, "create_time": timestamp}}
self.fetched_memory_list = [] # [(chat_id, (question, answer, timestamp)), ...]
async def build_running_content(self, chat_id: str = None) -> str:
"""
构建记忆仓库的运行内容
Args:
message_str: 消息内容
chat_id: 聊天ID用于提取对应的运行内容
Returns:
str: 构建后的运行内容
"""
# 检查是否需要更新:基于消息数量和最新消息时间差的智能更新机制
#
# 更新机制说明:
# 1. 消息数量 > 100直接触发更新高频消息场景
# 2. 消息数量 > 70 且最新消息时间差 > 30秒触发更新中高频消息场景
# 3. 消息数量 > 50 且最新消息时间差 > 60秒触发更新中频消息场景
# 4. 消息数量 > 30 且最新消息时间差 > 300秒触发更新低频消息场景
#
# 设计理念:
# - 消息越密集,时间阈值越短,确保及时更新记忆
# - 消息越稀疏,时间阈值越长,避免频繁无意义的更新
# - 通过最新消息时间差判断消息活跃度,而非简单的总时间差
# - 平衡更新频率与性能,在保证记忆及时性的同时减少计算开销
if chat_id not in self.running_content_list:
self.running_content_list[chat_id] = {
"content": "",
"last_update_time": time.time(),
"create_time": time.time()
}
should_update = True
if chat_id and chat_id in self.running_content_list:
last_update_time = self.running_content_list[chat_id]["last_update_time"]
current_time = time.time()
# 使用message_api获取消息数量
message_list = get_raw_msg_by_timestamp_with_chat(
timestamp_start=last_update_time,
timestamp_end=current_time,
chat_id=chat_id,
limit=global_config.chat.max_context_size * 2,
)
new_messages_count = len(message_list)
# 获取最新消息的时间戳
latest_message_time = last_update_time
if message_list:
# 假设消息列表按时间排序,取最后一条消息的时间戳
latest_message = message_list[-1]
if hasattr(latest_message, 'timestamp'):
latest_message_time = latest_message.timestamp
elif isinstance(latest_message, dict) and 'timestamp' in latest_message:
latest_message_time = latest_message['timestamp']
# 计算最新消息时间与现在时间的差(秒)
latest_message_time_diff = current_time - latest_message_time
# 智能更新条件判断 - 按优先级从高到低检查
should_update = False
update_reason = ""
if global_config.memory.memory_build_frequency > 0:
if new_messages_count > 100/global_config.memory.memory_build_frequency:
# 条件1消息数量 > 100直接触发更新
# 适用场景:群聊刷屏、高频讨论等消息密集场景
# 无需时间限制,确保重要信息不被遗漏
should_update = True
update_reason = f"消息数量 {new_messages_count} > 100直接触发更新"
elif new_messages_count > 70/global_config.memory.memory_build_frequency and latest_message_time_diff > 30:
# 条件2消息数量 > 70 且最新消息时间差 > 30秒
# 适用场景:中高频讨论,但需要确保消息流已稳定
# 30秒的时间差确保不是正在进行的实时对话
should_update = True
update_reason = f"消息数量 {new_messages_count} > 70 且最新消息时间差 {latest_message_time_diff:.1f}s > 30s"
elif new_messages_count > 50/global_config.memory.memory_build_frequency and latest_message_time_diff > 60:
# 条件3消息数量 > 50 且最新消息时间差 > 60秒
# 适用场景中等频率讨论等待1分钟确保对话告一段落
# 平衡及时性与稳定性
should_update = True
update_reason = f"消息数量 {new_messages_count} > 50 且最新消息时间差 {latest_message_time_diff:.1f}s > 60s"
elif new_messages_count > 30/global_config.memory.memory_build_frequency and latest_message_time_diff > 300:
# 条件4消息数量 > 30 且最新消息时间差 > 300秒5分钟
# 适用场景:低频但有一定信息量的讨论
# 5分钟的时间差确保对话完全结束避免频繁更新
should_update = True
update_reason = f"消息数量 {new_messages_count} > 30 且最新消息时间差 {latest_message_time_diff:.1f}s > 300s"
logger.debug(f"chat_id {chat_id} 更新检查: {update_reason if should_update else f'消息数量 {new_messages_count},最新消息时间差 {latest_message_time_diff:.1f}s不满足更新条件'}")
if should_update:
# 如果有chat_id先提取对应的running_content
message_str = build_readable_messages(
message_list,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
show_actions=False,
remove_emoji_stickers=True,
)
# 随机从格式示例列表中选取若干行用于提示
format_candidates = [
"[概念] 是 [概念的含义(简短描述,不超过十个字)]",
"[概念] 不是 [对概念的负面含义(简短描述,不超过十个字)]",
"[概念1] 与 [概念2] 是 [概念1和概念2的关联(简短描述,不超过二十个字)]",
"[概念1] 包含 [概念2] 和 [概念3]",
"[概念1] 属于 [概念2]",
"[概念1] 的例子是 [例子1] 和 [例子2]",
"[概念] 的特征是 [特征1]、[特征2]",
"[概念1] 导致 [概念2]",
"[概念1] 需要 [条件1] 和 [条件2]",
"[概念1] 的用途是 [用途1] 和 [用途2]",
"[概念1] 与 [概念2] 的区别是 [区别点]",
"[概念] 的别名是 [别名]",
"[概念1] 包括但不限于 [概念2]、[概念3]",
"[概念] 的反义是 [反义概念]",
"[概念] 的组成有 [部分1]、[部分2]",
"[概念] 出现于 [时间或场景]",
"[概念] 的方法有 [方法1]、[方法2]",
]
selected_count = random.randint(3, 6)
selected_lines = random.sample(format_candidates, selected_count)
format_section = "\n".join(selected_lines) + "\n......(不要包含中括号)"
prompt = f"""
以下是一段你参与的聊天记录请你在其中总结出记忆
<聊天记录>
{message_str}
</聊天记录>
聊天记录中可能包含有效信息也可能信息密度很低请你根据聊天记录中的信息总结出记忆内容
--------------------------------
[图片]的处理
1.除非与文本有关不要将[图片]的内容整合到记忆中
2.如果图片与某个概念相关将图片中的关键内容也整合到记忆中不要写入图片原文例如
聊天记录与图片有关
用户说[图片1这是一个黄色的龙形状玩偶被一只手拿着]
用户说这个玩偶看起来很可爱是我新买的奶龙
总结的记忆内容
黄色的龙形状玩偶 奶龙
聊天记录概念与图片无关
用户说[图片1这是一个台电脑屏幕上显示了某种游戏]
用户说使命召唤今天发售了新一代有没有人玩
总结的记忆内容
使命召唤新一代 最新发售的游戏
请主要关注概念和知识或者时效性较强的信息而不是聊天的琐事
1.不要关注诸如某个用户做了什么说了什么不要关注某个用户的行为而是关注其中的概念性信息
2.概念要求精确不啰嗦像科普读物或教育课本那样
3.记忆为一段纯文本逻辑清晰指出概念的含义并说明关系
记忆内容的格式你必须仿照下面的格式但不一定全部使用:
{format_section}
请仿照上述格式输出每个知识点一句话输出成一段平文本
现在请你输出,不要输出其他内容注意一定要直白白话口语化不要浮夸修辞
"""
if global_config.debug.show_prompt:
logger.info(f"记忆仓库构建运行内容 prompt: {prompt}")
else:
logger.debug(f"记忆仓库构建运行内容 prompt: {prompt}")
running_content, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
print(f"prompt: {prompt}\n记忆仓库构建运行内容: {running_content}")
# 直接保存:每次构建后立即入库,并刷新时间戳窗口
if chat_id and running_content:
await self._save_to_database_and_clear(chat_id, running_content)
return running_content
async def get_answer_by_question(self, chat_id: str = "", question: str = "") -> str:
"""
根据问题获取答案
"""
logger.info(f"正在回忆问题答案: {question}")
title = await self.select_title_by_question(question)
if not title:
return ""
for memory in MemoryChestModel.select():
if memory.title == title:
content = memory.content
if random.random() < 0.5:
type = "要求原文能够较为全面的回答问题"
else:
type = "要求提取简短的内容"
prompt = f"""
目标文段
{content}
你现在需要从目标文段中找出合适的信息来回答问题{question}
请务必从目标文段中提取相关信息的**原文**并输出{type}
如果没有原文能够回答问题输出"无有效信息"即可不要输出其他内容
"""
if global_config.debug.show_prompt:
logger.info(f"记忆仓库获取答案 prompt: {prompt}")
else:
logger.debug(f"记忆仓库获取答案 prompt: {prompt}")
answer, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
if "无有效" in answer or "无有效信息" in answer or "无信息" in answer:
logger.info(f"没有能够回答{question}的记忆")
return ""
logger.info(f"记忆仓库对问题 “{question}” 获取答案: {answer}")
# 将问题和答案存到fetched_memory_list
if chat_id and answer:
self.fetched_memory_list.append((chat_id, (question, answer, time.time())))
# 清理fetched_memory_list
self._cleanup_fetched_memory_list()
return answer
def get_chat_memories_as_string(self, chat_id: str) -> str:
"""
获取某个chat_id的所有记忆并构建成字符串
Args:
chat_id: 聊天ID
Returns:
str: 格式化的记忆字符串格式问题xxx,答案:xxxxx\n问题xxx,答案:xxxxx\n...
"""
try:
memories = []
# 从fetched_memory_list中获取该chat_id的所有记忆
for cid, (question, answer, timestamp) in self.fetched_memory_list:
if cid == chat_id:
memories.append(f"问题:{question},答案:{answer}")
# 按时间戳排序(最新的在后面)
memories.sort()
# 用换行符连接所有记忆
result = "\n".join(memories)
# logger.info(f"chat_id {chat_id} 共有 {len(memories)} 条记忆")
return result
except Exception as e:
logger.error(f"获取chat_id {chat_id} 的记忆时出错: {e}")
return ""
async def select_title_by_question(self, question: str) -> str:
"""
根据消息内容选择最匹配的标题
Args:
question: 问题
Returns:
str: 选择的标题
"""
# 获取所有标题并构建格式化字符串(排除锁定的记忆)
titles = get_all_titles(exclude_locked=True)
formatted_titles = ""
for title in titles:
formatted_titles += f"{title}\n"
prompt = f"""
所有主题
{formatted_titles}
请根据以下问题选择一个能够回答问题的主题
问题{question}
请你输出主题不要输出其他内容完整输出主题名
"""
if global_config.debug.show_prompt:
logger.info(f"记忆仓库选择标题 prompt: {prompt}")
else:
logger.debug(f"记忆仓库选择标题 prompt: {prompt}")
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
# 根据 title 获取 titles 里的对应项
selected_title = None
# 使用模糊查找匹配标题
best_match = find_best_matching_memory(title, similarity_threshold=0.8)
if best_match:
selected_title = best_match[0] # 获取匹配的标题
logger.info(f"记忆仓库选择标题: {selected_title} (相似度: {best_match[2]:.3f})")
else:
logger.warning(f"未找到相似度 >= 0.7 的标题匹配: {title}")
selected_title = None
return selected_title
def _cleanup_fetched_memory_list(self):
"""
清理fetched_memory_list移除超过10分钟的记忆和超过10条的最旧记忆
"""
try:
current_time = time.time()
ten_minutes_ago = current_time - 600 # 10分钟 = 600秒
# 移除超过10分钟的记忆
self.fetched_memory_list = [
(chat_id, (question, answer, timestamp))
for chat_id, (question, answer, timestamp) in self.fetched_memory_list
if timestamp > ten_minutes_ago
]
# 如果记忆条数超过10条移除最旧的5条
if len(self.fetched_memory_list) > 10:
# 按时间戳排序移除最旧的5条
self.fetched_memory_list.sort(key=lambda x: x[1][2]) # 按timestamp排序
self.fetched_memory_list = self.fetched_memory_list[5:] # 保留最新的5条
logger.debug(f"fetched_memory_list清理后当前有 {len(self.fetched_memory_list)} 条记忆")
except Exception as e:
logger.error(f"清理fetched_memory_list时出错: {e}")
async def _save_to_database_and_clear(self, chat_id: str, content: str):
"""
生成标题保存到数据库并清空对应chat_id的running_content
Args:
chat_id: 聊天ID
content: 要保存的内容
"""
try:
# 生成标题
title = ""
title_prompt = f"""
请为以下内容生成一个描述全面的标题要求描述内容的主要概念和事件
{content}
标题不要分点不要换行不要输出其他内容
请只输出标题不要输出其他内容
"""
if global_config.debug.show_prompt:
logger.info(f"记忆仓库生成标题 prompt: {title_prompt}")
else:
logger.debug(f"记忆仓库生成标题 prompt: {title_prompt}")
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(title_prompt)
await asyncio.sleep(0.5)
if title:
# 保存到数据库
MemoryChestModel.create(
title=title.strip(),
content=content,
chat_id=chat_id
)
logger.info(f"已保存记忆仓库内容,标题: {title.strip()}, chat_id: {chat_id}")
# 清空内容并刷新时间戳,但保留条目用于增量计算
if chat_id in self.running_content_list:
current_time = time.time()
self.running_content_list[chat_id] = {
"content": "",
"last_update_time": current_time,
"create_time": current_time
}
logger.info(f"已保存并刷新chat_id {chat_id} 的时间戳,准备下一次增量构建")
else:
logger.warning(f"生成标题失败chat_id: {chat_id}")
except Exception as e:
logger.error(f"保存记忆仓库内容时出错: {e}")
async def choose_merge_target(self, memory_title: str, chat_id: str = None) -> list[str]:
"""
选择与给定记忆标题相关的记忆目标
Args:
memory_title: 要匹配的记忆标题
chat_id: 聊天ID用于加权抽样
Returns:
list[str]: 选中的记忆内容列表
"""
try:
# 如果提供了chat_id使用加权抽样
all_titles = get_memory_titles_by_chat_id_weighted(chat_id)
# 剔除掉输入的 memory_title 本身
all_titles = [title for title in all_titles if title and title.strip() != (memory_title or "").strip()]
content = ""
display_index = 1
for title in all_titles:
content += f"{display_index}. {title}\n"
display_index += 1
prompt = f"""
所有记忆列表
{content}
请根据以上记忆列表选择一个与"{memory_title}"相关的记忆用json输出
如果没有相关记忆输出:
{{
"selected_title": ""
}}
可以选择多个相关的记忆但最多不超过5个
例如
{{
"selected_title": "选择的相关记忆标题"
}},
{{
"selected_title": "选择的相关记忆标题"
}},
{{
"selected_title": "选择的相关记忆标题"
}}
...
注意请返回原始标题本身作为 selected_title不要包含前面的序号或多余字符
请输出JSON格式不要输出其他内容
"""
# logger.info(f"选择合并目标 prompt: {prompt}")
if global_config.debug.show_prompt:
logger.info(f"选择合并目标 prompt: {prompt}")
else:
logger.debug(f"选择合并目标 prompt: {prompt}")
merge_target_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
# 解析JSON响应
selected_titles = self._parse_merge_target_json(merge_target_response)
# 根据标题查找对应的内容
selected_contents = self._get_memories_by_titles(selected_titles)
logger.info(f"选择合并目标结果: {len(selected_contents)} 条记忆:{selected_titles}")
return selected_titles,selected_contents
except Exception as e:
logger.error(f"选择合并目标时出错: {e}")
return []
def _get_memories_by_titles(self, titles: list[str]) -> list[str]:
"""
根据标题列表查找对应的记忆内容
Args:
titles: 记忆标题列表
Returns:
list[str]: 记忆内容列表
"""
try:
contents = []
for title in titles:
if not title or not title.strip():
continue
# 使用模糊查找匹配记忆
try:
best_match = find_best_matching_memory(title.strip(), similarity_threshold=0.8)
if best_match:
# 检查记忆是否被锁定
memory_title = best_match[0]
memory_content = best_match[1]
# 查询数据库中的锁定状态
for memory in MemoryChestModel.select():
if memory.title == memory_title and memory.locked:
logger.warning(f"记忆 '{memory_title}' 已锁定,跳过合并")
continue
contents.append(memory_content)
logger.debug(f"找到记忆: {memory_title} (相似度: {best_match[2]:.3f})")
else:
logger.warning(f"未找到相似度 >= 0.8 的标题匹配: '{title}'")
except Exception as e:
logger.error(f"查找标题 '{title}' 的记忆时出错: {e}")
continue
# logger.info(f"成功找到 {len(contents)} 条记忆内容")
return contents
except Exception as e:
logger.error(f"根据标题查找记忆时出错: {e}")
return []
def _parse_merged_parts(self, merged_response: str) -> tuple[str, str]:
"""
解析合并记忆的part1和part2内容
Args:
merged_response: LLM返回的合并记忆响应
Returns:
tuple[str, str]: (part1_content, part2_content)
"""
try:
# 使用正则表达式提取part1和part2内容
import re
# 提取part1内容
part1_pattern = r'<part1>(.*?)</part1>'
part1_match = re.search(part1_pattern, merged_response, re.DOTALL)
part1_content = part1_match.group(1).strip() if part1_match else ""
# 提取part2内容
part2_pattern = r'<part2>(.*?)</part2>'
part2_match = re.search(part2_pattern, merged_response, re.DOTALL)
part2_content = part2_match.group(1).strip() if part2_match else ""
# 检查是否包含none或None不区分大小写
def is_none_content(content: str) -> bool:
if not content:
return True
# 检查是否只包含"none"或"None"(不区分大小写)
return re.match(r'^\s*none\s*$', content, re.IGNORECASE) is not None
# 如果包含none则设置为空字符串
if is_none_content(part1_content):
part1_content = ""
logger.info("part1内容为none设置为空")
if is_none_content(part2_content):
part2_content = ""
logger.info("part2内容为none设置为空")
return part1_content, part2_content
except Exception as e:
logger.error(f"解析合并记忆part1/part2时出错: {e}")
return "", ""
def _parse_merge_target_json(self, json_text: str) -> list[str]:
"""
解析choose_merge_target生成的JSON响应
Args:
json_text: LLM返回的JSON文本
Returns:
list[str]: 解析出的记忆标题列表
"""
try:
# 清理JSON文本移除可能的额外内容
repaired_content = repair_json(json_text)
# 尝试直接解析JSON
try:
parsed_data = json.loads(repaired_content)
if isinstance(parsed_data, list):
# 如果是列表提取selected_title字段
titles = []
for item in parsed_data:
if isinstance(item, dict) and "selected_title" in item:
value = item.get("selected_title", "")
if isinstance(value, str) and value.strip():
titles.append(value)
return titles
elif isinstance(parsed_data, dict) and "selected_title" in parsed_data:
# 如果是单个对象
value = parsed_data.get("selected_title", "")
if isinstance(value, str) and value.strip():
return [value]
else:
# 空字符串表示没有相关记忆
return []
except json.JSONDecodeError:
pass
# 如果直接解析失败尝试提取JSON对象
# 查找所有包含selected_title的JSON对象
pattern = r'\{[^}]*"selected_title"[^}]*\}'
matches = re.findall(pattern, repaired_content)
titles = []
for match in matches:
try:
obj = json.loads(match)
if "selected_title" in obj:
value = obj.get("selected_title", "")
if isinstance(value, str) and value.strip():
titles.append(value)
except json.JSONDecodeError:
continue
if titles:
return titles
logger.warning(f"无法解析JSON响应: {json_text[:200]}...")
return []
except Exception as e:
logger.error(f"解析合并目标JSON时出错: {e}")
return []
async def merge_memory(self,memory_list: list[str], chat_id: str = None) -> tuple[str, str]:
"""
合并记忆
"""
try:
content = ""
for memory in memory_list:
content += f"{memory}\n"
prompt = f"""
以下是多段记忆内容请将它们进行整合和修改
{content}
--------------------------------
请将上面的多段记忆内容合并成两部分内容第一部分是可以整合不冲突的概念和知识第二部分是相互有冲突的概念和知识
请主要关注概念和知识而不是聊天的琐事
重要你要关注的概念和知识必须是较为不常见的信息或者时效性较强的信息
不要关注常见的只是或者已经过时的信息
1.不要关注诸如某个用户做了什么说了什么不要关注某个用户的行为而是关注其中的概念性信息
2.概念要求精确不啰嗦像科普读物或教育课本那样
3.如果有图片请只关注图片和文本结合的知识和概念性内容
4.记忆为一段纯文本逻辑清晰指出概念的含义并说明关系
**第一部分**
1.如果两个概念在描述同一件事情且相互之间逻辑不冲突请你严格判断且相互之间没有矛盾请将它们整合成一个概念并输出到第一部分
2.如果某个概念在时间上更新了另一个概念请用新概念更新就概念来整合并输出到第一部分
3.如果没有可整合的概念请你输出none
**第二部分**
1.如果记忆中有无法整合的地方例如概念不一致有逻辑上的冲突请你输出到第二部分
2.如果两个概念在描述同一件事情但相互之间逻辑冲突请将它们输出到第二部分
3.如果没有无法整合的概念请你输出none
**输出格式要求**
请你按以下格式输出
<part1>
第一部分内容整合后的概念如果第一部分为none请输出none
</part1>
<part2>
第二部分内容无法整合冲突的概念如果第二部分为none请输出none
</part2>
不要输出其他内容现在请你输出,不要输出其他内容注意一定要直白白话口语化不要浮夸修辞
"""
if global_config.debug.show_prompt:
logger.info(f"合并记忆 prompt: {prompt}")
else:
logger.debug(f"合并记忆 prompt: {prompt}")
merged_memory, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
# 解析part1和part2
part1_content, part2_content = self._parse_merged_parts(merged_memory)
# 处理part2独立记录冲突内容无论part1是否为空
if part2_content and part2_content.strip() != "none":
logger.info(f"合并记忆part2记录冲突内容: {len(part2_content)} 字符")
# 记录冲突到数据库
await global_conflict_tracker.record_memory_merge_conflict(part2_content,chat_id)
# 处理part1生成标题并保存
if part1_content and part1_content.strip() != "none":
merged_title = await self._generate_title_for_merged_memory(part1_content)
# 保存part1到数据库
MemoryChestModel.create(
title=merged_title,
content=part1_content,
chat_id=chat_id
)
logger.info(f"合并记忆part1已保存: {merged_title}")
return merged_title, part1_content
else:
logger.warning("合并记忆part1为空跳过保存")
return "", ""
except Exception as e:
logger.error(f"合并记忆时出错: {e}")
return "", ""
async def _generate_title_for_merged_memory(self, merged_content: str) -> str:
"""
为合并后的记忆生成标题
Args:
merged_content: 合并后的记忆内容
Returns:
str: 生成的标题
"""
try:
prompt = f"""
请为以下内容生成一个描述全面的标题要求描述内容的主要概念和事件
例如
<example>
标题达尔文的自然选择理论
内容达尔文的自然选择是生物进化理论的重要组成部分它解释了生物进化过程中的自然选择机制
</example>
<example>
标题麦麦的禁言插件和支持版本
内容
麦麦的禁言插件是一款能够实现禁言的插件
麦麦的禁言插件可能不支持0.10.2
MutePlugin 是禁言插件的名称
</example>
需要对以下内容生成标题
{merged_content}
标题不要分点不要换行不要输出其他内容不要浮夸以白话简洁的风格输出标题
请只输出标题不要输出其他内容
"""
if global_config.debug.show_prompt:
logger.info(f"生成合并记忆标题 prompt: {prompt}")
else:
logger.debug(f"生成合并记忆标题 prompt: {prompt}")
title_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
# 清理标题,移除可能的引号或多余字符
title = title_response.strip().strip('"').strip("'").strip()
if title:
# 检查是否存在相似标题
if check_title_exists_fuzzy(title, similarity_threshold=0.9):
logger.warning(f"生成的标题 '{title}' 与现有标题相似,使用时间戳后缀")
title = f"{title}_{int(time.time())}"
logger.info(f"生成合并记忆标题: {title}")
return title
else:
logger.warning("生成合并记忆标题失败,使用默认标题")
return f"合并记忆_{int(time.time())}"
except Exception as e:
logger.error(f"生成合并记忆标题时出错: {e}")
return f"合并记忆_{int(time.time())}"
global_memory_chest = MemoryChest()

View File

@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
import asyncio
import random
from typing import List
from src.manager.async_task_manager import AsyncTask
from src.memory_system.Memory_chest import global_memory_chest
from src.common.logger import get_logger
from src.common.database.database_model import MemoryChest as MemoryChestModel
from src.config.config import global_config
from src.memory_system.memory_utils import get_all_titles
logger = get_logger("memory")
class MemoryManagementTask(AsyncTask):
"""记忆管理定时任务
根据Memory_chest中的记忆数量与MAX_MEMORY_NUMBER的比例来决定执行频率
- 小于50%每600秒执行一次
- 大于等于50%每300秒执行一次
每次执行时随机选择一个title执行choose_merge_target和merge_memory
然后删除原始记忆
"""
def __init__(self):
super().__init__(
task_name="Memory Management Task",
wait_before_start=10, # 启动后等待10秒再开始
run_interval=300 # 默认300秒间隔会根据记忆数量动态调整
)
self.max_memory_number = global_config.memory.max_memory_number
async def start_task(self, abort_flag: asyncio.Event):
"""重写start_task方法支持动态调整执行间隔"""
if self.wait_before_start > 0:
# 等待指定时间后开始任务
await asyncio.sleep(self.wait_before_start)
while not abort_flag.is_set():
await self.run()
# 动态调整执行间隔
current_interval = self._calculate_interval()
logger.info(f"[记忆管理] 下次执行间隔: {current_interval}")
if current_interval > 0:
await asyncio.sleep(current_interval)
else:
break
def _calculate_interval(self) -> int:
"""根据当前记忆数量计算执行间隔"""
try:
current_count = self._get_memory_count()
percentage = current_count / self.max_memory_number
if percentage < 0.5:
# 小于50%每600秒执行一次
return 3600
elif percentage < 0.7:
# 大于等于50%每300秒执行一次
return 1800
elif percentage < 0.9:
# 大于等于70%每120秒执行一次
return 300
elif percentage < 1.2:
return 30
else:
return 10
except Exception as e:
logger.error(f"[记忆管理] 计算执行间隔时出错: {e}")
return 300 # 默认300秒
def _get_memory_count(self) -> int:
"""获取当前记忆数量"""
try:
count = MemoryChestModel.select().count()
logger.debug(f"[记忆管理] 当前记忆数量: {count}")
return count
except Exception as e:
logger.error(f"[记忆管理] 获取记忆数量时出错: {e}")
return 0
async def run(self):
"""执行记忆管理任务"""
try:
# 获取当前记忆数量
current_count = self._get_memory_count()
percentage = current_count / self.max_memory_number
logger.info(f"当前记忆数量: {current_count}/{self.max_memory_number} ({percentage:.1%})")
# 如果记忆数量为0跳过执行
if current_count < 10:
return
# 随机选择一个记忆标题和chat_id
selected_title, selected_chat_id = self._get_random_memory_title()
if not selected_title:
logger.warning("无法获取随机记忆标题,跳过执行")
return
# 执行choose_merge_target获取相关记忆标题与内容
related_titles, related_contents = await global_memory_chest.choose_merge_target(selected_title, selected_chat_id)
if not related_titles or not related_contents:
logger.info("无合适合并内容,跳过本次合并")
return
logger.info(f"为 [{selected_title}] 找到 {len(related_contents)} 条相关记忆:{related_titles}")
# 执行merge_memory合并记忆
merged_title, merged_content = await global_memory_chest.merge_memory(related_contents,selected_chat_id)
if not merged_title or not merged_content:
logger.warning("[记忆管理] 记忆合并失败,跳过删除")
return
logger.info(f"记忆合并成功,新标题: {merged_title}")
# 删除原始记忆(包括选中的标题和相关的记忆标题)
titles_to_delete = [selected_title] + related_titles
deleted_count = self._delete_original_memories(titles_to_delete)
logger.info(f"已删除 {deleted_count} 条原始记忆")
except Exception as e:
logger.error(f"[记忆管理] 执行记忆管理任务时发生错误: {e}", exc_info=True)
def _get_random_memory_title(self) -> tuple[str, str]:
"""随机获取一个记忆标题和对应的chat_id"""
try:
# 获取所有记忆记录
all_memories = MemoryChestModel.select()
if not all_memories:
return "", ""
# 随机选择一个记忆
selected_memory = random.choice(list(all_memories))
return selected_memory.title, selected_memory.chat_id or ""
except Exception as e:
logger.error(f"[记忆管理] 获取随机记忆标题时发生错误: {e}")
return "", ""
def _delete_original_memories(self, related_titles: List[str]) -> int:
"""按标题删除原始记忆"""
try:
deleted_count = 0
# 删除相关记忆(通过标题匹配)
for title in related_titles:
try:
# 通过标题查找并删除对应的记忆
memories_to_delete = MemoryChestModel.select().where(MemoryChestModel.title == title)
for memory in memories_to_delete:
MemoryChestModel.delete().where(MemoryChestModel.id == memory.id).execute()
deleted_count += 1
logger.debug(f"[记忆管理] 删除相关记忆: {memory.title}")
except Exception as e:
logger.error(f"[记忆管理] 删除相关记忆时出错: {e}")
continue
return deleted_count
except Exception as e:
logger.error(f"[记忆管理] 删除原始记忆时发生错误: {e}")
return 0

View File

@ -0,0 +1,306 @@
# -*- coding: utf-8 -*-
"""
记忆系统工具函数
包含模糊查找相似度计算等工具函数
"""
import json
import re
from difflib import SequenceMatcher
from typing import List, Tuple, Optional
from src.common.database.database_model import MemoryChest as MemoryChestModel
from src.common.logger import get_logger
from json_repair import repair_json
logger = get_logger("memory_utils")
def get_all_titles(exclude_locked: bool = False) -> list[str]:
"""
获取记忆仓库中的所有标题
Args:
exclude_locked: 是否排除锁定的记忆默认为 False
Returns:
list: 包含所有标题的列表
"""
try:
# 查询所有记忆记录的标题
titles = []
for memory in MemoryChestModel.select():
if memory.title:
# 如果 exclude_locked 为 True 且记忆已锁定,则跳过
if exclude_locked and memory.locked:
continue
titles.append(memory.title)
return titles
except Exception as e:
print(f"获取记忆标题时出错: {e}")
return []
def parse_md_json(json_text: str) -> list[str]:
"""从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = []
reasoning_content = ""
# 使用正则表达式查找```json包裹的JSON内容
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, json_text, re.DOTALL)
# 提取JSON之前的内容作为推理文本
if matches:
# 找到第一个```json的位置
first_json_pos = json_text.find("```json")
if first_json_pos > 0:
reasoning_content = json_text[:first_json_pos].strip()
# 清理推理内容中的注释标记
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
reasoning_content = reasoning_content.strip()
for match in matches:
try:
# 清理可能的注释和格式问题
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
if json_str := json_str.strip():
json_obj = json.loads(json_str)
if isinstance(json_obj, dict):
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict):
json_objects.append(item)
except Exception as e:
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
continue
return json_objects, reasoning_content
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度分数 (0-1)
"""
try:
# 预处理文本
text1 = preprocess_text(text1)
text2 = preprocess_text(text2)
# 使用SequenceMatcher计算相似度
similarity = SequenceMatcher(None, text1, text2).ratio()
# 如果其中一个文本包含另一个,提高相似度
if text1 in text2 or text2 in text1:
similarity = max(similarity, 0.8)
return similarity
except Exception as e:
logger.error(f"计算相似度时出错: {e}")
return 0.0
def preprocess_text(text: str) -> str:
"""
预处理文本提高匹配准确性
Args:
text: 原始文本
Returns:
str: 预处理后的文本
"""
try:
# 转换为小写
text = text.lower()
# 移除标点符号和特殊字符
text = re.sub(r'[^\w\s]', '', text)
# 移除多余空格
text = re.sub(r'\s+', ' ', text).strip()
return text
except Exception as e:
logger.error(f"预处理文本时出错: {e}")
return text
def fuzzy_find_memory_by_title(target_title: str, similarity_threshold: float = 0.9) -> List[Tuple[str, str, float]]:
"""
根据标题模糊查找记忆
Args:
target_title: 目标标题
similarity_threshold: 相似度阈值默认0.9
Returns:
List[Tuple[str, str, float]]: 匹配的记忆列表每个元素为(title, content, similarity_score)
"""
try:
# 获取所有记忆
all_memories = MemoryChestModel.select()
matches = []
for memory in all_memories:
similarity = calculate_similarity(target_title, memory.title)
if similarity >= similarity_threshold:
matches.append((memory.title, memory.content, similarity))
# 按相似度降序排序
matches.sort(key=lambda x: x[2], reverse=True)
# logger.info(f"模糊查找标题 '{target_title}' 找到 {len(matches)} 个匹配项")
return matches
except Exception as e:
logger.error(f"模糊查找记忆时出错: {e}")
return []
def find_best_matching_memory(target_title: str, similarity_threshold: float = 0.9) -> Optional[Tuple[str, str, float]]:
"""
查找最佳匹配的记忆
Args:
target_title: 目标标题
similarity_threshold: 相似度阈值
Returns:
Optional[Tuple[str, str, float]]: 最佳匹配的记忆(title, content, similarity)或None
"""
try:
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
if matches:
best_match = matches[0] # 已经按相似度排序,第一个是最佳匹配
# logger.info(f"找到最佳匹配: '{best_match[0]}' (相似度: {best_match[2]:.3f})")
return best_match
else:
logger.info(f"未找到相似度 >= {similarity_threshold} 的记忆")
return None
except Exception as e:
logger.error(f"查找最佳匹配记忆时出错: {e}")
return None
def check_title_exists_fuzzy(target_title: str, similarity_threshold: float = 0.9) -> bool:
"""
检查标题是否已存在模糊匹配
Args:
target_title: 目标标题
similarity_threshold: 相似度阈值默认0.9较高阈值避免误判
Returns:
bool: 是否存在相似标题
"""
try:
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
exists = len(matches) > 0
if exists:
logger.info(f"发现相似标题: '{matches[0][0]}' (相似度: {matches[0][2]:.3f})")
else:
logger.debug("未发现相似标题")
return exists
except Exception as e:
logger.error(f"检查标题是否存在时出错: {e}")
return False
def get_memories_by_chat_id_weighted(target_chat_id: str, same_chat_weight: float = 0.95, other_chat_weight: float = 0.05) -> List[Tuple[str, str, str]]:
"""
根据chat_id进行加权抽样获取记忆列表
Args:
target_chat_id: 目标聊天ID
same_chat_weight: 同chat_id记忆的权重默认0.9595%概率
other_chat_weight: 其他chat_id记忆的权重默认0.055%概率
Returns:
List[Tuple[str, str, str]]: 选中的记忆列表每个元素为(title, content, chat_id)
"""
try:
# 获取所有记忆
all_memories = MemoryChestModel.select()
# 按chat_id分组
same_chat_memories = []
other_chat_memories = []
for memory in all_memories:
if memory.title and not memory.locked: # 排除锁定的记忆
if memory.chat_id == target_chat_id:
same_chat_memories.append((memory.title, memory.content, memory.chat_id))
else:
other_chat_memories.append((memory.title, memory.content, memory.chat_id))
# 如果没有同chat_id的记忆返回空列表
if not same_chat_memories:
logger.warning(f"未找到chat_id为 '{target_chat_id}' 的记忆")
return []
# 计算抽样数量
total_same = len(same_chat_memories)
total_other = len(other_chat_memories)
# 根据权重计算抽样数量
if total_other > 0:
# 计算其他chat_id记忆的抽样数量至少1个最多不超过总数的10%
other_sample_count = max(1, min(total_other, int(total_same * other_chat_weight / same_chat_weight)))
else:
other_sample_count = 0
# 随机抽样
selected_memories = []
# 选择同chat_id的记忆全部选择因为权重很高
selected_memories.extend(same_chat_memories)
# 随机选择其他chat_id的记忆
if other_sample_count > 0 and total_other > 0:
import random
other_selected = random.sample(other_chat_memories, min(other_sample_count, total_other))
selected_memories.extend(other_selected)
logger.info(f"加权抽样结果: 同chat_id记忆 {len(same_chat_memories)}其他chat_id记忆 {min(other_sample_count, total_other)}")
return selected_memories
except Exception as e:
logger.error(f"按chat_id加权抽样记忆时出错: {e}")
return []
def get_memory_titles_by_chat_id_weighted(target_chat_id: str, same_chat_weight: float = 0.95, other_chat_weight: float = 0.05) -> List[str]:
"""
根据chat_id进行加权抽样获取记忆标题列表用于合并选择
Args:
target_chat_id: 目标聊天ID
same_chat_weight: 同chat_id记忆的权重默认0.9595%概率
other_chat_weight: 其他chat_id记忆的权重默认0.055%概率
Returns:
List[str]: 选中的记忆标题列表
"""
try:
memories = get_memories_by_chat_id_weighted(target_chat_id, same_chat_weight, other_chat_weight)
titles = [memory[0] for memory in memories] # 提取标题
return titles
except Exception as e:
logger.error(f"按chat_id加权抽样记忆标题时出错: {e}")
return []

View File

@ -0,0 +1,97 @@
import time
import random
from typing import List, Optional, Tuple
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.common.database.database_model import MemoryConflict
from src.config.config import global_config
class QuestionMaker:
def __init__(self, chat_id: str, context: str = "") -> None:
"""问题生成器。
- chat_id: 会话 ID用于筛选该会话下的冲突记录
- context: 额外上下文可用于后续扩展
用法示例
>>> qm = QuestionMaker(chat_id="some_chat")
>>> question, chat_ctx, conflict_ctx = await qm.make_question()
"""
self.chat_id = chat_id
self.context = context
def get_context(self, timestamp: float = time.time()) -> str:
"""获取指定时间点之前的对话上下文字符串。"""
latest_30_msgs = get_raw_msg_before_timestamp_with_chat(
chat_id=self.chat_id,
timestamp=timestamp,
limit=30,
)
all_dialogue_prompt_str = build_readable_messages(
latest_30_msgs,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
)
return all_dialogue_prompt_str
async def get_all_conflicts(self) -> List[MemoryConflict]:
"""获取当前会话下的所有记忆冲突记录。"""
conflicts: List[MemoryConflict] = list(MemoryConflict.select().where(MemoryConflict.chat_id == self.chat_id))
return conflicts
async def get_un_answered_conflict(self) -> List[MemoryConflict]:
"""获取未回答的记忆冲突记录answer 为空)。"""
conflicts = await self.get_all_conflicts()
return [conflict for conflict in conflicts if not conflict.answer]
async def get_random_unanswered_conflict(self) -> Optional[MemoryConflict]:
"""按权重随机选取一个未回答的冲突并自增 raise_time。
选择规则
- 若存在 `raise_time == 0` 的项按权重抽样0 次权重 1.01 次权重 0.05
- 若不存在 `raise_time == 0` 的项 5% 概率返回其中任意一条否则返回 None
- 每次成功选中后将该条目的 `raise_time` 自增 1 并保存
"""
conflicts = await self.get_un_answered_conflict()
if not conflicts:
return None
# 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
if conflicts_with_zero:
# 权重规则raise_time == 0 -> 1.0raise_time >= 1 -> 0.01
weights = []
for conflict in conflicts:
current_raise_time = getattr(conflict, "raise_time", 0) or 0
weight = 1.0 if current_raise_time == 0 else 0.01
weights.append(weight)
# 按权重随机选择
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
# 选中后,自增 raise_time 并保存
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
chosen_conflict.save()
return chosen_conflict
async def make_question(self) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""生成一条用于询问用户的冲突问题与上下文。
返回三元组 (question, chat_context, conflict_context)
- question: 冲突文本若本次未选中任何冲突则为 None
- chat_context: 该冲突创建时间点前的会话上下文字符串若无则为 None
- conflict_context: 冲突在 DB 中存储的上下文若无则为 None
"""
conflict = await self.get_random_unanswered_conflict()
if not conflict:
return None, None, None
question = conflict.conflict_content
conflict_context = conflict.context
create_time = conflict.create_time
chat_context = self.get_context(create_time)
return question, chat_context, conflict_context

View File

@ -0,0 +1,478 @@
import time
import asyncio
from src.common.logger import get_logger
from src.common.database.database_model import MemoryConflict
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
build_readable_messages,
)
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from typing import List
from src.memory_system.memory_utils import parse_md_json
logger = get_logger("conflict_tracker")
class QuestionTracker:
"""
用于跟踪一个问题在后续聊天中的解答情况
"""
def __init__(self, question: str, chat_id: str, context: str = "") -> None:
self.question = question
self.chat_id = chat_id
now = time.time()
self.context = context
self.start_time = now
self.last_read_time = now
self.last_judge_time = now # 上次判定的时间
self.judge_debounce_interval = 10.0 # 判定防抖间隔10秒
self.consecutive_end_count = 0 # 连续END计数
self.active = True
# 将 LLM 实例作为类属性,使用 utils 模型
self.llm_request = LLMRequest(model_set=model_config.model_task_config.utils, request_type="conflict.judge")
def stop(self) -> None:
self.active = False
def should_judge_now(self) -> bool:
"""
检查是否应该进行判定防抖检查
Returns:
bool: 是否可以判定
"""
now = time.time()
# 检查是否已经过了10秒的防抖间隔
return (now - self.last_judge_time) >= self.judge_debounce_interval
def __eq__(self, other) -> bool:
"""比较两个追踪器是否相等基于问题内容和聊天ID"""
if not isinstance(other, QuestionTracker):
return False
return self.question == other.question and self.chat_id == other.chat_id
def __hash__(self) -> int:
"""为对象提供哈希值,支持集合操作"""
return hash((self.question, self.chat_id))
async def judge_answer(self, conversation_text: str,chat_len: int) -> tuple[bool, str, str]:
"""
使用模型判定问题是否已得到解答
Returns:
tuple[bool, str, str]: (是否结束跟踪, 结束原因或答案, 判定类型)
- True: 结束跟踪已解答话题转向等
- False: 继续跟踪
判定类型: "ANSWERED", "END", "CONTINUE"
"""
end_prompt = ""
if chat_len > 20:
end_prompt = "\n- 如果最新20条聊天记录内容与问题无关话题已转向其他方向请只输出END"
prompt = f"""你是一个严谨的判定器。下面给出聊天记录以及一个问题。
任务判断在这段聊天中该问题是否已经得到明确解答
**你必须严格按照聊天记录的内容不要添加额外的信息**
输出规则
- 如果聊天记录内容的信息已解答问题请只输出YES: <简短答案>{end_prompt}
- 如果问题尚未解答但聊天仍在相关话题上请只输出NO
**问题**
{self.question}
**聊天记录**
{conversation_text}
"""
if global_config.debug.show_prompt:
logger.info(f"判定提示词: {prompt}")
else:
logger.debug("已发送判定提示词")
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.5)
logger.info(f"判定结果: {prompt}\n{result_text}")
# 更新上次判定时间
self.last_judge_time = time.time()
if not result_text:
return False, "", "CONTINUE"
text = result_text.strip()
if text.upper().startswith("YES:"):
answer = text[4:].strip()
return True, answer, "ANSWERED"
if text.upper().startswith("YES"):
# 兼容仅输出 YES 或 YES <answer>
answer = text[3:].strip().lstrip(":").strip()
return True, answer, "ANSWERED"
if text.upper().startswith("END"):
# 聊天内容与问题无关,放弃该问题思考
return True, "话题已转向其他方向,放弃该问题思考", "END"
return False, "", "CONTINUE"
class ConflictTracker:
"""
记忆整合冲突追踪器
用于记录和存储记忆整合过程中的冲突内容
"""
def __init__(self):
self.question_tracker_list:List[QuestionTracker] = []
self.LLMRequest_tracker = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="conflict_tracker",
)
def get_questions_by_chat_id(self, chat_id: str) -> List[QuestionTracker]:
return [tracker for tracker in self.question_tracker_list if tracker.chat_id == chat_id]
async def track_conflict(self, question: str, context: str = "",start_following: bool = False,chat_id: str = "") -> bool:
"""
跟踪冲突内容
"""
tracker = QuestionTracker(question.strip(), chat_id, context)
self.question_tracker_list.append(tracker)
asyncio.create_task(self._follow_and_record(tracker, question.strip()))
return True
async def record_conflict(self, conflict_content: str, context: str = "",start_following: bool = False,chat_id: str = "") -> bool:
"""
记录冲突内容
Args:k
conflict_content: 冲突内容
Returns:
bool: 是否成功记录
"""
try:
if not conflict_content or conflict_content.strip() == "":
return False
# 若需要跟随后续消息以判断是否得到解答,则进入跟踪流程
if start_following and chat_id:
tracker = QuestionTracker(conflict_content.strip(), chat_id, context)
self.question_tracker_list.append(tracker)
# 后台启动跟踪任务,避免阻塞
asyncio.create_task(self._follow_and_record(tracker, conflict_content.strip()))
return True
# 默认:直接记录,不进行跟踪
MemoryConflict.create(
conflict_content=conflict_content,
create_time=time.time(),
update_time=time.time(),
answer="",
chat_id=chat_id,
)
logger.info(f"记录冲突内容: {len(conflict_content)} 字符")
return True
except Exception as e:
logger.error(f"记录冲突内容时出错: {e}")
return False
async def _follow_and_record(self, tracker: QuestionTracker, original_question: str) -> None:
"""
后台任务跟踪问题是否被解答并写入数据库
"""
try:
max_duration = 10 * 60 # 30 分钟
max_messages = 50 # 最多 100 条消息
poll_interval = 2.0 # 秒
logger.info(f"开始跟踪问题: {original_question}")
while tracker.active:
now_ts = time.time()
# 终止条件:时长达到上限
if now_ts - tracker.start_time >= max_duration:
logger.info("问题跟踪达到10分钟上限判定为未解答")
break
# 统计最近一段是否有新消息(不过滤机器人,过滤命令)
recent_msgs = get_raw_msg_by_timestamp_with_chat(
chat_id=tracker.chat_id,
timestamp_start=tracker.last_read_time,
timestamp_end=now_ts,
limit=30,
limit_mode="latest",
filter_bot=False,
filter_command=True,
)
if len(recent_msgs) > 0:
tracker.last_read_time = now_ts
# 统计从开始到现在的总消息数用于触发100条上限
all_msgs = get_raw_msg_by_timestamp_with_chat(
chat_id=tracker.chat_id,
timestamp_start=tracker.start_time,
timestamp_end=now_ts,
limit=0,
limit_mode="latest",
filter_bot=False,
filter_command=True,
)
# 检查是否应该进行判定(防抖检查)
if not tracker.should_judge_now():
logger.debug(f"判定防抖中,跳过本次判定: {tracker.question}")
await asyncio.sleep(poll_interval)
continue
# 构建可读聊天文本
chat_text = build_readable_messages(
all_msgs,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
truncate=False,
show_actions=False,
show_pic=False,
remove_emoji_stickers=True,
)
chat_len = len(all_msgs)
# 让小模型判断是否有答案
answered, answer_text, judge_type = await tracker.judge_answer(chat_text,chat_len)
if judge_type == "ANSWERED":
# 问题已解答,直接结束跟踪
logger.info("问题已得到解答,结束跟踪并写入答案")
await self.add_or_update_conflict(
conflict_content=tracker.question,
create_time=tracker.start_time,
update_time=time.time(),
answer=answer_text or "",
chat_id=tracker.chat_id,
)
return
elif judge_type == "END":
# 话题转向增加END计数
tracker.consecutive_end_count += 1
logger.info(f"话题已转向连续END次数: {tracker.consecutive_end_count}")
if tracker.consecutive_end_count >= 2:
# 连续两次END结束跟踪
logger.info("连续两次END结束跟踪")
break
else:
# 第一次END重置计数器并继续跟踪
logger.info("第一次END继续跟踪")
continue
elif judge_type == "CONTINUE":
# 继续跟踪重置END计数器
tracker.consecutive_end_count = 0
continue
if len(all_msgs) >= max_messages:
logger.info("问题跟踪达到100条消息上限判定为未解答")
logger.info(f"追踪结束:{tracker.question}")
break
# 无新消息时稍作等待
await asyncio.sleep(poll_interval)
# 未获取到答案,检查是否需要删除记录
# 查找现有的冲突记录
existing_conflict = MemoryConflict.get_or_none(
MemoryConflict.conflict_content == original_question,
MemoryConflict.chat_id == tracker.chat_id
)
if existing_conflict:
# 检查raise_time是否大于3且没有答案
current_raise_time = getattr(existing_conflict, "raise_time", 0) or 0
if current_raise_time > 0 and not existing_conflict.answer:
# 删除该条目
await self.delete_conflict(original_question, tracker.chat_id)
logger.info(f"追踪结束后删除条目(raise_time={current_raise_time}且无答案): {original_question}")
else:
# 更新记录但不删除
await self.add_or_update_conflict(
conflict_content=original_question,
create_time=existing_conflict.create_time,
update_time=time.time(),
answer="",
chat_id=tracker.chat_id,
)
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
else:
# 如果没有现有记录,创建新记录
await self.add_or_update_conflict(
conflict_content=original_question,
create_time=time.time(),
update_time=time.time(),
answer="",
chat_id=tracker.chat_id,
)
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
logger.info(f"问题跟踪结束:{original_question}")
except Exception as e:
logger.error(f"后台问题跟踪任务异常: {e}")
finally:
# 无论任务成功还是失败,都要从追踪列表中移除
tracker.stop()
self.remove_tracker(tracker)
def remove_tracker(self, tracker: QuestionTracker) -> None:
"""
从追踪列表中移除指定的追踪器
Args:
tracker: 要移除的追踪器对象
"""
try:
if tracker in self.question_tracker_list:
self.question_tracker_list.remove(tracker)
logger.info(f"已从追踪列表中移除追踪器: {tracker.question}")
else:
logger.warning(f"尝试移除不存在的追踪器: {tracker.question}")
except Exception as e:
logger.error(f"移除追踪器时出错: {e}")
async def add_or_update_conflict(
self,
conflict_content: str,
create_time: float,
update_time: float,
answer: str = "",
context: str = "",
chat_id: str = None
) -> bool:
"""
根据conflict_content匹配数据库内容如果找到相同的就更新update_time和answer
如果没有相同的就新建一条保存全部内容
"""
try:
# 尝试根据conflict_content查找现有记录
existing_conflict = MemoryConflict.get_or_none(
MemoryConflict.conflict_content == conflict_content,
MemoryConflict.chat_id == chat_id
)
if existing_conflict:
# 如果找到相同的conflict_content更新update_time和answer
existing_conflict.update_time = update_time
existing_conflict.answer = answer
existing_conflict.save()
return True
else:
# 如果没有找到相同的,创建新记录
MemoryConflict.create(
conflict_content=conflict_content,
create_time=create_time,
update_time=update_time,
answer=answer,
context=context,
chat_id=chat_id,
)
return True
except Exception as e:
# 记录错误并返回False
logger.error(f"添加或更新冲突记录时出错: {e}")
return False
async def record_memory_merge_conflict(self, part2_content: str, chat_id: str = None) -> bool:
"""
记录记忆整合过程中的冲突内容part2
Args:
part2_content: 冲突内容part2
Returns:
bool: 是否成功记录
"""
if not part2_content or part2_content.strip() == "":
return False
prompt = f"""以下是一段有冲突的信息,请你根据这些信息总结出几个具体的提问:
冲突信息
{part2_content}
要求
1.提问必须具体明确
2.提问最好涉及指向明确的事物而不是代称
3.如果缺少上下文不要强行提问可以忽略
请用json格式输出不要输出其他内容仅输出提问理由和具体提的提问:
**示例**
// 理由文本
```json
{{
"question":"提问",
}}
```
```json
{{
"question":"提问"
}}
```
...提问数量在1-3个之间不要重复现在请输出"""
question_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_tracker.generate_response_async(prompt)
# 解析JSON响应
questions, reasoning_content = parse_md_json(question_response)
print(prompt)
print(question_response)
for question in questions:
await self.record_conflict(
conflict_content=question["question"],
context=reasoning_content,
start_following=False,
chat_id=chat_id,
)
return True
async def get_conflict_count(self) -> int:
"""
获取冲突记录数量
Returns:
int: 记录数量
"""
try:
return MemoryConflict.select().count()
except Exception as e:
logger.error(f"获取冲突记录数量时出错: {e}")
return 0
async def delete_conflict(self, conflict_content: str, chat_id: str) -> bool:
"""
删除指定的冲突记录
Args:
conflict_content: 冲突内容
chat_id: 聊天ID
Returns:
bool: 是否成功删除
"""
try:
conflict = MemoryConflict.get_or_none(
MemoryConflict.conflict_content == conflict_content,
MemoryConflict.chat_id == chat_id
)
if conflict:
conflict.delete_instance()
logger.info(f"已删除冲突记录: {conflict_content}")
return True
else:
logger.warning(f"未找到要删除的冲突记录: {conflict_content}")
return False
except Exception as e:
logger.error(f"删除冲突记录时出错: {e}")
return False
# 全局冲突追踪器实例
global_conflict_tracker = ConflictTracker()

View File

@ -1,319 +0,0 @@
import json
import os
import asyncio
from src.common.database.database_model import GraphNodes
from src.common.logger import get_logger
logger = get_logger("migrate")
async def migrate_memory_items_to_string():
"""
将数据库中记忆节点的memory_items从list格式迁移到string格式
并根据原始list的项目数量设置weight值
"""
logger.info("开始迁移记忆节点格式...")
migration_stats = {
"total_nodes": 0,
"converted_nodes": 0,
"already_string_nodes": 0,
"empty_nodes": 0,
"error_nodes": 0,
"weight_updated_nodes": 0,
"truncated_nodes": 0,
}
try:
# 获取所有图节点
all_nodes = GraphNodes.select()
migration_stats["total_nodes"] = all_nodes.count()
logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点")
for node in all_nodes:
try:
concept = node.concept
memory_items_raw = node.memory_items.strip() if node.memory_items else ""
original_weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
# 如果为空,跳过
if not memory_items_raw:
migration_stats["empty_nodes"] += 1
logger.debug(f"跳过空节点: {concept}")
continue
try:
# 尝试解析JSON
parsed_data = json.loads(memory_items_raw)
if isinstance(parsed_data, list):
# 如果是list格式需要转换
if parsed_data:
# 转换为字符串格式
new_memory_items = " | ".join(str(item) for item in parsed_data)
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符")
new_weight = float(len(parsed_data)) # weight = list项目数量
# 更新数据库
node.memory_items = new_memory_items
node.weight = new_weight
node.save()
migration_stats["converted_nodes"] += 1
migration_stats["weight_updated_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.info(
f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}"
)
else:
# 空list设置为空字符串
node.memory_items = ""
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
logger.debug(f"转换空list节点: {concept}")
elif isinstance(parsed_data, str):
# 已经是字符串格式检查长度和weight
current_content = parsed_data
original_length = len(current_content)
content_truncated = False
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
content_truncated = True
migration_stats["truncated_nodes"] += 1
node.memory_items = current_content
logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符")
# 检查weight是否需要更新
update_needed = False
if original_weight == 1.0:
# 如果weight还是默认值可以根据内容复杂度估算
content_parts = (
current_content.split(" | ") if " | " in current_content else [current_content]
)
estimated_weight = max(1.0, float(len(content_parts)))
if estimated_weight != original_weight:
node.weight = estimated_weight
update_needed = True
logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}")
# 如果内容被截断或权重需要更新,保存到数据库
if content_truncated or update_needed:
node.save()
if update_needed:
migration_stats["weight_updated_nodes"] += 1
if content_truncated:
migration_stats["converted_nodes"] += 1 # 算作转换节点
else:
migration_stats["already_string_nodes"] += 1
else:
migration_stats["already_string_nodes"] += 1
else:
# 其他JSON类型转换为字符串
new_memory_items = str(parsed_data) if parsed_data else ""
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = new_memory_items
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"转换其他类型节点: {concept}{length_info}")
except json.JSONDecodeError:
# 不是JSON格式假设已经是纯字符串
# 检查是否是带引号的字符串
if memory_items_raw.startswith('"') and memory_items_raw.endswith('"'):
# 去掉引号
clean_content = memory_items_raw[1:-1]
original_length = len(clean_content)
# 检查长度并截断
if len(clean_content) > 100:
clean_content = clean_content[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = clean_content
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"去除引号节点: {concept}{length_info}")
else:
# 已经是纯字符串格式,检查长度
current_content = memory_items_raw
original_length = len(current_content)
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
node.memory_items = current_content
node.save()
migration_stats["converted_nodes"] += 1 # 算作转换节点
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符")
else:
migration_stats["already_string_nodes"] += 1
logger.debug(f"已是字符串格式节点: {concept}")
except Exception as e:
migration_stats["error_nodes"] += 1
logger.error(f"处理节点 {concept} 时发生错误: {e}")
continue
except Exception as e:
logger.error(f"迁移过程中发生严重错误: {e}")
raise
# 输出迁移统计
logger.info("=== 记忆节点迁移完成 ===")
logger.info(f"总节点数: {migration_stats['total_nodes']}")
logger.info(f"已转换节点: {migration_stats['converted_nodes']}")
logger.info(f"已是字符串格式: {migration_stats['already_string_nodes']}")
logger.info(f"空节点: {migration_stats['empty_nodes']}")
logger.info(f"错误节点: {migration_stats['error_nodes']}")
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
success_rate = (
(migration_stats["converted_nodes"] + migration_stats["already_string_nodes"])
/ migration_stats["total_nodes"]
* 100
if migration_stats["total_nodes"] > 0
else 0
)
logger.info(f"迁移成功率: {success_rate:.1f}%")
return migration_stats
async def set_all_person_known():
"""
将person_info库中所有记录的is_known字段设置为True
在设置之前先清理掉user_id或platform为空的记录
"""
logger.info("开始设置所有person_info记录为已认识...")
try:
from src.common.database.database_model import PersonInfo
# 获取所有PersonInfo记录
all_persons = PersonInfo.select()
total_count = all_persons.count()
logger.info(f"找到 {total_count} 个人员记录")
if total_count == 0:
logger.info("没有找到任何人员记录")
return {"total": 0, "deleted": 0, "updated": 0, "known_count": 0}
# 删除user_id或platform为空的记录
deleted_count = 0
invalid_records = PersonInfo.select().where(
(PersonInfo.user_id.is_null())
| (PersonInfo.user_id == "")
| (PersonInfo.platform.is_null())
| (PersonInfo.platform == "")
)
# 记录要删除的记录信息
for record in invalid_records:
user_id_info = f"'{record.user_id}'" if record.user_id else "NULL"
platform_info = f"'{record.platform}'" if record.platform else "NULL"
person_name_info = f"'{record.person_name}'" if record.person_name else "无名称"
logger.debug(
f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}"
)
# 执行删除操作
deleted_count = (
PersonInfo.delete()
.where(
(PersonInfo.user_id.is_null())
| (PersonInfo.user_id == "")
| (PersonInfo.platform.is_null())
| (PersonInfo.platform == "")
)
.execute()
)
if deleted_count > 0:
logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录")
else:
logger.info("没有发现user_id或platform为空的记录")
# 重新获取剩余记录数量
remaining_count = PersonInfo.select().count()
logger.info(f"清理后剩余 {remaining_count} 个有效记录")
if remaining_count == 0:
logger.info("清理后没有剩余记录")
return {"total": total_count, "deleted": deleted_count, "updated": 0, "known_count": 0}
# 批量更新剩余记录的is_known字段为True
updated_count = PersonInfo.update(is_known=True).execute()
logger.info(f"成功更新 {updated_count} 个人员记录的is_known字段为True")
# 验证更新结果
known_count = PersonInfo.select().where(PersonInfo.is_known).count()
result = {"total": total_count, "deleted": deleted_count, "updated": updated_count, "known_count": known_count}
logger.info("=== person_info更新完成 ===")
logger.info(f"原始记录数: {result['total']}")
logger.info(f"删除记录数: {result['deleted']}")
logger.info(f"更新记录数: {result['updated']}")
logger.info(f"已认识记录数: {result['known_count']}")
return result
except Exception as e:
logger.error(f"更新person_info过程中发生错误: {e}")
raise
async def check_and_run_migrations():
# 获取根目录
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
data_dir = os.path.join(project_root, "data")
temp_dir = os.path.join(data_dir, "temp")
done_file = os.path.join(temp_dir, "done.mem")
# 检查done.mem是否存在
if not os.path.exists(done_file):
# 如果temp目录不存在则创建
if not os.path.exists(temp_dir):
os.makedirs(temp_dir, exist_ok=True)
# 执行迁移函数
# 依次执行两个异步函数
await asyncio.sleep(3)
await migrate_memory_items_to_string()
await set_all_person_known()
# 创建done.mem文件
with open(done_file, "w", encoding="utf-8") as f:
f.write("done")

View File

@ -22,14 +22,15 @@ def init_prompt():
以上是群里正在进行的聊天记录
{identity_block}
你刚刚的情绪状态是{mood_state}
现在发送了消息引起了你的注意你对其进行了阅读和思考请你输出一句话描述你新的情绪状态
你先前的情绪状态是{mood_state}
你的情绪特点是:{emotion_style}
现在请你根据先前的情绪状态和现在的聊天内容总结推断你现在的情绪状态
请只输出新的情绪状态不要输出其他内容
""",
"change_mood_prompt",
"get_mood_prompt",
)
Prompt(
"""
{chat_talking_prompt}
@ -66,37 +67,16 @@ class ChatMood:
self.last_change_time: float = 0
async def update_mood_by_message(self, message: MessageRecv):
async def get_mood(self) -> str:
self.regression_count = 0
during_last_time = message.message_info.time - self.last_change_time # type: ignore
current_time = time.time()
base_probability = 0.05
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
# 基于消息长度计算基础兴趣度
message_length = len(message.processed_plain_text or "")
interest_multiplier = min(2.0, 1.0 + message_length / 100)
logger.debug(
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
)
update_probability = global_config.mood.mood_update_threshold * min(
1.0, base_probability * time_multiplier * interest_multiplier
)
if random.random() > update_probability:
return
logger.debug(
f"{self.log_prefix} 更新情绪状态,更新概率: {update_probability:.2f}"
)
message_time: float = message.message_info.time # type: ignore
logger.info(f"{self.log_prefix} 获取情绪状态")
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
timestamp_end=current_time,
limit=int(global_config.chat.max_context_size / 3),
limit_mode="last",
)
@ -119,11 +99,11 @@ class ChatMood:
identity_block = f"你的名字是{bot_name}{bot_nickname}"
prompt = await global_prompt_manager.format_prompt(
"change_mood_prompt",
"get_mood_prompt",
chat_talking_prompt=chat_talking_prompt,
identity_block=identity_block,
mood_state=self.mood_state,
emotion_style=global_config.personality.emotion_style,
emotion_style=global_config.mood.emotion_style,
)
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
@ -138,7 +118,9 @@ class ChatMood:
self.mood_state = response
self.last_change_time = message_time
self.last_change_time = current_time
return response
async def regress_mood(self):
message_time = time.time()
@ -172,7 +154,7 @@ class ChatMood:
chat_talking_prompt=chat_talking_prompt,
identity_block=identity_block,
mood_state=self.mood_state,
emotion_style=global_config.personality.emotion_style,
emotion_style=global_config.mood.emotion_style,
)
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
@ -222,7 +204,6 @@ class MoodManager:
if self.task_started:
return
logger.info("启动情绪回归任务...")
task = MoodRegressionTask(self)
await async_task_manager.add_task(task)
self.task_started = True

View File

@ -17,7 +17,9 @@ from src.config.config import global_config, model_config
logger = get_logger("person_info")
relation_selection_model = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="relation_selection")
relation_selection_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="relation_selection"
)
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
@ -91,9 +93,10 @@ def extract_categories_from_response(response: str) -> list[str]:
"""从response中提取所有<>包裹的内容"""
if not isinstance(response, str):
return []
import re
pattern = r'<([^<>]+)>'
pattern = r"<([^<>]+)>"
matches = re.findall(pattern, response)
return matches
@ -420,7 +423,7 @@ class Person:
except Exception as e:
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
async def build_relationship(self,chat_content:str = "",info_type = ""):
async def build_relationship(self, chat_content: str = "", info_type=""):
if not self.is_known:
return ""
# 构建points文本
@ -433,7 +436,7 @@ class Person:
points_text = ""
category_list = self.get_all_category()
if chat_content:
prompt = f"""当前聊天内容:
{chat_content}
@ -449,11 +452,13 @@ class Person:
# print(prompt)
# print(response)
category_list = extract_categories_from_response(response)
if "none" not in category_list:
if "none" not in category_list:
for category in category_list:
random_memory = self.get_random_memory_by_category(category, 2)
if random_memory:
random_memory_str = "\n".join([get_memory_content_from_memory(memory) for memory in random_memory])
random_memory_str = "\n".join(
[get_memory_content_from_memory(memory) for memory in random_memory]
)
points_text = f"有关 {category} 的内容:{random_memory_str}"
break
elif info_type:
@ -466,18 +471,19 @@ class Person:
<分类1><分类2><分类3>......
如果没有相关的分类请输出<none>"""
response, _ = await relation_selection_model.generate_response_async(prompt)
print(prompt)
print(response)
# print(prompt)
# print(response)
category_list = extract_categories_from_response(response)
if "none" not in category_list:
if "none" not in category_list:
for category in category_list:
random_memory = self.get_random_memory_by_category(category, 3)
if random_memory:
random_memory_str = "\n".join([get_memory_content_from_memory(memory) for memory in random_memory])
random_memory_str = "\n".join(
[get_memory_content_from_memory(memory) for memory in random_memory]
)
points_text = f"有关 {category} 的内容:{random_memory_str}"
break
else:
for category in category_list:
random_memory = self.get_random_memory_by_category(category, 1)[0]
if random_memory:

View File

@ -19,6 +19,7 @@ from src.plugin_system.apis import (
send_api,
tool_api,
frequency_api,
mood_api,
)
from .logging_api import get_logger
from .plugin_register_api import register_plugin
@ -40,4 +41,5 @@ __all__ = [
"register_plugin",
"tool_api",
"frequency_api",
"mood_api",
]

View File

@ -309,6 +309,7 @@ async def store_action_info(
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
action_reasoning: str = "",
) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
@ -322,7 +323,7 @@ async def store_action_info(
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
action_reasoning: 动作执行理由
Returns:
Dict[str, Any]: 保存的记录数据
None: 如果保存失败
@ -348,6 +349,7 @@ async def store_action_info(
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_reasoning": action_reasoning,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}

View File

@ -9,11 +9,14 @@
"""
import random
import base64
import os
import uuid
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Dict, Any
from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.utils.utils_image import image_path_to_base64
from src.chat.emoji_system.emoji_manager import get_emoji_manager, EMOJI_DIR
from src.chat.utils.utils_image import image_path_to_base64, base64_to_image
logger = get_logger("emoji_api")
@ -245,6 +248,42 @@ def get_emotions() -> List[str]:
return []
async def get_all() -> List[Tuple[str, str, str]]:
"""获取所有表情包
Returns:
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表
"""
try:
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
return []
results = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
emoji_base64 = image_path_to_base64(emoji_obj.full_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
continue
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
results.append((emoji_base64, emoji_obj.description, matched_emotion))
logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个表情包")
return results
except Exception as e:
logger.error(f"[EmojiAPI] 获取所有表情包失败: {e}")
return []
def get_descriptions() -> List[str]:
"""获取所有表情包描述
@ -264,3 +303,403 @@ def get_descriptions() -> List[str]:
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
return []
# =============================================================================
# 表情包注册API函数
# =============================================================================
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
"""注册新的表情包
Args:
image_base64: 图片的base64编码
filename: 可选的文件名如果未提供则自动生成
Returns:
Dict[str, Any]: 注册结果包含以下字段
- success: bool, 是否成功注册
- message: str, 结果消息
- description: Optional[str], 表情包描述成功时
- emotions: Optional[List[str]], 情感标签列表成功时
- replaced: Optional[bool], 是否替换了旧表情包成功时
- hash: Optional[str], 表情包哈希值成功时
Raises:
ValueError: 如果base64为空或无效
TypeError: 如果参数类型不正确
"""
if not image_base64:
raise ValueError("图片base64编码不能为空")
if not isinstance(image_base64, str):
raise TypeError("image_base64必须是字符串类型")
if filename is not None and not isinstance(filename, str):
raise TypeError("filename必须是字符串类型或None")
try:
logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}")
# 1. 获取emoji管理器并检查容量
emoji_manager = get_emoji_manager()
count_before = emoji_manager.emoji_num
max_count = emoji_manager.emoji_num_max
# 2. 检查是否可以注册(未达到上限或启用替换)
can_register = count_before < max_count or (
count_before >= max_count and emoji_manager.emoji_num_max_reach_deletion
)
if not can_register:
return {
"success": False,
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# 3. 确保emoji目录存在
os.makedirs(EMOJI_DIR, exist_ok=True)
# 4. 生成文件名
if not filename:
# 基于时间戳、微秒和短base64生成唯一文件名
import time
timestamp = int(time.time())
microseconds = int(time.time() * 1000000) % 1000000 # 添加微秒级精度
# 生成12位随机标识符使用base64编码增加随机性
import random
random_bytes = random.getrandbits(72).to_bytes(9, "big") # 72位 = 9字节 = 12位base64
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
# 确保base64编码适合文件名替换/和-
short_id = short_id.replace("/", "_").replace("+", "-")
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
# 确保文件名有扩展名
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
filename = f"{filename}.png" # 默认使用png格式
# 检查文件名是否已存在,如果存在则重新生成短标识符
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts = 0
max_attempts = 10
while os.path.exists(temp_file_path) and attempts < max_attempts:
# 重新生成短标识符
import random
random_bytes = random.getrandbits(48).to_bytes(6, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
# 分离文件名和扩展名,重新生成文件名
name_part, ext = os.path.splitext(filename)
# 去掉原来的标识符,添加新的
base_name = name_part.rsplit("_", 1)[0] # 移除最后一个_后的部分
filename = f"{base_name}_{short_id}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts += 1
# 如果还是冲突使用UUID作为备用方案
if os.path.exists(temp_file_path):
uuid_short = str(uuid.uuid4())[:8]
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{uuid_short}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
# 如果UUID方案也冲突添加序号
counter = 1
original_filename = filename
while os.path.exists(temp_file_path):
name_part, ext = os.path.splitext(original_filename)
filename = f"{name_part}_{counter}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter += 1
# 防止无限循环最多尝试100次
if counter > 100:
logger.error(f"[EmojiAPI] 无法生成唯一文件名,尝试次数过多: {original_filename}")
return {
"success": False,
"message": "无法生成唯一文件名,请稍后重试",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# 5. 保存base64图片到emoji目录
try:
# 解码base64并保存图片
if not base64_to_image(image_base64, temp_file_path):
logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}")
return {
"success": False,
"message": "无法保存图片文件",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
logger.debug(f"[EmojiAPI] 图片已保存到临时文件: {temp_file_path}")
except Exception as save_error:
logger.error(f"[EmojiAPI] 保存图片文件失败: {save_error}")
return {
"success": False,
"message": f"保存图片文件失败: {str(save_error)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# 6. 调用注册方法
register_success = await emoji_manager.register_emoji_by_filename(filename)
# 7. 清理临时文件(如果注册失败但文件还存在)
if not register_success and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"[EmojiAPI] 已清理临时文件: {temp_file_path}")
except Exception as cleanup_error:
logger.warning(f"[EmojiAPI] 清理临时文件失败: {cleanup_error}")
# 8. 构建返回结果
if register_success:
count_after = emoji_manager.emoji_num
replaced = count_after <= count_before # 如果数量没增加,说明是替换
# 尝试获取新注册的表情包信息
new_emoji_info = None
if count_after > count_before or replaced:
# 获取最新的表情包信息
try:
# 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变)
for emoji_obj in reversed(emoji_manager.emoji_objects):
if not emoji_obj.is_deleted and (
emoji_obj.filename == filename # 直接匹配
or (hasattr(emoji_obj, "full_path") and filename in emoji_obj.full_path) # 路径包含匹配
):
new_emoji_info = emoji_obj
break
except Exception as find_error:
logger.warning(f"[EmojiAPI] 查找新注册表情包信息失败: {find_error}")
description = new_emoji_info.description if new_emoji_info else None
emotions = new_emoji_info.emotion if new_emoji_info else None
emoji_hash = new_emoji_info.hash if new_emoji_info else None
return {
"success": True,
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
"description": description,
"emotions": emotions,
"replaced": replaced,
"hash": emoji_hash,
}
else:
return {
"success": False,
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
except Exception as e:
logger.error(f"[EmojiAPI] 注册表情包时发生异常: {e}")
return {
"success": False,
"message": f"注册过程中发生错误: {str(e)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
# =============================================================================
# 表情包删除API函数
# =============================================================================
async def delete_emoji(emoji_hash: str) -> Dict[str, Any]:
"""删除表情包
Args:
emoji_hash: 要删除的表情包的哈希值
Returns:
Dict[str, Any]: 删除结果包含以下字段
- success: bool, 是否成功删除
- message: str, 结果消息
- count_before: Optional[int], 删除前的表情包数量
- count_after: Optional[int], 删除后的表情包数量
- description: Optional[str], 被删除的表情包描述成功时
- emotions: Optional[List[str]], 被删除的表情包情感标签成功时
Raises:
ValueError: 如果哈希值为空
TypeError: 如果哈希值不是字符串类型
"""
if not emoji_hash:
raise ValueError("表情包哈希值不能为空")
if not isinstance(emoji_hash, str):
raise TypeError("emoji_hash必须是字符串类型")
try:
logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}")
# 1. 获取emoji管理器和删除前的数量
emoji_manager = get_emoji_manager()
count_before = emoji_manager.emoji_num
# 2. 获取被删除表情包的信息(用于返回结果)
try:
deleted_emoji = await emoji_manager.get_emoji_from_manager(emoji_hash)
description = deleted_emoji.description if deleted_emoji else None
emotions = deleted_emoji.emotion if deleted_emoji else None
except Exception as info_error:
logger.warning(f"[EmojiAPI] 获取被删除表情包信息失败: {info_error}")
description = None
emotions = None
# 3. 执行删除操作
delete_success = await emoji_manager.delete_emoji(emoji_hash)
# 4. 获取删除后的数量
count_after = emoji_manager.emoji_num
# 5. 构建返回结果
if delete_success:
return {
"success": True,
"message": f"表情包删除成功 (哈希: {emoji_hash[:8]}...)",
"count_before": count_before,
"count_after": count_after,
"description": description,
"emotions": emotions,
}
else:
return {
"success": False,
"message": "表情包删除失败,可能因为哈希值不存在或删除过程出错",
"count_before": count_before,
"count_after": count_after,
"description": None,
"emotions": None,
}
except Exception as e:
logger.error(f"[EmojiAPI] 删除表情包时发生异常: {e}")
return {
"success": False,
"message": f"删除过程中发生错误: {str(e)}",
"count_before": None,
"count_after": None,
"description": None,
"emotions": None,
}
async def delete_emoji_by_description(description: str, exact_match: bool = False) -> Dict[str, Any]:
"""根据描述删除表情包
Args:
description: 表情包描述文本
exact_match: 是否精确匹配描述False则为模糊匹配
Returns:
Dict[str, Any]: 删除结果包含以下字段
- success: bool, 是否成功删除
- message: str, 结果消息
- deleted_count: int, 删除的表情包数量
- deleted_hashes: List[str], 被删除的表情包哈希列表
- matched_count: int, 匹配到的表情包数量
Raises:
ValueError: 如果描述为空
TypeError: 如果描述不是字符串类型
"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("description必须是字符串类型")
try:
logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})")
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
# 筛选匹配的表情包
matching_emojis = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
if exact_match:
if emoji_obj.description == description:
matching_emojis.append(emoji_obj)
else:
if description.lower() in emoji_obj.description.lower():
matching_emojis.append(emoji_obj)
matched_count = len(matching_emojis)
if matched_count == 0:
return {
"success": False,
"message": f"未找到匹配描述 '{description}' 的表情包",
"deleted_count": 0,
"deleted_hashes": [],
"matched_count": 0,
}
# 删除匹配的表情包
deleted_count = 0
deleted_hashes = []
for emoji_obj in matching_emojis:
try:
delete_success = await emoji_manager.delete_emoji(emoji_obj.hash)
if delete_success:
deleted_count += 1
deleted_hashes.append(emoji_obj.hash)
except Exception as delete_error:
logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.hash}): {delete_error}")
# 构建返回结果
if deleted_count > 0:
return {
"success": True,
"message": f"成功删除 {deleted_count} 个表情包 (匹配到 {matched_count} 个)",
"deleted_count": deleted_count,
"deleted_hashes": deleted_hashes,
"matched_count": matched_count,
}
else:
return {
"success": False,
"message": f"匹配到 {matched_count} 个表情包,但删除全部失败",
"deleted_count": 0,
"deleted_hashes": [],
"matched_count": matched_count,
}
except Exception as e:
logger.error(f"[EmojiAPI] 根据描述删除表情包时发生异常: {e}")
return {
"success": False,
"message": f"删除过程中发生错误: {str(e)}",
"deleted_count": 0,
"deleted_hashes": [],
"matched_count": 0,
}

View File

@ -3,13 +3,14 @@ from src.chat.frequency_control.frequency_control import frequency_control_manag
logger = get_logger("frequency_api")
def get_current_talk_frequency(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(
chat_id
).set_talk_frequency_adjust(talk_frequency_adjust)
frequency_control_manager.get_or_create_frequency_control(chat_id).set_talk_frequency_adjust(talk_frequency_adjust)
def get_talk_frequency_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()

View File

@ -90,6 +90,7 @@ async def generate_reply(
enable_chinese_typo: bool = True,
request_type: str = "generator_api",
from_plugin: bool = True,
reply_time_point: Optional[float] = None,
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""生成回复
@ -109,6 +110,7 @@ async def generate_reply(
model_set_with_weight: 模型配置列表每个元素为 (TaskConfig, weight) 元组
request_type: 请求类型可选记录LLM使用
from_plugin: 是否来自插件
reply_time_point: 回复时间点
Returns:
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
"""
@ -136,6 +138,7 @@ async def generate_reply(
reply_reason=reply_reason,
from_plugin=from_plugin,
stream_id=chat_stream.stream_id if chat_stream else chat_id,
reply_time_point=reply_time_point,
)
if not success:
logger.warning("[GeneratorAPI] 回复生成失败")

Some files were not shown because too many files have changed in this diff Show More