mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'MaiM-with-u:dev' into dev
commit
d06f4005a6
|
|
@ -1,3 +1,2 @@
|
|||
*.bat text eol=crlf
|
||||
*.cmd text eol=crlf
|
||||
MaiLauncher.bat text eol=crlf working-tree-encoding=GBK
|
||||
*.cmd text eol=crlf
|
||||
|
|
@ -20,6 +20,8 @@ MaiBot-Napcat-Adapter
|
|||
nonebot-maibot-adapter/
|
||||
MaiMBot-LPMM
|
||||
*.zip
|
||||
run_bot.bat
|
||||
run_na.bat
|
||||
run.bat
|
||||
log_debug/
|
||||
run_amds.bat
|
||||
|
|
@ -41,16 +43,13 @@ config/bot_config.toml
|
|||
config/bot_config.toml.bak
|
||||
config/lpmm_config.toml
|
||||
config/lpmm_config.toml.bak
|
||||
src/mais4u/config/s4u_config.toml
|
||||
src/mais4u/config/old
|
||||
template/compare/bot_config_template.toml
|
||||
template/compare/model_config_template.toml
|
||||
(测试版)麦麦生成人格.bat
|
||||
(临时版)麦麦开始学习.bat
|
||||
src/plugins/utils/statistic.py
|
||||
CLAUDE.md
|
||||
s4u.s4u
|
||||
s4u.s4u1
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
|
@ -321,9 +320,14 @@ run_pet.bat
|
|||
/plugins/*
|
||||
!/plugins
|
||||
!/plugins/hello_world_plugin
|
||||
!/plugins/emoji_manage_plugin
|
||||
!/plugins/take_picture_plugin
|
||||
!/plugins/deep_think
|
||||
!/plugins/ChatFrequency/
|
||||
!/plugins/__init__.py
|
||||
|
||||
config.toml
|
||||
|
||||
interested_rates.txt
|
||||
MaiBot.code-workspace
|
||||
*.lock
|
||||
21
bot.py
21
bot.py
|
|
@ -5,16 +5,29 @@ import sys
|
|||
import time
|
||||
import platform
|
||||
import traceback
|
||||
import shutil
|
||||
from dotenv import load_dotenv
|
||||
from pathlib import Path
|
||||
from rich.traceback import install
|
||||
|
||||
if os.path.exists(".env"):
|
||||
load_dotenv(".env", override=True)
|
||||
env_path = Path(__file__).parent / ".env"
|
||||
template_env_path = Path(__file__).parent / "template" / "template.env"
|
||||
|
||||
if env_path.exists():
|
||||
load_dotenv(str(env_path), override=True)
|
||||
print("成功加载环境变量配置")
|
||||
else:
|
||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||
try:
|
||||
if template_env_path.exists():
|
||||
shutil.copyfile(template_env_path, env_path)
|
||||
print("未找到.env,已从 template/template.env 自动创建")
|
||||
load_dotenv(str(env_path), override=True)
|
||||
else:
|
||||
print("未找到.env文件,也未找到模板 template/template.env")
|
||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||
except Exception as e:
|
||||
print(f"自动创建 .env 失败: {e}")
|
||||
raise
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
||||
|
|
|
|||
|
|
@ -1,14 +1,27 @@
|
|||
# Changelog
|
||||
|
||||
0.10.4饼 表达方式优化
|
||||
无了
|
||||
## [0.11.0] - 2025-9-22
|
||||
### 🌟 主要功能更改
|
||||
- 重构记忆系统,新的记忆系统更可靠,记忆能力更强大
|
||||
- 麦麦好奇功能,麦麦会自主提出问题
|
||||
- 添加deepthink插件(默认关闭),让麦麦可以深度思考一些问题
|
||||
- 添加表情包管理插件
|
||||
|
||||
## [0.10.3] - 2025-9-1x
|
||||
### 细节功能更改
|
||||
- 修复配置文件转义问题
|
||||
- 情绪系统现在可以由配置文件控制开关
|
||||
- 修复平行动作控制失效的问题
|
||||
- 添加planner防抖,防止短时间快速消耗token
|
||||
- 修复吞字问题
|
||||
- 更新依赖表
|
||||
- 修复负载均衡
|
||||
- 优化了对gemini和不同模型的支持
|
||||
|
||||
## [0.10.3] - 2025-9-22
|
||||
### 🌟 主要功能更改
|
||||
- planner支持多动作,移除Sub_planner
|
||||
- 移除激活度系统,现在回复完全由planner控制
|
||||
- 现可自定义planner行为
|
||||
- 更丰富的聊天行为
|
||||
- 现可自定义planner行为,更优化的聊天频率控制
|
||||
- 支持发送转发和合并转发
|
||||
- 关系现在支持多人的信息
|
||||
- 更好的event系统,正式建立
|
||||
|
|
@ -20,6 +33,8 @@
|
|||
- 优化识图token限制
|
||||
- 为空回复添加重试机制
|
||||
- 加入brainchat模式,为私聊支持做准备
|
||||
- 修复qq号格式
|
||||
|
||||
|
||||
|
||||
## [0.10.2] - 2025-8-31
|
||||
|
|
|
|||
|
|
@ -1,51 +0,0 @@
|
|||
# Changelog
|
||||
|
||||
## [1.0.3] - 2025-3-31
|
||||
### Added
|
||||
- 新增了心流相关配置项:
|
||||
- `heartflow` 配置项,用于控制心流功能
|
||||
|
||||
### Removed
|
||||
- 移除了 `response` 配置项中的 `model_r1_probability` 和 `model_v3_probability` 选项
|
||||
- 移除了次级推理模型相关配置
|
||||
|
||||
## [1.0.1] - 2025-3-30
|
||||
### Added
|
||||
- 增加了流式输出控制项 `stream`
|
||||
- 修复 `LLM_Request` 不会自动为 `payload` 增加流式输出标志的问题
|
||||
|
||||
## [1.0.0] - 2025-3-30
|
||||
### Added
|
||||
- 修复了错误的版本命名
|
||||
- 杀掉了所有无关文件
|
||||
|
||||
## [0.0.11] - 2025-3-12
|
||||
### Added
|
||||
- 新增了 `schedule` 配置项,用于配置日程表生成功能
|
||||
- 新增了 `response_splitter` 配置项,用于控制回复分割
|
||||
- 新增了 `experimental` 配置项,用于实验性功能开关
|
||||
- 新增了 `llm_observation` 和 `llm_sub_heartflow` 模型配置
|
||||
- 新增了 `llm_heartflow` 模型配置
|
||||
- 在 `personality` 配置项中新增了 `prompt_schedule_gen` 参数
|
||||
|
||||
### Changed
|
||||
- 优化了模型配置的组织结构
|
||||
- 调整了部分配置项的默认值
|
||||
- 调整了配置项的顺序,将 `groups` 配置项移到了更靠前的位置
|
||||
- 在 `message` 配置项中:
|
||||
- 新增了 `model_max_output_length` 参数
|
||||
- 在 `willing` 配置项中新增了 `emoji_response_penalty` 参数
|
||||
- 将 `personality` 配置项中的 `prompt_schedule` 重命名为 `prompt_schedule_gen`
|
||||
|
||||
### Removed
|
||||
- 移除了 `min_text_length` 配置项
|
||||
- 移除了 `cq_code` 配置项
|
||||
- 移除了 `others` 配置项(其功能已整合到 `experimental` 中)
|
||||
|
||||
## [0.0.5] - 2025-3-11
|
||||
### Added
|
||||
- 新增了 `alias_names` 配置项,用于指定麦麦的别名。
|
||||
|
||||
## [0.0.4] - 2025-3-9
|
||||
### Added
|
||||
- 新增了 `memory_ban_words` 配置项,用于指定不希望记忆的词汇。
|
||||
|
|
@ -28,7 +28,7 @@ version = "1.1.1"
|
|||
```toml
|
||||
[[api_providers]]
|
||||
name = "DeepSeek" # 服务商名称(自定义)
|
||||
base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
|
||||
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
|
||||
api_key = "your-api-key-here" # API密钥
|
||||
client_type = "openai" # 客户端类型
|
||||
max_retry = 2 # 最大重试次数
|
||||
|
|
@ -43,19 +43,19 @@ retry_interval = 10 # 重试间隔(秒)
|
|||
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
|
||||
| `base_url` | ✅ | API服务的基础URL | - |
|
||||
| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
|
||||
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式,现在支持不良好) | `openai` |
|
||||
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式) | `openai` |
|
||||
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
|
||||
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
||||
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
||||
|
||||
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。**
|
||||
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
|
||||
### 2.3 支持的服务商示例
|
||||
|
||||
#### DeepSeek
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "DeepSeek"
|
||||
base_url = "https://api.deepseek.cn/v1"
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
api_key = "your-deepseek-api-key"
|
||||
client_type = "openai"
|
||||
```
|
||||
|
|
@ -73,7 +73,7 @@ client_type = "openai"
|
|||
```toml
|
||||
[[api_providers]]
|
||||
name = "Google"
|
||||
base_url = "https://api.google.com/v1"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
api_key = "your-google-api-key"
|
||||
client_type = "gemini" # 注意:Gemini需要使用特殊客户端
|
||||
```
|
||||
|
|
@ -131,9 +131,20 @@ enable_thinking = false # 禁用思考
|
|||
[models.extra_params]
|
||||
thinking = {type = "disabled"} # 禁用思考
|
||||
```
|
||||
|
||||
而对于`gemini`需要单独进行配置
|
||||
```toml
|
||||
[[models]]
|
||||
model_identifier = "gemini-2.5-flash"
|
||||
name = "gemini-2.5-flash"
|
||||
api_provider = "Google"
|
||||
[models.extra_params]
|
||||
thinking_budget = 0 # 禁用思考
|
||||
# thinking_budget = -1 由模型自己决定
|
||||
```
|
||||
|
||||
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
||||
|
||||
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
|
||||
### 3.3 配置参数说明
|
||||
|
||||
| 参数 | 必填 | 说明 |
|
||||
|
|
|
|||
57
flake.lock
57
flake.lock
|
|
@ -1,57 +0,0 @@
|
|||
{
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 0,
|
||||
"narHash": "sha256-nJj8f78AYAxl/zqLiFGXn5Im1qjFKU8yBPKoWEeZN5M=",
|
||||
"path": "/nix/store/f30jn7l0bf7a01qj029fq55i466vmnkh-source",
|
||||
"type": "path"
|
||||
},
|
||||
"original": {
|
||||
"id": "nixpkgs",
|
||||
"type": "indirect"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"nixpkgs": "nixpkgs",
|
||||
"utils": "utils"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
39
flake.nix
39
flake.nix
|
|
@ -1,39 +0,0 @@
|
|||
{
|
||||
description = "MaiMBot Nix Dev Env";
|
||||
|
||||
inputs = {
|
||||
utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
outputs = {
|
||||
self,
|
||||
nixpkgs,
|
||||
utils,
|
||||
...
|
||||
}:
|
||||
utils.lib.eachDefaultSystem (system: let
|
||||
pkgs = import nixpkgs {inherit system;};
|
||||
pythonPackages = pkgs.python3Packages;
|
||||
in {
|
||||
devShells.default = pkgs.mkShell {
|
||||
name = "python-venv";
|
||||
venvDir = "./.venv";
|
||||
buildInputs = with pythonPackages; [
|
||||
python
|
||||
venvShellHook
|
||||
scipy
|
||||
numpy
|
||||
];
|
||||
|
||||
postVenvCreation = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
pip install -r requirements.txt
|
||||
'';
|
||||
|
||||
postShellHook = ''
|
||||
# allow pip to install wheels
|
||||
unset SOURCE_DATE_EPOCH
|
||||
'';
|
||||
};
|
||||
});
|
||||
}
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
# BetterFrequency 频率控制插件
|
||||
|
||||
这是一个用于控制MaiBot聊天频率的插件,支持实时调整talk_frequency参数。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 💬 **Talk Frequency控制**: 调整机器人的发言频率
|
||||
- 📊 **状态显示**: 实时查看当前频率控制状态
|
||||
- ⚡ **实时生效**: 设置后立即生效,无需重启
|
||||
- 💾 **不保存消息**: 命令执行反馈不会保存到数据库
|
||||
- 🚀 **简化命令**: 支持完整命令和简化命令两种形式
|
||||
|
||||
## 命令列表
|
||||
|
||||
### 1. 设置Talk Frequency
|
||||
```
|
||||
/chat talk_frequency <数字> # 完整命令
|
||||
/chat t <数字> # 简化命令
|
||||
```
|
||||
- 功能:设置当前聊天的talk_frequency调整值
|
||||
- 参数:支持0到1之间的数值
|
||||
- 示例:
|
||||
- `/chat talk_frequency 1.0` 或 `/chat t 1.0` - 设置发言频率调整为1.0(最高频率)
|
||||
- `/chat talk_frequency 0.5` 或 `/chat t 0.5` - 设置发言频率调整为0.5
|
||||
- `/chat talk_frequency 0.0` 或 `/chat t 0.0` - 设置发言频率调整为0.0(最低频率)
|
||||
|
||||
### 2. 显示当前状态
|
||||
```
|
||||
/chat show # 完整命令
|
||||
/chat s # 简化命令
|
||||
```
|
||||
- 功能:显示当前聊天的频率控制状态
|
||||
- 显示内容:
|
||||
- 当前talk_frequency值
|
||||
- 可用命令提示(包含简化命令)
|
||||
|
||||
## 配置说明
|
||||
|
||||
插件配置文件 `config.toml` 包含以下选项:
|
||||
|
||||
```toml
|
||||
[plugin]
|
||||
name = "better_frequency_plugin"
|
||||
version = "1.0.0"
|
||||
enabled = true
|
||||
|
||||
[frequency]
|
||||
default_talk_adjust = 1.0 # 默认talk_frequency调整值
|
||||
max_adjust_value = 1.0 # 最大调整值
|
||||
min_adjust_value = 0.0 # 最小调整值
|
||||
```
|
||||
|
||||
## 使用场景
|
||||
|
||||
- **提高机器人活跃度**: 设置较高的talk_frequency值(接近1.0)
|
||||
- **降低机器人活跃度**: 设置较低的talk_frequency值(接近0.0)
|
||||
- **精细调节**: 使用小数进行微调
|
||||
- **实时监控**: 通过show命令查看当前状态
|
||||
- **快速操作**: 使用简化命令提高操作效率
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 调整值会立即生效,影响当前聊天的机器人行为
|
||||
2. 命令执行反馈消息不会保存到数据库
|
||||
3. 支持0到1之间的数值
|
||||
4. 每个聊天都有独立的频率控制设置
|
||||
5. 简化命令和完整命令功能完全相同,可根据个人习惯选择
|
||||
|
||||
## 技术实现
|
||||
|
||||
- 基于MaiCore插件系统开发
|
||||
- 使用frequency_api进行频率控制操作
|
||||
- 使用send_api发送反馈消息
|
||||
- 支持异步操作和错误处理
|
||||
- 正则表达式支持多种命令格式
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"manifest_version": 1,
|
||||
"name": "发言频率控制插件|BetterFrequency Plugin",
|
||||
"version": "2.0.0",
|
||||
"description": "控制聊天频率,支持设置focus_value和talk_frequency调整值,提供完整命令和简化命令",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.3"
|
||||
},
|
||||
"homepage_url": "https://github.com/SengokuCola/BetterFrequency",
|
||||
"repository_url": "https://github.com/SengokuCola/BetterFrequency",
|
||||
"keywords": ["frequency", "control", "talk_frequency", "plugin", "shortcut"],
|
||||
"categories": ["Chat", "Frequency", "Control"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "frequency",
|
||||
"components": [
|
||||
{
|
||||
"type": "command",
|
||||
"name": "set_talk_frequency",
|
||||
"description": "设置当前聊天的talk_frequency调整值",
|
||||
"pattern": "/chat talk_frequency <数字> 或 /chat t <数字>"
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "frequency_adjust",
|
||||
"description": "调整当前聊天的发言频率",
|
||||
"pattern": "/chat frequency_adjust <数字> 或 /chat f <数字>"
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "show_frequency",
|
||||
"description": "显示当前聊天的频率控制状态",
|
||||
"pattern": "/chat show 或 /chat s"
|
||||
}
|
||||
],
|
||||
"features": [
|
||||
"设置talk_frequency调整值",
|
||||
"调整当前聊天的发言频率",
|
||||
"显示当前频率控制状态",
|
||||
"实时频率控制调整",
|
||||
"命令执行反馈(不保存消息)",
|
||||
"支持完整命令和简化命令",
|
||||
"快速操作支持"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
from typing import Tuple
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BaseAction, ActionActivationType
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入API模块
|
||||
from src.plugin_system.apis import frequency_api, send_api, config_api, generator_api
|
||||
|
||||
logger = get_logger("frequency_adjust")
|
||||
|
||||
|
||||
class FrequencyAdjustAction(BaseAction):
|
||||
"""频率调节动作 - 调整聊天发言频率"""
|
||||
|
||||
activation_type = ActionActivationType.LLM_JUDGE
|
||||
parallel_action = False
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "frequency_adjust"
|
||||
|
||||
action_description = "调整当前聊天的发言频率"
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
"direction": "调整方向:'increase'(增加)或'decrease'(降低)",
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
bot_name = config_api.get_global_config("bot.nickname")
|
||||
|
||||
|
||||
action_require = [
|
||||
f"当用户提到 {bot_name} 太安静或太活跃时使用",
|
||||
f"有人提到 {bot_name} 的发言太多或太少",
|
||||
f"需要根据聊天氛围调整 {bot_name} 的活跃度",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行频率调节动作"""
|
||||
try:
|
||||
# 1. 获取动作参数
|
||||
direction = self.action_data.get("direction")
|
||||
# multiply = 1.2
|
||||
# multiply = self.action_data.get("multiply")
|
||||
|
||||
if not direction:
|
||||
error_msg = "缺少必要的参数:direction或multiply"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# 2. 获取当前频率值
|
||||
current_frequency = frequency_api.get_current_talk_frequency(self.chat_id)
|
||||
|
||||
# 3. 计算新的频率值(使用比率而不是绝对值)
|
||||
# calculated_frequency = current_frequency * multiply
|
||||
if direction == "increase":
|
||||
calculated_frequency = current_frequency * 1.2
|
||||
if calculated_frequency > 1.0:
|
||||
new_frequency = 1.0
|
||||
action_desc = f"增加到最大值"
|
||||
# 记录超出限制的action
|
||||
logger.warning(f"{self.log_prefix} 尝试调整频率超出最大值: current={current_frequency:.2f}, calculated={calculated_frequency:.2f}")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"你尝试调整发言频率到{calculated_frequency:.2f},但最大值只能为1.0,已设置为最大值",
|
||||
action_done=True,
|
||||
)
|
||||
return True, f"调整发言频率超出限制: {current_frequency:.2f} → {new_frequency:.2f}"
|
||||
else:
|
||||
new_frequency = calculated_frequency
|
||||
action_desc = f"增加"
|
||||
elif direction == "decrease":
|
||||
calculated_frequency = current_frequency * 0.8
|
||||
new_frequency = max(0.0, calculated_frequency)
|
||||
action_desc = f"降低"
|
||||
else:
|
||||
error_msg = f"无效的调整方向: {direction}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# 4. 设置新的频率值
|
||||
frequency_api.set_talk_frequency_adjust(self.chat_id, new_frequency)
|
||||
|
||||
# 5. 发送反馈消息
|
||||
feedback_msg = f"已{action_desc}发言频率:{current_frequency:.2f} → {new_frequency:.2f}"
|
||||
result_status, data = await generator_api.rewrite_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_data={
|
||||
"raw_reply": feedback_msg,
|
||||
"reason": "表达自己已经调整了发言频率,不一定要说具体数值,可以有趣一些",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if result_status:
|
||||
for reply_seg in data.reply_set.reply_data:
|
||||
send_data = reply_seg.content
|
||||
await self.send_text(send_data)
|
||||
logger.info(f"{self.log_prefix} {send_data}")
|
||||
|
||||
# 6. 存储动作信息(仅在未超出限制时)
|
||||
if calculated_frequency <= 1.0:
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"你{action_desc}了发言频率,从{current_frequency:.2f}调整到{new_frequency:.2f}",
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
return True, f"成功调整发言频率: {current_frequency:.2f} → {new_frequency:.2f}"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"频率调节失败: {str(e)}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}", exc_info=True)
|
||||
await self.send_text("频率调节失败")
|
||||
return False, error_msg
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
from typing import List, Tuple, Type, Any, Optional
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
BaseCommand,
|
||||
ComponentInfo,
|
||||
ConfigField
|
||||
)
|
||||
from src.plugin_system.apis import send_api, frequency_api
|
||||
from .frequency_adjust_action import FrequencyAdjustAction
|
||||
|
||||
|
||||
class SetTalkFrequencyCommand(BaseCommand):
|
||||
"""设置当前聊天的talk_frequency值"""
|
||||
command_name = "set_talk_frequency"
|
||||
command_description = "设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>"
|
||||
command_pattern = r"^/chat\s+(?:talk_frequency|t)\s+(?P<value>[+-]?\d*\.?\d+)$"
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
try:
|
||||
# 获取命令参数 - 使用命名捕获组
|
||||
if not self.matched_groups or "value" not in self.matched_groups:
|
||||
return False, "命令格式错误", False
|
||||
|
||||
value_str = self.matched_groups["value"]
|
||||
if not value_str:
|
||||
return False, "无法获取数值参数", False
|
||||
|
||||
value = float(value_str)
|
||||
|
||||
# 获取聊天流ID
|
||||
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
|
||||
return False, "无法获取聊天流信息", False
|
||||
|
||||
chat_id = self.message.chat_stream.stream_id
|
||||
|
||||
# 设置talk_frequency
|
||||
frequency_api.set_talk_frequency_adjust(chat_id, value)
|
||||
|
||||
# 发送反馈消息(不保存到数据库)
|
||||
await send_api.text_to_stream(
|
||||
f"已设置当前聊天的talk_frequency调整值为: {value}",
|
||||
chat_id,
|
||||
storage_message=False
|
||||
)
|
||||
|
||||
return True, None, False
|
||||
|
||||
except ValueError:
|
||||
error_msg = "数值格式错误,请输入有效的数字"
|
||||
await self.send_text(error_msg, storage_message=False)
|
||||
return False, error_msg, False
|
||||
except Exception as e:
|
||||
error_msg = f"设置talk_frequency失败: {str(e)}"
|
||||
await self.send_text(error_msg, storage_message=False)
|
||||
return False, error_msg, False
|
||||
|
||||
|
||||
class ShowFrequencyCommand(BaseCommand):
|
||||
"""显示当前聊天的频率控制状态"""
|
||||
command_name = "show_frequency"
|
||||
command_description = "显示当前聊天的频率控制状态:/chat show 或 /chat s"
|
||||
command_pattern = r"^/chat\s+(?:show|s)$"
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
try:
|
||||
# 获取聊天流ID
|
||||
if not self.message.chat_stream or not hasattr(self.message.chat_stream, "stream_id"):
|
||||
return False, "无法获取聊天流信息", False
|
||||
|
||||
chat_id = self.message.chat_stream.stream_id
|
||||
|
||||
# 获取当前频率控制状态
|
||||
current_talk_frequency = frequency_api.get_current_talk_frequency(chat_id)
|
||||
talk_frequency_adjust = frequency_api.get_talk_frequency_adjust(chat_id)
|
||||
|
||||
# 构建显示消息
|
||||
status_msg = f"""当前聊天频率控制状态
|
||||
Talk Frequency (发言频率):
|
||||
• 当前值: {current_talk_frequency:.2f}
|
||||
|
||||
使用命令:
|
||||
• /chat talk_frequency <数字> 或 /chat t <数字> - 设置发言频率调整
|
||||
• /chat show 或 /chat s - 显示当前状态"""
|
||||
|
||||
# 发送状态消息(不保存到数据库)
|
||||
await send_api.text_to_stream(status_msg, chat_id, storage_message=False)
|
||||
|
||||
return True, None, False
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取频率控制状态失败: {str(e)}"
|
||||
# 使用内置的send_text方法发送错误消息
|
||||
await self.send_text(error_msg, storage_message=False)
|
||||
return False, error_msg, False
|
||||
|
||||
|
||||
# ===== 插件注册 =====
|
||||
|
||||
|
||||
@register_plugin
|
||||
class BetterFrequencyPlugin(BasePlugin):
|
||||
"""BetterFrequency插件 - 控制聊天频率的插件"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "better_frequency_plugin"
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = []
|
||||
python_dependencies: List[str] = []
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息",
|
||||
"frequency": "频率控制配置"
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
},
|
||||
"frequency": {
|
||||
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
|
||||
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
|
||||
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
|
||||
}
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
return [
|
||||
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
|
||||
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
|
||||
(FrequencyAdjustAction.get_action_info(), FrequencyAdjustAction),
|
||||
]
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
{
|
||||
"manifest_version": 1,
|
||||
"name": "Deep Think插件 (Deep Think Actions)",
|
||||
"version": "1.0.0",
|
||||
"description": "可以深度思考",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.11.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": ["deep", "think", "action", "built-in"],
|
||||
"categories": ["Deep Think"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "action_provider",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "deep_think",
|
||||
"description": "发送深度思考"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
from typing import List, Tuple, Type, Any
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.base_tool import BaseTool, ToolParamType
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.plugins.built_in.relation.relation import BuildRelationAction
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
logger = get_logger("relation_actions")
|
||||
|
||||
|
||||
|
||||
class DeepThinkTool(BaseTool):
|
||||
"""获取用户信息"""
|
||||
|
||||
name = "deep_think"
|
||||
description = "深度思考,对某个知识,概念或逻辑问题进行全面且深入的思考,当面临复杂环境或重要问题时,使用此获得更好的解决方案。"
|
||||
parameters = [
|
||||
("question", ToolParamType.STRING, "需要思考的问题,越具体越好(从上下文中总结)", True, None),
|
||||
]
|
||||
|
||||
available_for_llm = True
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
question: str = function_args.get("question") # type: ignore
|
||||
|
||||
print(f"question: {question}")
|
||||
|
||||
prompt = f"""
|
||||
请你思考以下问题,以简洁的一段话回答:
|
||||
{question}
|
||||
"""
|
||||
|
||||
models = llm_api.get_available_models()
|
||||
chat_model_config = models.get("replyer") # 使用字典访问方式
|
||||
|
||||
success, thinking_result, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="deep_think"
|
||||
)
|
||||
|
||||
logger.info(f"{question}: {thinking_result}")
|
||||
|
||||
thinking_result =f"思考结果:{thinking_result}\n**注意** 因为你进行了深度思考,最后的回复内容可以回复的长一些,更加详细一些,不用太简洁。\n"
|
||||
|
||||
return {"content": thinking_result}
|
||||
|
||||
|
||||
@register_plugin
|
||||
class DeepThinkPlugin(BasePlugin):
|
||||
"""关系动作插件
|
||||
|
||||
系统内置插件,提供基础的聊天交互功能:
|
||||
- Reply: 回复动作
|
||||
- NoReply: 不回复动作
|
||||
- Emoji: 表情动作
|
||||
|
||||
注意:插件基本信息优先从_manifest.json文件中读取
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "deep_think" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件启用配置",
|
||||
"components": "核心组件启用配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),
|
||||
}
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
components.append((DeepThinkTool.get_tool_info(), DeepThinkTool))
|
||||
|
||||
return components
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
{
|
||||
"manifest_version": 1,
|
||||
"name": "BetterEmoji",
|
||||
"version": "1.0.0",
|
||||
"description": "更好的表情包管理插件",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/SengokuCola"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.4"
|
||||
},
|
||||
"homepage_url": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"repository_url": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"keywords": ["emoji", "manage", "plugin"],
|
||||
"categories": ["Examples", "Tutorial"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "emoji_manage",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "hello_greeting",
|
||||
"description": "向用户发送问候消息"
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "bye_greeting",
|
||||
"description": "向用户发送告别消息",
|
||||
"activation_modes": ["keyword"],
|
||||
"keywords": ["再见", "bye", "88", "拜拜"]
|
||||
},
|
||||
{
|
||||
"type": "command",
|
||||
"name": "time",
|
||||
"description": "查询当前时间",
|
||||
"pattern": "/time"
|
||||
}
|
||||
],
|
||||
"features": [
|
||||
"问候和告别功能",
|
||||
"时间查询命令",
|
||||
"配置文件示例",
|
||||
"新手教程代码"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,399 @@
|
|||
from typing import List, Tuple, Type
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
BaseCommand,
|
||||
ComponentInfo,
|
||||
ConfigField,
|
||||
ReplyContentType,
|
||||
emoji_api,
|
||||
)
|
||||
from maim_message import Seg
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("emoji_manage_plugin")
|
||||
|
||||
|
||||
class AddEmojiCommand(BaseCommand):
|
||||
command_name = "add_emoji"
|
||||
command_description = "添加表情包"
|
||||
command_pattern = r".*/emoji add.*"
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
# 查找消息中的表情包
|
||||
# logger.info(f"查找消息中的表情包: {self.message.message_segment}")
|
||||
|
||||
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
|
||||
|
||||
if not emoji_base64_list:
|
||||
return False, "未在消息中找到表情包或图片", False
|
||||
|
||||
# 注册找到的表情包
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
results = []
|
||||
|
||||
for i, emoji_base64 in enumerate(emoji_base64_list):
|
||||
try:
|
||||
# 使用emoji_api注册表情包(让API自动生成唯一文件名)
|
||||
result = await emoji_api.register_emoji(emoji_base64)
|
||||
|
||||
if result["success"]:
|
||||
success_count += 1
|
||||
description = result.get("description", "未知描述")
|
||||
emotions = result.get("emotions", [])
|
||||
replaced = result.get("replaced", False)
|
||||
|
||||
result_msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
|
||||
if description:
|
||||
result_msg += f"\n描述: {description}"
|
||||
if emotions:
|
||||
result_msg += f"\n情感标签: {', '.join(emotions)}"
|
||||
|
||||
results.append(result_msg)
|
||||
else:
|
||||
fail_count += 1
|
||||
error_msg = result.get("message", "注册失败")
|
||||
results.append(f"表情包 {i + 1} 注册失败: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
fail_count += 1
|
||||
results.append(f"表情包 {i + 1} 注册时发生错误: {str(e)}")
|
||||
|
||||
# 构建返回消息
|
||||
total_count = success_count + fail_count
|
||||
summary_msg = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count} 个"
|
||||
|
||||
# 如果有结果详情,添加到返回消息中
|
||||
details_msg = ""
|
||||
if results:
|
||||
details_msg = "\n" + "\n".join(results)
|
||||
final_msg = summary_msg + details_msg
|
||||
else:
|
||||
final_msg = summary_msg
|
||||
|
||||
# 使用表达器重写回复
|
||||
try:
|
||||
from src.plugin_system.apis import generator_api
|
||||
|
||||
# 构建重写数据
|
||||
rewrite_data = {
|
||||
"raw_reply": summary_msg,
|
||||
"reason": f"注册了表情包:{details_msg}\n",
|
||||
}
|
||||
|
||||
# 调用表达器重写
|
||||
result_status, data = await generator_api.rewrite_reply(
|
||||
chat_stream=self.message.chat_stream,
|
||||
reply_data=rewrite_data,
|
||||
)
|
||||
|
||||
if result_status:
|
||||
# 发送重写后的回复
|
||||
for reply_seg in data.reply_set.reply_data:
|
||||
send_data = reply_seg.content
|
||||
await self.send_text(send_data)
|
||||
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
else:
|
||||
# 如果重写失败,发送原始消息
|
||||
await self.send_text(final_msg)
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
|
||||
except Exception as e:
|
||||
# 如果表达器调用失败,发送原始消息
|
||||
logger.error(f"[add_emoji] 表达器重写失败: {e}")
|
||||
await self.send_text(final_msg)
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
|
||||
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
|
||||
emoji_base64_list = []
|
||||
|
||||
# 处理单个Seg对象的情况
|
||||
if isinstance(message_segments, Seg):
|
||||
if message_segments.type == "emoji":
|
||||
emoji_base64_list.append(message_segments.data)
|
||||
elif message_segments.type == "image":
|
||||
# 假设图片数据是base64编码的
|
||||
emoji_base64_list.append(message_segments.data)
|
||||
elif message_segments.type == "seglist":
|
||||
# 递归处理嵌套的Seg列表
|
||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
|
||||
return emoji_base64_list
|
||||
|
||||
# 处理Seg列表的情况
|
||||
for seg in message_segments:
|
||||
if seg.type == "emoji":
|
||||
emoji_base64_list.append(seg.data)
|
||||
elif seg.type == "image":
|
||||
# 假设图片数据是base64编码的
|
||||
emoji_base64_list.append(seg.data)
|
||||
elif seg.type == "seglist":
|
||||
# 递归处理嵌套的Seg列表
|
||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
|
||||
return emoji_base64_list
|
||||
|
||||
|
||||
class ListEmojiCommand(BaseCommand):
|
||||
"""列表表情包Command - 响应/emoji list命令"""
|
||||
|
||||
command_name = "emoji_list"
|
||||
command_description = "列表表情包"
|
||||
|
||||
# === 命令设置(必须填写)===
|
||||
command_pattern = r"^/emoji list(\s+\d+)?$" # 匹配 "/emoji list" 或 "/emoji list 数量"
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
"""执行列表表情包"""
|
||||
from src.plugin_system.apis import emoji_api
|
||||
import datetime
|
||||
|
||||
# 解析命令参数
|
||||
import re
|
||||
|
||||
match = re.match(r"^/emoji list(?:\s+(\d+))?$", self.message.raw_message)
|
||||
max_count = 10 # 默认显示10个
|
||||
if match and match.group(1):
|
||||
max_count = min(int(match.group(1)), 50) # 最多显示50个
|
||||
|
||||
# 获取当前时间
|
||||
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime(time_format)
|
||||
|
||||
# 获取表情包信息
|
||||
emoji_count = emoji_api.get_count()
|
||||
emoji_info = emoji_api.get_info()
|
||||
|
||||
# 构建返回消息
|
||||
message_lines = [
|
||||
f"📊 表情包统计信息 ({time_str})",
|
||||
f"• 总数: {emoji_count} / {emoji_info['max_count']}",
|
||||
f"• 可用: {emoji_info['available_emojis']}",
|
||||
]
|
||||
|
||||
if emoji_count == 0:
|
||||
message_lines.append("\n❌ 暂无表情包")
|
||||
final_message = "\n".join(message_lines)
|
||||
await self.send_text(final_message)
|
||||
return True, final_message, True
|
||||
|
||||
# 获取所有表情包
|
||||
all_emojis = await emoji_api.get_all()
|
||||
if not all_emojis:
|
||||
message_lines.append("\n❌ 无法获取表情包列表")
|
||||
final_message = "\n".join(message_lines)
|
||||
await self.send_text(final_message)
|
||||
return False, final_message, True
|
||||
|
||||
# 显示前N个表情包
|
||||
display_emojis = all_emojis[:max_count]
|
||||
message_lines.append(f"\n📋 显示前 {len(display_emojis)} 个表情包:")
|
||||
|
||||
for i, (_, description, emotion) in enumerate(display_emojis, 1):
|
||||
# 截断过长的描述
|
||||
short_desc = description[:50] + "..." if len(description) > 50 else description
|
||||
message_lines.append(f"{i}. {short_desc} [{emotion}]")
|
||||
|
||||
# 如果还有更多表情包,显示总数
|
||||
if len(all_emojis) > max_count:
|
||||
message_lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
|
||||
|
||||
final_message = "\n".join(message_lines)
|
||||
|
||||
# 直接发送文本消息
|
||||
await self.send_text(final_message)
|
||||
|
||||
return True, final_message, True
|
||||
|
||||
|
||||
class DeleteEmojiCommand(BaseCommand):
|
||||
command_name = "delete_emoji"
|
||||
command_description = "删除表情包"
|
||||
command_pattern = r".*/emoji delete.*"
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
# 查找消息中的表情包图片
|
||||
logger.info(f"查找消息中的表情包用于删除: {self.message.message_segment}")
|
||||
|
||||
emoji_base64_list = self.find_and_return_emoji_in_message(self.message.message_segment)
|
||||
|
||||
if not emoji_base64_list:
|
||||
return False, "未在消息中找到表情包或图片", False
|
||||
|
||||
# 删除找到的表情包
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
results = []
|
||||
|
||||
for i, emoji_base64 in enumerate(emoji_base64_list):
|
||||
try:
|
||||
# 计算图片的哈希值来查找对应的表情包
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
if isinstance(emoji_base64, str):
|
||||
emoji_base64_clean = emoji_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
else:
|
||||
emoji_base64_clean = str(emoji_base64)
|
||||
|
||||
# 计算哈希值
|
||||
image_bytes = base64.b64decode(emoji_base64_clean)
|
||||
emoji_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 使用emoji_api删除表情包
|
||||
result = await emoji_api.delete_emoji(emoji_hash)
|
||||
|
||||
if result["success"]:
|
||||
success_count += 1
|
||||
description = result.get("description", "未知描述")
|
||||
count_before = result.get("count_before", 0)
|
||||
count_after = result.get("count_after", 0)
|
||||
emotions = result.get("emotions", [])
|
||||
|
||||
result_msg = f"表情包 {i + 1} 删除成功"
|
||||
if description:
|
||||
result_msg += f"\n描述: {description}"
|
||||
if emotions:
|
||||
result_msg += f"\n情感标签: {', '.join(emotions)}"
|
||||
result_msg += f"\n表情包数量: {count_before} → {count_after}"
|
||||
|
||||
results.append(result_msg)
|
||||
else:
|
||||
fail_count += 1
|
||||
error_msg = result.get("message", "删除失败")
|
||||
results.append(f"表情包 {i + 1} 删除失败: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
fail_count += 1
|
||||
results.append(f"表情包 {i + 1} 删除时发生错误: {str(e)}")
|
||||
|
||||
# 构建返回消息
|
||||
total_count = success_count + fail_count
|
||||
summary_msg = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total_count} 个"
|
||||
|
||||
# 如果有结果详情,添加到返回消息中
|
||||
details_msg = ""
|
||||
if results:
|
||||
details_msg = "\n" + "\n".join(results)
|
||||
final_msg = summary_msg + details_msg
|
||||
else:
|
||||
final_msg = summary_msg
|
||||
|
||||
# 使用表达器重写回复
|
||||
try:
|
||||
from src.plugin_system.apis import generator_api
|
||||
|
||||
# 构建重写数据
|
||||
rewrite_data = {
|
||||
"raw_reply": summary_msg,
|
||||
"reason": f"删除了表情包:{details_msg}\n",
|
||||
}
|
||||
|
||||
# 调用表达器重写
|
||||
result_status, data = await generator_api.rewrite_reply(
|
||||
chat_stream=self.message.chat_stream,
|
||||
reply_data=rewrite_data,
|
||||
)
|
||||
|
||||
if result_status:
|
||||
# 发送重写后的回复
|
||||
for reply_seg in data.reply_set.reply_data:
|
||||
send_data = reply_seg.content
|
||||
await self.send_text(send_data)
|
||||
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
else:
|
||||
# 如果重写失败,发送原始消息
|
||||
await self.send_text(final_msg)
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
|
||||
except Exception as e:
|
||||
# 如果表达器调用失败,发送原始消息
|
||||
logger.error(f"[delete_emoji] 表达器重写失败: {e}")
|
||||
await self.send_text(final_msg)
|
||||
return success_count > 0, final_msg, success_count > 0
|
||||
|
||||
def find_and_return_emoji_in_message(self, message_segments) -> List[str]:
|
||||
emoji_base64_list = []
|
||||
|
||||
# 处理单个Seg对象的情况
|
||||
if isinstance(message_segments, Seg):
|
||||
if message_segments.type == "emoji":
|
||||
emoji_base64_list.append(message_segments.data)
|
||||
elif message_segments.type == "image":
|
||||
# 假设图片数据是base64编码的
|
||||
emoji_base64_list.append(message_segments.data)
|
||||
elif message_segments.type == "seglist":
|
||||
# 递归处理嵌套的Seg列表
|
||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(message_segments.data))
|
||||
return emoji_base64_list
|
||||
|
||||
# 处理Seg列表的情况
|
||||
for seg in message_segments:
|
||||
if seg.type == "emoji":
|
||||
emoji_base64_list.append(seg.data)
|
||||
elif seg.type == "image":
|
||||
# 假设图片数据是base64编码的
|
||||
emoji_base64_list.append(seg.data)
|
||||
elif seg.type == "seglist":
|
||||
# 递归处理嵌套的Seg列表
|
||||
emoji_base64_list.extend(self.find_and_return_emoji_in_message(seg.data))
|
||||
return emoji_base64_list
|
||||
|
||||
|
||||
class RandomEmojis(BaseCommand):
|
||||
command_name = "random_emojis"
|
||||
command_description = "发送多张随机表情包"
|
||||
command_pattern = r"^/random_emojis$"
|
||||
|
||||
async def execute(self):
|
||||
emojis = await emoji_api.get_random(5)
|
||||
if not emojis:
|
||||
return False, "未找到表情包", False
|
||||
emoji_base64_list = []
|
||||
for emoji in emojis:
|
||||
emoji_base64_list.append(emoji[0])
|
||||
return await self.forward_images(emoji_base64_list)
|
||||
|
||||
async def forward_images(self, images: List[str]):
|
||||
"""
|
||||
把多张图片用合并转发的方式发给用户
|
||||
"""
|
||||
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
|
||||
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
|
||||
|
||||
|
||||
# ===== 插件注册 =====
|
||||
|
||||
|
||||
@register_plugin
|
||||
class EmojiManagePlugin(BasePlugin):
|
||||
"""表情包管理插件 - 管理表情包"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "emoji_manage_plugin" # 内部标识符
|
||||
enable_plugin: bool = False
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "emoji": "表情包功能配置"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="1.0.1", description="配置文件版本"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
return [
|
||||
(RandomEmojis.get_command_info(), RandomEmojis),
|
||||
(AddEmojiCommand.get_command_info(), AddEmojiCommand),
|
||||
(ListEmojiCommand.get_command_info(), ListEmojiCommand),
|
||||
(DeleteEmojiCommand.get_command_info(), DeleteEmojiCommand),
|
||||
]
|
||||
|
|
@ -237,8 +237,7 @@ class HelloWorldPlugin(BasePlugin):
|
|||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="hello_world_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
"config_version": ConfigField(type=str, default="1.0.0", description="配置文件版本"),
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
},
|
||||
"greeting": {
|
||||
|
|
|
|||
|
|
@ -1,56 +1,37 @@
|
|||
[project]
|
||||
name = "MaiBot"
|
||||
version = "0.8.1"
|
||||
version = "0.11.0"
|
||||
description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"aiohttp>=3.12.14",
|
||||
"apscheduler>=3.11.0",
|
||||
"aiohttp-cors>=0.8.1",
|
||||
"colorama>=0.4.6",
|
||||
"cryptography>=45.0.5",
|
||||
"customtkinter>=5.2.2",
|
||||
"dotenv>=0.9.9",
|
||||
"faiss-cpu>=1.11.0",
|
||||
"fastapi>=0.116.0",
|
||||
"google-genai>=1.39.1",
|
||||
"jieba>=0.42.1",
|
||||
"json-repair>=0.47.6",
|
||||
"jsonlines>=4.0.0",
|
||||
"maim-message>=0.3.8",
|
||||
"maim-message",
|
||||
"matplotlib>=3.10.3",
|
||||
"networkx>=3.4.2",
|
||||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
"packaging>=25.0",
|
||||
"pandas>=2.3.1",
|
||||
"peewee>=3.18.2",
|
||||
"pillow>=11.3.0",
|
||||
"psutil>=7.0.0",
|
||||
"pyarrow>=20.0.0",
|
||||
"pydantic>=2.11.7",
|
||||
"pymongo>=4.13.2",
|
||||
"pypinyin>=0.54.0",
|
||||
"python-dateutil>=2.9.0.post0",
|
||||
"python-dotenv>=1.1.1",
|
||||
"python-igraph>=0.11.9",
|
||||
"quick-algo>=0.1.3",
|
||||
"reportportal-client>=5.6.5",
|
||||
"requests>=2.32.4",
|
||||
"rich>=14.0.0",
|
||||
"ruff>=0.12.2",
|
||||
"scikit-learn>=1.7.0",
|
||||
"scipy>=1.15.3",
|
||||
"seaborn>=0.13.2",
|
||||
"setuptools>=80.9.0",
|
||||
"strawberry-graphql[fastapi]>=0.275.5",
|
||||
"structlog>=25.4.0",
|
||||
"toml>=0.10.2",
|
||||
"tomli>=2.2.1",
|
||||
"tomli-w>=1.2.0",
|
||||
"tomlkit>=0.13.3",
|
||||
"tqdm>=4.67.1",
|
||||
"urllib3>=2.5.0",
|
||||
"uvicorn>=0.35.0",
|
||||
"websockets>=15.0.1",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,271 +0,0 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.txt -o requirements.lock
|
||||
aenum==3.1.16
|
||||
# via reportportal-client
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.14
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
# reportportal-client
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.9.0
|
||||
# via
|
||||
# httpx
|
||||
# openai
|
||||
# starlette
|
||||
apscheduler==3.11.0
|
||||
# via -r requirements.txt
|
||||
attrs==25.3.0
|
||||
# via
|
||||
# aiohttp
|
||||
# jsonlines
|
||||
certifi==2025.7.9
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# reportportal-client
|
||||
# requests
|
||||
cffi==1.17.1
|
||||
# via cryptography
|
||||
charset-normalizer==3.4.2
|
||||
# via requests
|
||||
click==8.2.1
|
||||
# via uvicorn
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# click
|
||||
# tqdm
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
cryptography==45.0.5
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
customtkinter==5.2.2
|
||||
# via -r requirements.txt
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
darkdetect==0.8.0
|
||||
# via customtkinter
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
dnspython==2.7.0
|
||||
# via pymongo
|
||||
dotenv==0.9.9
|
||||
# via -r requirements.txt
|
||||
faiss-cpu==1.11.0
|
||||
# via -r requirements.txt
|
||||
fastapi==0.116.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
# strawberry-graphql
|
||||
fonttools==4.58.5
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
graphql-core==3.2.6
|
||||
# via strawberry-graphql
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via openai
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
igraph==0.11.9
|
||||
# via python-igraph
|
||||
jieba==0.42.1
|
||||
# via -r requirements.txt
|
||||
jiter==0.10.0
|
||||
# via openai
|
||||
joblib==1.5.1
|
||||
# via scikit-learn
|
||||
json-repair==0.47.6
|
||||
# via -r requirements.txt
|
||||
jsonlines==4.0.0
|
||||
# via -r requirements.txt
|
||||
kiwisolver==1.4.8
|
||||
# via matplotlib
|
||||
maim-message==0.3.8
|
||||
# via -r requirements.txt
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.10.3
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# seaborn
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
multidict==6.6.3
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
networkx==3.5
|
||||
# via -r requirements.txt
|
||||
numpy==2.3.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# contourpy
|
||||
# faiss-cpu
|
||||
# matplotlib
|
||||
# pandas
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# seaborn
|
||||
openai==1.95.0
|
||||
# via -r requirements.txt
|
||||
packaging==25.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# customtkinter
|
||||
# faiss-cpu
|
||||
# matplotlib
|
||||
# strawberry-graphql
|
||||
pandas==2.3.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# seaborn
|
||||
peewee==3.18.2
|
||||
# via -r requirements.txt
|
||||
pillow==11.3.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# matplotlib
|
||||
propcache==0.3.2
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
psutil==7.0.0
|
||||
# via -r requirements.txt
|
||||
pyarrow==20.0.0
|
||||
# via -r requirements.txt
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# fastapi
|
||||
# maim-message
|
||||
# openai
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygments==2.19.2
|
||||
# via rich
|
||||
pymongo==4.13.2
|
||||
# via -r requirements.txt
|
||||
pyparsing==3.2.3
|
||||
# via matplotlib
|
||||
pypinyin==0.54.0
|
||||
# via -r requirements.txt
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# matplotlib
|
||||
# pandas
|
||||
# strawberry-graphql
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# dotenv
|
||||
python-igraph==0.11.9
|
||||
# via -r requirements.txt
|
||||
python-multipart==0.0.20
|
||||
# via strawberry-graphql
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
quick-algo==0.1.3
|
||||
# via -r requirements.txt
|
||||
reportportal-client==5.6.5
|
||||
# via -r requirements.txt
|
||||
requests==2.32.4
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# reportportal-client
|
||||
rich==14.0.0
|
||||
# via -r requirements.txt
|
||||
ruff==0.12.2
|
||||
# via -r requirements.txt
|
||||
scikit-learn==1.7.0
|
||||
# via -r requirements.txt
|
||||
scipy==1.16.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# scikit-learn
|
||||
seaborn==0.13.2
|
||||
# via -r requirements.txt
|
||||
setuptools==80.9.0
|
||||
# via -r requirements.txt
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# openai
|
||||
starlette==0.46.2
|
||||
# via fastapi
|
||||
strawberry-graphql==0.275.5
|
||||
# via -r requirements.txt
|
||||
structlog==25.4.0
|
||||
# via -r requirements.txt
|
||||
texttable==1.7.0
|
||||
# via igraph
|
||||
threadpoolctl==3.6.0
|
||||
# via scikit-learn
|
||||
toml==0.10.2
|
||||
# via -r requirements.txt
|
||||
tomli==2.2.1
|
||||
# via -r requirements.txt
|
||||
tomli-w==1.2.0
|
||||
# via -r requirements.txt
|
||||
tomlkit==0.13.3
|
||||
# via -r requirements.txt
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# openai
|
||||
typing-extensions==4.14.1
|
||||
# via
|
||||
# fastapi
|
||||
# openai
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# strawberry-graphql
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# pandas
|
||||
# tzlocal
|
||||
tzlocal==5.3.1
|
||||
# via apscheduler
|
||||
urllib3==2.5.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# requests
|
||||
uvicorn==0.35.0
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
websockets==15.0.1
|
||||
# via
|
||||
# -r requirements.txt
|
||||
# maim-message
|
||||
yarl==1.20.1
|
||||
# via aiohttp
|
||||
|
|
@ -1,49 +1,28 @@
|
|||
APScheduler
|
||||
Pillow
|
||||
aiohttp
|
||||
aiohttp-cors
|
||||
colorama
|
||||
customtkinter
|
||||
dotenv
|
||||
faiss-cpu
|
||||
fastapi
|
||||
jieba
|
||||
jsonlines
|
||||
maim_message
|
||||
quick_algo
|
||||
matplotlib
|
||||
networkx
|
||||
numpy
|
||||
openai
|
||||
pandas
|
||||
peewee
|
||||
pyarrow
|
||||
pydantic
|
||||
pypinyin
|
||||
python-dateutil
|
||||
python-dotenv
|
||||
python-igraph
|
||||
pymongo
|
||||
requests
|
||||
ruff
|
||||
scipy
|
||||
setuptools
|
||||
toml
|
||||
tomli
|
||||
tomli_w
|
||||
tomlkit
|
||||
tqdm
|
||||
urllib3
|
||||
uvicorn
|
||||
websockets
|
||||
strawberry-graphql[fastapi]
|
||||
packaging
|
||||
rich
|
||||
psutil
|
||||
cryptography
|
||||
json-repair
|
||||
reportportal-client
|
||||
scikit-learn
|
||||
seaborn
|
||||
structlog
|
||||
google.genai
|
||||
aiohttp>=3.12.14
|
||||
aiohttp-cors>=0.8.1
|
||||
colorama>=0.4.6
|
||||
faiss-cpu>=1.11.0
|
||||
fastapi>=0.116.0
|
||||
google-genai>=1.39.1
|
||||
jieba>=0.42.1
|
||||
json-repair>=0.47.6
|
||||
maim-message
|
||||
matplotlib>=3.10.3
|
||||
numpy>=2.2.6
|
||||
openai>=1.95.0
|
||||
pandas>=2.3.1
|
||||
peewee>=3.18.2
|
||||
pillow>=11.3.0
|
||||
pyarrow>=20.0.0
|
||||
pydantic>=2.11.7
|
||||
pypinyin>=0.54.0
|
||||
python-dotenv>=1.1.1
|
||||
quick-algo>=0.1.3
|
||||
rich>=14.0.0
|
||||
ruff>=0.12.2
|
||||
setuptools>=80.9.0
|
||||
structlog>=25.4.0
|
||||
toml>=0.10.2
|
||||
tomlkit>=0.13.3
|
||||
urllib3>=2.5.0
|
||||
uvicorn>=0.35.0
|
||||
|
|
@ -0,0 +1,389 @@
|
|||
import argparse
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
|
||||
|
||||
def clean_output_text(text: str) -> str:
|
||||
"""
|
||||
清理输出文本,移除表情包和回复内容
|
||||
- 移除 [表情包:...] 格式的内容
|
||||
- 移除 [回复...] 格式的内容
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# 移除表情包内容:[表情包:...]
|
||||
text = re.sub(r'\[表情包:[^\]]*\]', '', text)
|
||||
|
||||
# 移除回复内容:[回复...],说:... 的完整模式
|
||||
text = re.sub(r'\[回复[^\]]*\],说:[^@]*@[^:]*:', '', text)
|
||||
|
||||
# 清理多余的空格和换行
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err: Optional[Exception] = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
|
||||
def fetch_messages_between(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
|
||||
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
|
||||
if platform:
|
||||
filter_query["chat_info_platform"] = platform
|
||||
# 当 limit==0 时,sort 生效,这里按时间升序
|
||||
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
|
||||
|
||||
|
||||
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
|
||||
groups: Dict[str, List[DatabaseMessages]] = {}
|
||||
for msg in messages:
|
||||
groups.setdefault(msg.chat_id, []).append(msg)
|
||||
# 保证每个分组内按时间升序
|
||||
for chat_id, msgs in groups.items():
|
||||
msgs.sort(key=lambda m: m.time or 0)
|
||||
return groups
|
||||
|
||||
|
||||
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
|
||||
"""
|
||||
将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。
|
||||
processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。
|
||||
"""
|
||||
if not bucket:
|
||||
raise ValueError("bucket 为空,无法合并")
|
||||
|
||||
latest = bucket[-1]
|
||||
merged_texts: List[str] = []
|
||||
for m in bucket:
|
||||
text = m.processed_plain_text or ""
|
||||
if text:
|
||||
merged_texts.append(text)
|
||||
|
||||
merged = DatabaseMessages(
|
||||
# 其他信息采用最新消息
|
||||
message_id=latest.message_id,
|
||||
time=latest.time,
|
||||
chat_id=latest.chat_id,
|
||||
reply_to=latest.reply_to,
|
||||
interest_value=latest.interest_value,
|
||||
key_words=latest.key_words,
|
||||
key_words_lite=latest.key_words_lite,
|
||||
is_mentioned=latest.is_mentioned,
|
||||
is_at=latest.is_at,
|
||||
reply_probability_boost=latest.reply_probability_boost,
|
||||
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
|
||||
display_message=latest.display_message,
|
||||
priority_mode=latest.priority_mode,
|
||||
priority_info=latest.priority_info,
|
||||
additional_config=latest.additional_config,
|
||||
is_emoji=latest.is_emoji,
|
||||
is_picid=latest.is_picid,
|
||||
is_command=latest.is_command,
|
||||
is_notify=latest.is_notify,
|
||||
selected_expressions=latest.selected_expressions,
|
||||
user_id=latest.user_info.user_id,
|
||||
user_nickname=latest.user_info.user_nickname,
|
||||
user_cardname=latest.user_info.user_cardname,
|
||||
user_platform=latest.user_info.platform,
|
||||
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
|
||||
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
|
||||
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
|
||||
chat_info_user_id=latest.chat_info.user_info.user_id,
|
||||
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
|
||||
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
|
||||
chat_info_user_platform=latest.chat_info.user_info.platform,
|
||||
chat_info_stream_id=latest.chat_info.stream_id,
|
||||
chat_info_platform=latest.chat_info.platform,
|
||||
chat_info_create_time=latest.chat_info.create_time,
|
||||
chat_info_last_active_time=latest.chat_info.last_active_time,
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||||
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged: List[DatabaseMessages] = []
|
||||
bucket: List[DatabaseMessages] = []
|
||||
|
||||
def flush_bucket() -> None:
|
||||
nonlocal bucket
|
||||
if bucket:
|
||||
merged.append(_merge_bucket_to_message(bucket))
|
||||
bucket = []
|
||||
|
||||
for msg in messages:
|
||||
if not bucket:
|
||||
bucket = [msg]
|
||||
continue
|
||||
|
||||
last = bucket[-1]
|
||||
same_user = (msg.user_info.user_id == last.user_info.user_id)
|
||||
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
|
||||
|
||||
if same_user and close_enough:
|
||||
bucket.append(msg)
|
||||
else:
|
||||
flush_bucket()
|
||||
bucket = [msg]
|
||||
|
||||
flush_bucket()
|
||||
return merged
|
||||
|
||||
|
||||
def build_pairs_for_chat(
|
||||
original_messages: List[DatabaseMessages],
|
||||
merged_messages: List[DatabaseMessages],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
target_user_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
对每条合并后的消息作为 output,从其前面取 20-30 条(可配置)的原始消息作为 input。
|
||||
input 使用原始未合并的消息构建上下文。
|
||||
output 使用合并后消息的 processed_plain_text。
|
||||
如果指定了 target_user_id,则只处理该用户的消息作为 output。
|
||||
"""
|
||||
pairs: List[Tuple[str, str, str]] = []
|
||||
n_merged = len(merged_messages)
|
||||
n_original = len(original_messages)
|
||||
|
||||
if n_merged == 0 or n_original == 0:
|
||||
return pairs
|
||||
|
||||
# 为每个合并后的消息找到对应的原始消息位置
|
||||
merged_to_original_map = {}
|
||||
original_idx = 0
|
||||
|
||||
for merged_idx, merged_msg in enumerate(merged_messages):
|
||||
# 找到这个合并消息对应的第一个原始消息
|
||||
while (original_idx < n_original and
|
||||
original_messages[original_idx].time < merged_msg.time):
|
||||
original_idx += 1
|
||||
|
||||
# 如果找到了时间匹配的原始消息,建立映射
|
||||
if (original_idx < n_original and
|
||||
original_messages[original_idx].time == merged_msg.time):
|
||||
merged_to_original_map[merged_idx] = original_idx
|
||||
|
||||
for merged_idx in range(n_merged):
|
||||
merged_msg = merged_messages[merged_idx]
|
||||
|
||||
# 如果指定了 target_user_id,只处理该用户的消息作为 output
|
||||
if target_user_id and merged_msg.user_info.user_id != target_user_id:
|
||||
continue
|
||||
|
||||
# 找到对应的原始消息位置
|
||||
if merged_idx not in merged_to_original_map:
|
||||
continue
|
||||
|
||||
original_idx = merged_to_original_map[merged_idx]
|
||||
|
||||
# 选择上下文窗口大小
|
||||
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
|
||||
start = max(0, original_idx - window)
|
||||
context_msgs = original_messages[start:original_idx]
|
||||
|
||||
# 使用原始未合并消息构建 input
|
||||
input_str = build_readable_messages(
|
||||
messages=context_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
)
|
||||
|
||||
# 输出取合并后消息的 processed_plain_text 并清理表情包和回复内容
|
||||
output_text = merged_msg.processed_plain_text or ""
|
||||
output_text = clean_output_text(output_text)
|
||||
output_id = merged_msg.message_id or ""
|
||||
pairs.append((input_str, output_text, output_id))
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def build_pairs(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str],
|
||||
user_id: Optional[str],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# 获取所有消息(不按user_id过滤),这样input上下文可以包含所有用户的消息
|
||||
messages = fetch_messages_between(start_ts, end_ts, platform)
|
||||
groups = group_by_chat(messages)
|
||||
|
||||
all_pairs: List[Tuple[str, str, str]] = []
|
||||
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||
# 对消息进行合并,用于output
|
||||
merged = merge_adjacent_same_user(msgs)
|
||||
# 传递原始消息和合并后消息,input使用原始消息,output使用合并后消息
|
||||
pairs = build_pairs_for_chat(msgs, merged, min_ctx, max_ctx, user_id)
|
||||
all_pairs.extend(pairs)
|
||||
|
||||
return all_pairs
|
||||
|
||||
|
||||
def main(argv: Optional[List[str]] = None) -> int:
|
||||
# 若未提供参数,则进入交互模式
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
|
||||
if len(argv) == 0:
|
||||
return run_interactive()
|
||||
|
||||
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表,支持按用户ID筛选消息")
|
||||
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
|
||||
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
|
||||
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
|
||||
parser.add_argument("--user_id", default=None, help="仅选择指定 user_id 的消息")
|
||||
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数,默认20")
|
||||
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数,默认30")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="输出保存路径,支持 .jsonl(每行 {input, output}),若不指定则打印到stdout",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
start_ts = parse_datetime_to_timestamp(args.start)
|
||||
end_ts = parse_datetime_to_timestamp(args.end)
|
||||
if end_ts <= start_ts:
|
||||
raise ValueError("结束时间必须大于起始时间")
|
||||
|
||||
if args.max_ctx < args.min_ctx:
|
||||
raise ValueError("max_ctx 不能小于 min_ctx")
|
||||
|
||||
pairs = build_pairs(start_ts, end_ts, args.platform, args.user_id, args.min_ctx, args.max_ctx)
|
||||
|
||||
if args.output:
|
||||
# 保存为 JSONL,每行一个 {input, output, message_id}
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {args.output}")
|
||||
else:
|
||||
# 打印到 stdout
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
|
||||
suffix = f"[{default}]" if default not in (None, "") else ""
|
||||
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
|
||||
if value == "" and default is not None:
|
||||
return default
|
||||
return value
|
||||
|
||||
|
||||
def run_interactive() -> int:
|
||||
print("进入交互模式(直接回车采用默认值)。时间格式例如:2025-09-28 00:00:00 或 2025-09-28")
|
||||
start_str = _prompt_with_default("请输入起始时间", None)
|
||||
end_str = _prompt_with_default("请输入结束时间", None)
|
||||
platform = _prompt_with_default("平台(可留空表示不限)", "")
|
||||
user_id = _prompt_with_default("用户ID(可留空表示不限)", "")
|
||||
try:
|
||||
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
|
||||
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
|
||||
except Exception:
|
||||
print("上下文条数输入有误,使用默认 20/30")
|
||||
min_ctx, max_ctx = 20, 30
|
||||
output_path = _prompt_with_default("输出路径(.jsonl,可留空打印到控制台)", "")
|
||||
|
||||
if not start_str or not end_str:
|
||||
print("必须提供起始与结束时间。")
|
||||
return 2
|
||||
|
||||
try:
|
||||
start_ts = parse_datetime_to_timestamp(start_str)
|
||||
end_ts = parse_datetime_to_timestamp(end_str)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"时间解析失败:{e}")
|
||||
return 2
|
||||
|
||||
if end_ts <= start_ts:
|
||||
print("结束时间必须大于起始时间。")
|
||||
return 2
|
||||
|
||||
if max_ctx < min_ctx:
|
||||
print("最多条数不能小于最少条数。")
|
||||
return 2
|
||||
|
||||
platform_val = platform if platform != "" else None
|
||||
user_id_val = user_id if user_id != "" else None
|
||||
pairs = build_pairs(start_ts, end_ts, platform_val, user_id_val, min_ctx, max_ctx)
|
||||
|
||||
if output_path:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {output_path}")
|
||||
else:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
print(f"总计 {len(pairs)} 条。")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
import time
|
||||
import sys
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def get_expression_data() -> List[Tuple[float, float, str, str]]:
|
||||
"""获取Expression表中的数据,返回(create_date, count, chat_id, expression_type)的列表"""
|
||||
expressions = Expression.select()
|
||||
data = []
|
||||
|
||||
for expr in expressions:
|
||||
# 如果create_date为空,跳过该记录
|
||||
if expr.create_date is None:
|
||||
continue
|
||||
|
||||
data.append((
|
||||
expr.create_date,
|
||||
expr.count,
|
||||
expr.chat_id,
|
||||
expr.type
|
||||
))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 分离数据
|
||||
create_dates = [item[0] for item in data]
|
||||
counts = [item[1] for item in data]
|
||||
chat_ids = [item[2] for item in data]
|
||||
expression_types = [item[3] for item in data]
|
||||
|
||||
# 转换时间戳为datetime对象
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
time_span = max(dates) - min(dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = '%Y-%m-%d %H:%M'
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# 创建散点图
|
||||
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap='viridis')
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
||||
ax.set_title('表达式使用次数随时间分布散点图', fontsize=14, fontweight='bold')
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 添加颜色条
|
||||
cbar = plt.colorbar(scatter)
|
||||
cbar.set_label('数据点顺序', fontsize=10)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print(f"\n=== 数据统计 ===")
|
||||
print(f"总数据点数量: {len(data)}")
|
||||
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')} 到 {max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"使用次数范围: {min(counts):.1f} 到 {max(counts):.1f}")
|
||||
print(f"平均使用次数: {np.mean(counts):.2f}")
|
||||
print(f"中位数使用次数: {np.median(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"\n散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建按聊天分组的散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 按chat_id分组
|
||||
chat_groups = {}
|
||||
for item in data:
|
||||
chat_id = item[2]
|
||||
if chat_id not in chat_groups:
|
||||
chat_groups[chat_id] = []
|
||||
chat_groups[chat_id].append(item)
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = '%Y-%m-%d %H:%M'
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
|
||||
# 为每个聊天分配不同颜色
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
|
||||
|
||||
for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
|
||||
create_dates = [item[0] for item in chat_data]
|
||||
counts = [item[1] for item in chat_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
chat_name = get_chat_name(chat_id)
|
||||
# 截断过长的聊天名称
|
||||
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
|
||||
|
||||
ax.scatter(dates, counts, alpha=0.7, s=40,
|
||||
c=[colors[i]], label=f"{display_name} ({len(chat_data)}个)",
|
||||
edgecolors='black', linewidth=0.5)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
||||
ax.set_title('按聊天分组的表达式使用次数散点图', fontsize=14, fontweight='bold')
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print(f"\n=== 分组统计 ===")
|
||||
print(f"总聊天数量: {len(chat_groups)}")
|
||||
for chat_id, chat_data in chat_groups.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
counts = [item[1] for item in chat_data]
|
||||
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"\n分组散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建按表达式类型分组的散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 按type分组
|
||||
type_groups = {}
|
||||
for item in data:
|
||||
expr_type = item[3]
|
||||
if expr_type not in type_groups:
|
||||
type_groups[expr_type] = []
|
||||
type_groups[expr_type].append(item)
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = '%Y-%m-%d'
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = '%Y-%m-%d %H:%M'
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# 为每个类型分配不同颜色
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
|
||||
|
||||
for i, (expr_type, type_data) in enumerate(type_groups.items()):
|
||||
create_dates = [item[0] for item in type_data]
|
||||
counts = [item[1] for item in type_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
ax.scatter(dates, counts, alpha=0.7, s=40,
|
||||
c=[colors[i]], label=f"{expr_type} ({len(type_data)}个)",
|
||||
edgecolors='black', linewidth=0.5)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel('创建日期 (Create Date)', fontsize=12)
|
||||
ax.set_ylabel('使用次数 (Count)', fontsize=12)
|
||||
ax.set_title('按表达式类型分组的散点图', fontsize=14, fontweight='bold')
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print(f"\n=== 类型统计 ===")
|
||||
for expr_type, type_data in type_groups.items():
|
||||
counts = [item[1] for item in type_data]
|
||||
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"\n类型散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("开始分析表达式数据...")
|
||||
|
||||
# 获取数据
|
||||
data = get_expression_data()
|
||||
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据(create_date不为空的数据)")
|
||||
return
|
||||
|
||||
print(f"找到 {len(data)} 条有效数据")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 生成时间戳用于文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 1. 创建基础散点图
|
||||
print("\n1. 创建基础散点图...")
|
||||
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
|
||||
|
||||
# 2. 创建按聊天分组的散点图
|
||||
print("\n2. 创建按聊天分组的散点图...")
|
||||
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
|
||||
|
||||
# 3. 创建按类型分组的散点图
|
||||
print("\n3. 创建按类型分组的散点图...")
|
||||
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
|
||||
|
||||
print("\n分析完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,285 +0,0 @@
|
|||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
from src.common.database.database_model import Messages, ChatStreams # noqa
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def format_timestamp(timestamp: float) -> str:
|
||||
"""Format timestamp to readable date string"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
|
||||
"""Calculate distribution of interest_value"""
|
||||
distribution = {
|
||||
"0.000-0.010": 0,
|
||||
"0.010-0.050": 0,
|
||||
"0.050-0.100": 0,
|
||||
"0.100-0.500": 0,
|
||||
"0.500-1.000": 0,
|
||||
"1.000-2.000": 0,
|
||||
"2.000-5.000": 0,
|
||||
"5.000-10.000": 0,
|
||||
"10.000+": 0,
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
if msg.interest_value is None or msg.interest_value == 0.0:
|
||||
continue
|
||||
|
||||
value = float(msg.interest_value)
|
||||
if value < 0.010:
|
||||
distribution["0.000-0.010"] += 1
|
||||
elif value < 0.050:
|
||||
distribution["0.010-0.050"] += 1
|
||||
elif value < 0.100:
|
||||
distribution["0.050-0.100"] += 1
|
||||
elif value < 0.500:
|
||||
distribution["0.100-0.500"] += 1
|
||||
elif value < 1.000:
|
||||
distribution["0.500-1.000"] += 1
|
||||
elif value < 2.000:
|
||||
distribution["1.000-2.000"] += 1
|
||||
elif value < 5.000:
|
||||
distribution["2.000-5.000"] += 1
|
||||
elif value < 10.000:
|
||||
distribution["5.000-10.000"] += 1
|
||||
else:
|
||||
distribution["10.000+"] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def get_interest_value_stats(messages) -> Dict[str, float]:
|
||||
"""Calculate basic statistics for interest_value"""
|
||||
values = [
|
||||
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
|
||||
]
|
||||
|
||||
if not values:
|
||||
return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0}
|
||||
|
||||
values.sort()
|
||||
count = len(values)
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"avg": sum(values) / count,
|
||||
"median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2,
|
||||
}
|
||||
|
||||
|
||||
def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
"""Get all available chats with message counts"""
|
||||
try:
|
||||
# 获取所有有消息的chat_id
|
||||
chat_counts = {}
|
||||
for msg in Messages.select(Messages.chat_id).distinct():
|
||||
chat_id = msg.chat_id
|
||||
count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.chat_id == chat_id)
|
||||
& (Messages.interest_value.is_null(False))
|
||||
& (Messages.interest_value != 0.0)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
if count > 0:
|
||||
chat_counts[chat_id] = count
|
||||
|
||||
# 获取聊天名称
|
||||
result = []
|
||||
for chat_id, count in chat_counts.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
result.append((chat_id, chat_name, count))
|
||||
|
||||
# 按消息数量排序
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取聊天列表失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
print("2. 最近3天")
|
||||
print("3. 最近7天")
|
||||
print("4. 最近30天")
|
||||
print("5. 自定义时间范围")
|
||||
print("6. 不限制时间")
|
||||
|
||||
choice = input("请选择时间范围 (1-6): ").strip()
|
||||
|
||||
now = time.time()
|
||||
|
||||
if choice == "1":
|
||||
return now - 24 * 3600, now
|
||||
elif choice == "2":
|
||||
return now - 3 * 24 * 3600, now
|
||||
elif choice == "3":
|
||||
return now - 7 * 24 * 3600, now
|
||||
elif choice == "4":
|
||||
return now - 30 * 24 * 3600, now
|
||||
elif choice == "5":
|
||||
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
start_str = input().strip()
|
||||
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
end_str = input().strip()
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print("时间格式错误,将不限制时间范围")
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def analyze_interest_values(
|
||||
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
|
||||
) -> None:
|
||||
"""Analyze interest values with optional filters"""
|
||||
|
||||
# 构建查询条件
|
||||
query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0))
|
||||
|
||||
if chat_id:
|
||||
query = query.where(Messages.chat_id == chat_id)
|
||||
|
||||
if start_time:
|
||||
query = query.where(Messages.time >= start_time)
|
||||
|
||||
if end_time:
|
||||
query = query.where(Messages.time <= end_time)
|
||||
|
||||
messages = list(query)
|
||||
|
||||
if not messages:
|
||||
print("没有找到符合条件的消息")
|
||||
return
|
||||
|
||||
# 计算统计信息
|
||||
distribution = calculate_interest_value_distribution(messages)
|
||||
stats = get_interest_value_stats(messages)
|
||||
|
||||
# 显示结果
|
||||
print("\n=== Interest Value 分析结果 ===")
|
||||
if chat_id:
|
||||
print(f"聊天: {get_chat_name(chat_id)}")
|
||||
else:
|
||||
print("聊天: 全部聊天")
|
||||
|
||||
if start_time and end_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
||||
elif start_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 之后")
|
||||
elif end_time:
|
||||
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
||||
else:
|
||||
print("时间范围: 不限制")
|
||||
|
||||
print("\n基本统计:")
|
||||
print(f"有效消息数量: {stats['count']} (排除null和0值)")
|
||||
print(f"最小值: {stats['min']:.3f}")
|
||||
print(f"最大值: {stats['max']:.3f}")
|
||||
print(f"平均值: {stats['avg']:.3f}")
|
||||
print(f"中位数: {stats['median']:.3f}")
|
||||
|
||||
print("\nInterest Value 分布:")
|
||||
total = stats["count"]
|
||||
for range_name, count in distribution.items():
|
||||
if count > 0:
|
||||
percentage = count / total * 100
|
||||
print(f"{range_name}: {count} ({percentage:.2f}%)")
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for interest value analysis"""
|
||||
|
||||
while True:
|
||||
print("\n" + "=" * 50)
|
||||
print("Interest Value 分析工具")
|
||||
print("=" * 50)
|
||||
print("1. 分析全部聊天")
|
||||
print("2. 选择特定聊天分析")
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
chat_id = None
|
||||
|
||||
if choice == "2":
|
||||
# 显示可用的聊天列表
|
||||
chats = get_available_chats()
|
||||
if not chats:
|
||||
print("没有找到有interest_value数据的聊天")
|
||||
continue
|
||||
|
||||
print(f"\n可用的聊天 (共{len(chats)}个):")
|
||||
for i, (_cid, name, count) in enumerate(chats, 1):
|
||||
print(f"{i}. {name} ({count}条有效消息)")
|
||||
|
||||
try:
|
||||
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
||||
if 1 <= chat_choice <= len(chats):
|
||||
chat_id = chats[chat_choice - 1][0]
|
||||
else:
|
||||
print("无效选择")
|
||||
continue
|
||||
except ValueError:
|
||||
print("请输入有效数字")
|
||||
continue
|
||||
|
||||
elif choice != "1":
|
||||
print("无效选择")
|
||||
continue
|
||||
|
||||
# 获取时间范围
|
||||
start_time, end_time = get_time_range_input()
|
||||
|
||||
# 执行分析
|
||||
analyze_interest_values(chat_id, start_time, end_time)
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
|
|
@ -1,397 +0,0 @@
|
|||
import time
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
from src.common.database.database_model import Messages, ChatStreams # noqa
|
||||
|
||||
|
||||
def contains_emoji_or_image_tags(text: str) -> bool:
|
||||
"""Check if text contains [表情包xxxxx] or [图片xxxxx] tags"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# 检查是否包含 [表情包] 或 [图片] 标记
|
||||
emoji_pattern = r"\[表情包[^\]]*\]"
|
||||
image_pattern = r"\[图片[^\]]*\]"
|
||||
|
||||
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
|
||||
|
||||
|
||||
def clean_reply_text(text: str) -> str:
|
||||
"""Remove reply references like [回复 xxxx...] from text"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# 匹配 [回复 xxxx...] 格式的内容
|
||||
# 使用非贪婪匹配,匹配到第一个 ] 就停止
|
||||
cleaned_text = re.sub(r"\[回复[^\]]*\]", "", text)
|
||||
|
||||
# 去除多余的空白字符
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
return cleaned_text
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def format_timestamp(timestamp: float) -> str:
|
||||
"""Format timestamp to readable date string"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def calculate_text_length_distribution(messages) -> Dict[str, int]:
|
||||
"""Calculate distribution of processed_plain_text length"""
|
||||
distribution = {
|
||||
"0": 0, # 空文本
|
||||
"1-5": 0, # 极短文本
|
||||
"6-10": 0, # 很短文本
|
||||
"11-20": 0, # 短文本
|
||||
"21-30": 0, # 较短文本
|
||||
"31-50": 0, # 中短文本
|
||||
"51-70": 0, # 中等文本
|
||||
"71-100": 0, # 较长文本
|
||||
"101-150": 0, # 长文本
|
||||
"151-200": 0, # 很长文本
|
||||
"201-300": 0, # 超长文本
|
||||
"301-500": 0, # 极长文本
|
||||
"501-1000": 0, # 巨长文本
|
||||
"1000+": 0, # 超巨长文本
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is None:
|
||||
continue
|
||||
|
||||
# 排除包含表情包或图片标记的消息
|
||||
if contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
continue
|
||||
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
length = len(cleaned_text)
|
||||
|
||||
if length == 0:
|
||||
distribution["0"] += 1
|
||||
elif length <= 5:
|
||||
distribution["1-5"] += 1
|
||||
elif length <= 10:
|
||||
distribution["6-10"] += 1
|
||||
elif length <= 20:
|
||||
distribution["11-20"] += 1
|
||||
elif length <= 30:
|
||||
distribution["21-30"] += 1
|
||||
elif length <= 50:
|
||||
distribution["31-50"] += 1
|
||||
elif length <= 70:
|
||||
distribution["51-70"] += 1
|
||||
elif length <= 100:
|
||||
distribution["71-100"] += 1
|
||||
elif length <= 150:
|
||||
distribution["101-150"] += 1
|
||||
elif length <= 200:
|
||||
distribution["151-200"] += 1
|
||||
elif length <= 300:
|
||||
distribution["201-300"] += 1
|
||||
elif length <= 500:
|
||||
distribution["301-500"] += 1
|
||||
elif length <= 1000:
|
||||
distribution["501-1000"] += 1
|
||||
else:
|
||||
distribution["1000+"] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def get_text_length_stats(messages) -> Dict[str, float]:
|
||||
"""Calculate basic statistics for processed_plain_text length"""
|
||||
lengths = []
|
||||
null_count = 0
|
||||
excluded_count = 0 # 被排除的消息数量
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is None:
|
||||
null_count += 1
|
||||
elif contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
# 排除包含表情包或图片标记的消息
|
||||
excluded_count += 1
|
||||
else:
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
lengths.append(len(cleaned_text))
|
||||
|
||||
if not lengths:
|
||||
return {
|
||||
"count": 0,
|
||||
"null_count": null_count,
|
||||
"excluded_count": excluded_count,
|
||||
"min": 0,
|
||||
"max": 0,
|
||||
"avg": 0,
|
||||
"median": 0,
|
||||
}
|
||||
|
||||
lengths.sort()
|
||||
count = len(lengths)
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"null_count": null_count,
|
||||
"excluded_count": excluded_count,
|
||||
"min": min(lengths),
|
||||
"max": max(lengths),
|
||||
"avg": sum(lengths) / count,
|
||||
"median": lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2,
|
||||
}
|
||||
|
||||
|
||||
def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
"""Get all available chats with message counts"""
|
||||
try:
|
||||
# 获取所有有消息的chat_id,排除特殊类型消息
|
||||
chat_counts = {}
|
||||
for msg in Messages.select(Messages.chat_id).distinct():
|
||||
chat_id = msg.chat_id
|
||||
count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.chat_id == chat_id)
|
||||
& (Messages.is_emoji != 1)
|
||||
& (Messages.is_picid != 1)
|
||||
& (Messages.is_command != 1)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
if count > 0:
|
||||
chat_counts[chat_id] = count
|
||||
|
||||
# 获取聊天名称
|
||||
result = []
|
||||
for chat_id, count in chat_counts.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
result.append((chat_id, chat_name, count))
|
||||
|
||||
# 按消息数量排序
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取聊天列表失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
print("2. 最近3天")
|
||||
print("3. 最近7天")
|
||||
print("4. 最近30天")
|
||||
print("5. 自定义时间范围")
|
||||
print("6. 不限制时间")
|
||||
|
||||
choice = input("请选择时间范围 (1-6): ").strip()
|
||||
|
||||
now = time.time()
|
||||
|
||||
if choice == "1":
|
||||
return now - 24 * 3600, now
|
||||
elif choice == "2":
|
||||
return now - 3 * 24 * 3600, now
|
||||
elif choice == "3":
|
||||
return now - 7 * 24 * 3600, now
|
||||
elif choice == "4":
|
||||
return now - 30 * 24 * 3600, now
|
||||
elif choice == "5":
|
||||
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
start_str = input().strip()
|
||||
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
end_str = input().strip()
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print("时间格式错误,将不限制时间范围")
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
|
||||
"""Get top N longest messages"""
|
||||
message_lengths = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is not None:
|
||||
# 排除包含表情包或图片标记的消息
|
||||
if contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
continue
|
||||
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
length = len(cleaned_text)
|
||||
chat_name = get_chat_name(msg.chat_id)
|
||||
time_str = format_timestamp(msg.time)
|
||||
# 截取前100个字符作为预览
|
||||
preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
|
||||
message_lengths.append((chat_name, length, time_str, preview))
|
||||
|
||||
# 按长度排序,取前N个
|
||||
message_lengths.sort(key=lambda x: x[1], reverse=True)
|
||||
return message_lengths[:top_n]
|
||||
|
||||
|
||||
def analyze_text_lengths(
|
||||
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
|
||||
) -> None:
|
||||
"""Analyze processed_plain_text lengths with optional filters"""
|
||||
|
||||
# 构建查询条件,排除特殊类型的消息
|
||||
query = Messages.select().where((Messages.is_emoji != 1) & (Messages.is_picid != 1) & (Messages.is_command != 1))
|
||||
|
||||
if chat_id:
|
||||
query = query.where(Messages.chat_id == chat_id)
|
||||
|
||||
if start_time:
|
||||
query = query.where(Messages.time >= start_time)
|
||||
|
||||
if end_time:
|
||||
query = query.where(Messages.time <= end_time)
|
||||
|
||||
messages = list(query)
|
||||
|
||||
if not messages:
|
||||
print("没有找到符合条件的消息")
|
||||
return
|
||||
|
||||
# 计算统计信息
|
||||
distribution = calculate_text_length_distribution(messages)
|
||||
stats = get_text_length_stats(messages)
|
||||
top_longest = get_top_longest_messages(messages, 10)
|
||||
|
||||
# 显示结果
|
||||
print("\n=== Processed Plain Text 长度分析结果 ===")
|
||||
print("(已排除表情、图片ID、命令类型消息,已排除[表情包]和[图片]标记消息,已清理回复引用)")
|
||||
if chat_id:
|
||||
print(f"聊天: {get_chat_name(chat_id)}")
|
||||
else:
|
||||
print("聊天: 全部聊天")
|
||||
|
||||
if start_time and end_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
||||
elif start_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 之后")
|
||||
elif end_time:
|
||||
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
||||
else:
|
||||
print("时间范围: 不限制")
|
||||
|
||||
print("\n基本统计:")
|
||||
print(f"总消息数量: {len(messages)}")
|
||||
print(f"有文本消息数量: {stats['count']}")
|
||||
print(f"空文本消息数量: {stats['null_count']}")
|
||||
print(f"被排除的消息数量: {stats['excluded_count']}")
|
||||
if stats["count"] > 0:
|
||||
print(f"最短长度: {stats['min']} 字符")
|
||||
print(f"最长长度: {stats['max']} 字符")
|
||||
print(f"平均长度: {stats['avg']:.2f} 字符")
|
||||
print(f"中位数长度: {stats['median']:.2f} 字符")
|
||||
|
||||
print("\n文本长度分布:")
|
||||
total = stats["count"]
|
||||
if total > 0:
|
||||
for range_name, count in distribution.items():
|
||||
if count > 0:
|
||||
percentage = count / total * 100
|
||||
print(f"{range_name} 字符: {count} ({percentage:.2f}%)")
|
||||
|
||||
# 显示最长的消息
|
||||
if top_longest:
|
||||
print(f"\n最长的 {len(top_longest)} 条消息:")
|
||||
for i, (chat_name, length, time_str, preview) in enumerate(top_longest, 1):
|
||||
print(f"{i}. [{chat_name}] {time_str}")
|
||||
print(f" 长度: {length} 字符")
|
||||
print(f" 预览: {preview}")
|
||||
print()
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for text length analysis"""
|
||||
|
||||
while True:
|
||||
print("\n" + "=" * 50)
|
||||
print("Processed Plain Text 长度分析工具")
|
||||
print("=" * 50)
|
||||
print("1. 分析全部聊天")
|
||||
print("2. 选择特定聊天分析")
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
chat_id = None
|
||||
|
||||
if choice == "2":
|
||||
# 显示可用的聊天列表
|
||||
chats = get_available_chats()
|
||||
if not chats:
|
||||
print("没有找到聊天数据")
|
||||
continue
|
||||
|
||||
print(f"\n可用的聊天 (共{len(chats)}个):")
|
||||
for i, (_cid, name, count) in enumerate(chats, 1):
|
||||
print(f"{i}. {name} ({count}条消息)")
|
||||
|
||||
try:
|
||||
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
||||
if 1 <= chat_choice <= len(chats):
|
||||
chat_id = chats[chat_choice - 1][0]
|
||||
else:
|
||||
print("无效选择")
|
||||
continue
|
||||
except ValueError:
|
||||
print("请输入有效数字")
|
||||
continue
|
||||
|
||||
elif choice != "1":
|
||||
print("无效选择")
|
||||
continue
|
||||
|
||||
# 获取时间范围
|
||||
start_time, end_time = get_time_range_input()
|
||||
|
||||
# 执行分析
|
||||
analyze_text_lengths(chat_id, start_time, end_time)
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
|
|
@ -16,8 +16,7 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
|
|||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
|
|
@ -96,7 +95,6 @@ class BrainChatting:
|
|||
self.last_read_time = time.time() - 2
|
||||
|
||||
self.more_plan = False
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
|
@ -171,10 +169,8 @@ class BrainChatting:
|
|||
|
||||
if len(recent_messages_list) >= 1:
|
||||
self.last_read_time = time.time()
|
||||
await self._observe(
|
||||
recent_messages_list=recent_messages_list
|
||||
)
|
||||
|
||||
await self._observe(recent_messages_list=recent_messages_list)
|
||||
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
await asyncio.sleep(0.2)
|
||||
|
|
@ -233,11 +229,11 @@ class BrainChatting:
|
|||
|
||||
async def _observe(
|
||||
self, # interest_value: float = 0.0,
|
||||
recent_messages_list: Optional[List["DatabaseMessages"]] = None
|
||||
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
|
||||
) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
await self.expression_learner.trigger_learning_for_chat()
|
||||
|
|
@ -286,7 +282,7 @@ class BrainChatting:
|
|||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info, _ = await self.action_planner.plan(
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
|
@ -334,7 +330,7 @@ class BrainChatting:
|
|||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
reply_text = reply_text_from_reply
|
||||
_reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
|
|
@ -347,7 +343,7 @@ class BrainChatting:
|
|||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
_reply_text = action_reply_text
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
|
@ -401,7 +397,7 @@ class BrainChatting:
|
|||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
action_reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
|
|
@ -417,8 +413,8 @@ class BrainChatting:
|
|||
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
|
||||
return False, "", ""
|
||||
|
||||
# 处理动作并获取结果
|
||||
result = await action_handler.execute()
|
||||
# 处理动作并获取结果(固定记录一次动作信息)
|
||||
result = await action_handler.run()
|
||||
success, action_text = result
|
||||
command = ""
|
||||
|
||||
|
|
@ -484,13 +480,12 @@ class BrainChatting:
|
|||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
# 直接处理no_action逻辑,不再通过动作系统
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 存储no_action信息到数据库
|
||||
# 存储no_reply信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
|
|
@ -498,9 +493,9 @@ class BrainChatting:
|
|||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_action",
|
||||
action_name="no_reply",
|
||||
)
|
||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
try:
|
||||
|
|
@ -517,7 +512,9 @@ class BrainChatting:
|
|||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
|
|
|||
|
|
@ -152,10 +152,10 @@ class BrainPlanner:
|
|||
action_planner_infos = []
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_action")
|
||||
action = action_json.get("action", "no_reply")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 非no_action动作需要target_message_id
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
|
|
@ -215,12 +215,11 @@ class BrainPlanner:
|
|||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
) -> Tuple[List[ActionPlannerInfo], Optional["DatabaseMessages"]]:
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
target_message: Optional["DatabaseMessages"] = None
|
||||
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
|
|
@ -274,12 +273,7 @@ class BrainPlanner:
|
|||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 获取target_message(如果有非no_action的动作)
|
||||
non_no_actions = [a for a in actions if a.action_type != "no_reply"]
|
||||
if non_no_actions:
|
||||
target_message = non_no_actions[0].action_message
|
||||
|
||||
return actions, target_message
|
||||
return actions
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
|
|
@ -307,7 +301,9 @@ class BrainPlanner:
|
|||
|
||||
if chat_target_info:
|
||||
# 构建聊天上下文描述
|
||||
chat_context_description = f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
|
||||
chat_context_description = (
|
||||
f"你正在和 {chat_target_info.person_name or chat_target_info.user_nickname or '对方'} 聊天中"
|
||||
)
|
||||
|
||||
# 构建动作选项块
|
||||
action_options_block = await self._build_action_options_block(current_available_actions)
|
||||
|
|
@ -487,19 +483,19 @@ class BrainPlanner:
|
|||
else:
|
||||
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
|
||||
|
||||
# 添加循环开始时间到所有非no_action动作
|
||||
# 添加循环开始时间到所有非no_reply动作
|
||||
for action in actions:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
"""创建no_action"""
|
||||
"""创建no_reply"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
|
|
|
|||
|
|
@ -1,604 +0,0 @@
|
|||
import time
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
DECAY_DAYS = 15 # 30天衰减到0.01
|
||||
DECAY_MIN = 0.01 # 最小衰减值
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def init_prompt() -> None:
|
||||
learn_style_prompt = """
|
||||
{chat_str}
|
||||
|
||||
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
|
||||
1. 只考虑文字,不要考虑表情包和图片
|
||||
2. 不要涉及具体的人名,但是可以涉及具体名词
|
||||
3. 思考有没有特殊的梗,一并总结成语言风格
|
||||
4. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||
|
||||
例如:
|
||||
当"对某件事表示十分惊叹"时,使用"我嘞个xxxx"
|
||||
当"表示讽刺的赞同,不讲道理"时,使用"对对对"
|
||||
当"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
|
||||
当"当涉及游戏相关时,夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
|
||||
请注意:不要总结你自己(SELF)的发言,尽量保证总结内容的逻辑性
|
||||
现在请你概括
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer, request_type="expression.learner"
|
||||
)
|
||||
self.chat_id = chat_id
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 维护每个chat的上次学习时间
|
||||
self.last_learning_time: float = time.time()
|
||||
|
||||
# 学习参数
|
||||
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
|
||||
|
||||
def can_learn_for_chat(self) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许学习表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许学习
|
||||
"""
|
||||
try:
|
||||
use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
return enable_learning
|
||||
except Exception as e:
|
||||
logger.error(f"检查学习权限失败: {e}")
|
||||
return False
|
||||
|
||||
def should_trigger_learning(self) -> bool:
|
||||
"""
|
||||
检查是否应该触发学习
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否应该触发学习
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 获取该聊天流的学习强度
|
||||
try:
|
||||
_, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||
self.chat_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
|
||||
return False
|
||||
|
||||
# 检查是否允许学习
|
||||
if not enable_learning:
|
||||
return False
|
||||
|
||||
# 根据学习强度计算最短学习时间间隔
|
||||
min_interval = self.min_learning_interval / learning_intensity
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = current_time - self.last_learning_time
|
||||
if time_diff < min_interval:
|
||||
return False
|
||||
|
||||
# 检查消息数量(只检查指定聊天流的消息)
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def trigger_learning_for_chat(self) -> bool:
|
||||
"""
|
||||
为指定聊天流触发学习
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功触发学习
|
||||
"""
|
||||
if not self.should_trigger_learning():
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
|
||||
|
||||
# 学习语言风格
|
||||
learnt_style = await self.learn_and_store(num=25)
|
||||
|
||||
# 更新学习时间
|
||||
self.last_learning_time = time.time()
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
return False
|
||||
|
||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
"""
|
||||
try:
|
||||
# 获取所有表达方式
|
||||
all_expressions = Expression.select()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
for expr in all_expressions:
|
||||
# 计算时间差
|
||||
last_active = expr.last_active_time
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
|
||||
# 计算衰减值
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
new_count = max(0.01, expr.count - decay_value)
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
expr.delete_instance()
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
|
||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||
"""
|
||||
计算衰减值
|
||||
当时间差为0天时,衰减值为0(最近活跃的不衰减)
|
||||
当时间差为7天时,衰减值为0.002(中等衰减)
|
||||
当时间差为30天或更长时,衰减值为0.01(高衰减)
|
||||
使用二次函数进行曲线插值
|
||||
"""
|
||||
if time_diff_days <= 0:
|
||||
return 0.0 # 刚激活的表达式不衰减
|
||||
|
||||
if time_diff_days >= DECAY_DAYS:
|
||||
return 0.01 # 长时间未活跃的表达式大幅衰减
|
||||
|
||||
# 使用二次函数插值:在0-30天之间从0衰减到0.01
|
||||
# 使用简单的二次函数:y = a * x^2
|
||||
# 当x=30时,y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
|
||||
a = 0.01 / (DECAY_DAYS**2)
|
||||
decay = a * (time_diff_days**2)
|
||||
|
||||
return min(0.01, decay)
|
||||
|
||||
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
"""
|
||||
# 检查是否允许在此聊天流中学习(在函数最前面检查)
|
||||
if not self.can_learn_for_chat():
|
||||
logger.debug(f"聊天流 {self.chat_name} 不允许学习表达,跳过学习")
|
||||
return []
|
||||
|
||||
res = await self.learn_expression(num)
|
||||
|
||||
if res is None:
|
||||
return []
|
||||
learnt_expressions, chat_id = res
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if chat_stream is None:
|
||||
group_name = f"聊天流 {chat_id}"
|
||||
elif chat_stream.group_info:
|
||||
group_name = chat_stream.group_info.group_name
|
||||
else:
|
||||
group_name = f"{chat_stream.user_info.user_nickname}的私聊"
|
||||
learnt_expressions_str = ""
|
||||
for _chat_id, situation, style in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {group_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
if not learnt_expressions:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for chat_id, situation, style in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append({"situation": situation, "style": style})
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == "style")
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type="style",
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
.order_by(Expression.count.asc())
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
num: 学习数量
|
||||
"""
|
||||
type_str = "语言风格"
|
||||
prompt = "learn_style_prompt"
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习时间
|
||||
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=current_time,
|
||||
limit=num,
|
||||
)
|
||||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
# 转化成str
|
||||
chat_id: str = random_msg[0].chat_id
|
||||
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
logger.debug(f"学习{type_str}的prompt: {prompt}")
|
||||
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习{type_str}失败: {e}")
|
||||
return None
|
||||
|
||||
logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
|
||||
return expressions, chat_id
|
||||
|
||||
def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
"""
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# 查找"当"和下一个引号
|
||||
idx_when = line.find('当"')
|
||||
if idx_when == -1:
|
||||
continue
|
||||
idx_quote1 = idx_when + 1
|
||||
idx_quote2 = line.find('"', idx_quote1 + 1)
|
||||
if idx_quote2 == -1:
|
||||
continue
|
||||
situation = line[idx_quote1 + 1 : idx_quote2]
|
||||
# 查找"使用"
|
||||
idx_use = line.find('使用"', idx_quote2)
|
||||
if idx_use == -1:
|
||||
continue
|
||||
idx_quote3 = idx_use + 2
|
||||
idx_quote4 = line.find('"', idx_quote3 + 1)
|
||||
if idx_quote4 == -1:
|
||||
continue
|
||||
style = line[idx_quote3 + 1 : idx_quote4]
|
||||
expressions.append((chat_id, situation, style))
|
||||
return expressions
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
class ExpressionLearnerManager:
|
||||
def __init__(self):
|
||||
self.expression_learners = {}
|
||||
|
||||
self._ensure_expression_directories()
|
||||
self._auto_migrate_json_to_db()
|
||||
self._migrate_old_data_create_date()
|
||||
|
||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
if chat_id not in self.expression_learners:
|
||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||
return self.expression_learners[chat_id]
|
||||
|
||||
def _ensure_expression_directories(self):
|
||||
"""
|
||||
确保表达方式相关的目录结构存在
|
||||
"""
|
||||
base_dir = os.path.join("data", "expression")
|
||||
directories_to_create = [
|
||||
base_dir,
|
||||
os.path.join(base_dir, "learnt_style"),
|
||||
os.path.join(base_dir, "learnt_grammar"),
|
||||
]
|
||||
|
||||
for directory in directories_to_create:
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
def _auto_migrate_json_to_db(self):
|
||||
"""
|
||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
||||
然后检查done.done2,如果没有就删除所有grammar表达并创建该标记文件。
|
||||
"""
|
||||
base_dir = os.path.join("data", "expression")
|
||||
done_flag = os.path.join(base_dir, "done.done")
|
||||
done_flag2 = os.path.join(base_dir, "done.done2")
|
||||
|
||||
# 确保基础目录存在
|
||||
try:
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {base_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建表达方式目录失败: {e}")
|
||||
return
|
||||
|
||||
if os.path.exists(done_flag):
|
||||
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
||||
else:
|
||||
logger.info("开始迁移表达方式JSON到数据库...")
|
||||
migrated_count = 0
|
||||
|
||||
for type in ["learnt_style", "learnt_grammar"]:
|
||||
type_str = "style" if type == "learnt_style" else "grammar"
|
||||
type_dir = os.path.join(base_dir, type)
|
||||
if not os.path.exists(type_dir):
|
||||
logger.debug(f"目录不存在,跳过: {type_dir}")
|
||||
continue
|
||||
|
||||
try:
|
||||
chat_ids = os.listdir(type_dir)
|
||||
logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
|
||||
except Exception as e:
|
||||
logger.error(f"读取目录失败 {type_dir}: {e}")
|
||||
continue
|
||||
|
||||
for chat_id in chat_ids:
|
||||
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(expr_file):
|
||||
continue
|
||||
try:
|
||||
with open(expr_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
|
||||
if not isinstance(expressions, list):
|
||||
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
|
||||
continue
|
||||
|
||||
for expr in expressions:
|
||||
if not isinstance(expr, dict):
|
||||
continue
|
||||
|
||||
situation = expr.get("situation")
|
||||
style_val = expr.get("style")
|
||||
count = expr.get("count", 1)
|
||||
last_active_time = expr.get("last_active_time", time.time())
|
||||
|
||||
if not situation or not style_val:
|
||||
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
|
||||
continue
|
||||
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
|
||||
# 标记迁移完成
|
||||
try:
|
||||
# 确保done.done文件的父目录存在
|
||||
done_parent_dir = os.path.dirname(done_flag)
|
||||
if not os.path.exists(done_parent_dir):
|
||||
os.makedirs(done_parent_dir, exist_ok=True)
|
||||
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
||||
|
||||
with open(done_flag, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
||||
except PermissionError as e:
|
||||
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
||||
except OSError as e:
|
||||
logger.error(f"文件系统错误,无法写入done.done标记文件: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"写入done.done标记文件失败: {e}")
|
||||
|
||||
# 检查并处理grammar表达删除
|
||||
if not os.path.exists(done_flag2):
|
||||
logger.info("开始删除所有grammar类型的表达...")
|
||||
try:
|
||||
deleted_count = self.delete_all_grammar_expressions()
|
||||
logger.info(f"grammar表达删除完成,共删除 {deleted_count} 个表达")
|
||||
|
||||
# 创建done.done2标记文件
|
||||
with open(done_flag2, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info("已创建done.done2标记文件,grammar表达删除标记完成")
|
||||
except Exception as e:
|
||||
logger.error(f"删除grammar表达或创建标记文件失败: {e}")
|
||||
else:
|
||||
logger.info("grammar表达已删除,跳过重复删除")
|
||||
|
||||
def _migrate_old_data_create_date(self):
|
||||
"""
|
||||
为没有create_date的老数据设置创建日期
|
||||
使用last_active_time作为create_date的默认值
|
||||
"""
|
||||
try:
|
||||
# 查找所有create_date为空的表达方式
|
||||
old_expressions = Expression.select().where(Expression.create_date.is_null())
|
||||
updated_count = 0
|
||||
|
||||
for expr in old_expressions:
|
||||
# 使用last_active_time作为create_date
|
||||
expr.create_date = expr.last_active_time
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0:
|
||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移老数据创建日期失败: {e}")
|
||||
|
||||
def delete_all_grammar_expressions(self) -> int:
|
||||
"""
|
||||
检查expression库中所有type为"grammar"的表达并全部删除
|
||||
|
||||
Returns:
|
||||
int: 删除的grammar表达数量
|
||||
"""
|
||||
try:
|
||||
# 查询所有type为"grammar"的表达
|
||||
grammar_expressions = Expression.select().where(Expression.type == "grammar")
|
||||
grammar_count = grammar_expressions.count()
|
||||
|
||||
if grammar_count == 0:
|
||||
logger.info("expression库中没有找到grammar类型的表达")
|
||||
return 0
|
||||
|
||||
logger.info(f"找到 {grammar_count} 个grammar类型的表达,开始删除...")
|
||||
|
||||
# 删除所有grammar类型的表达
|
||||
deleted_count = 0
|
||||
for expr in grammar_expressions:
|
||||
try:
|
||||
expr.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"删除grammar表达失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"成功删除 {deleted_count} 个grammar类型的表达")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除grammar表达过程中发生错误: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
expression_learner_manager = ExpressionLearnerManager()
|
||||
|
|
@ -1,316 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
expression_evaluation_prompt = """
|
||||
以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
|
||||
你的名字是{bot_name}{target_message}
|
||||
|
||||
以下是可选的表达情境:
|
||||
{all_situations}
|
||||
|
||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||
考虑因素包括:
|
||||
1. 聊天的情绪氛围(轻松、严肃、幽默等)
|
||||
2. 话题类型(日常、技术、游戏、情感等)
|
||||
3. 情境与当前语境的匹配度
|
||||
{target_message_extra_block}
|
||||
|
||||
请以JSON格式输出,只需要输出选中的情境编号:
|
||||
例如:
|
||||
{{
|
||||
"selected_situations": [2, 3, 5, 7, 19]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出,不要包含其他内容:
|
||||
"""
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
||||
"""按权重随机抽样"""
|
||||
if not population or not weights or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 使用累积权重的方法进行加权抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
weights_copy = weights.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
|
||||
# 选择一个元素
|
||||
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
|
||||
selected.append(population_copy.pop(chosen_idx))
|
||||
weights_copy.pop(chosen_idx)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 按权重抽样(使用count作为权重)
|
||||
if style_exprs:
|
||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
||||
selected_style = weighted_sample(style_exprs, style_weights, total_num)
|
||||
else:
|
||||
selected_style = []
|
||||
return selected_style
|
||||
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
expr_type: str = expr.get("type", "style")
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 1. 获取20个随机表达方式(现在按权重抽取)
|
||||
style_exprs = self.get_random_expressions(chat_id, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in style_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("没有找到可用的表达方式")
|
||||
return [], []
|
||||
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要回复消息:{target_message}"
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_observe_info=chat_info,
|
||||
all_situations=all_situations_str,
|
||||
max_num=max_num,
|
||||
target_message=target_message_str,
|
||||
target_message_extra_block=target_message_extra_block,
|
||||
)
|
||||
|
||||
# 4. 调用LLM
|
||||
try:
|
||||
# start_time = time.time()
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
||||
|
||||
# logger.info(f"模型名称: {model_name}")
|
||||
# logger.info(f"LLM返回结果: {content}")
|
||||
# if reasoning_content:
|
||||
# logger.info(f"LLM推理: {reasoning_content}")
|
||||
# else:
|
||||
# logger.info(f"LLM推理: 无")
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
# 5. 解析结果
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return [], []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
# 根据索引获取完整的表达方式
|
||||
valid_expressions: List[Dict[str, Any]] = []
|
||||
selected_ids = []
|
||||
for idx in selected_indices:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
selected_ids.append(expression["id"])
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的所有表达方式,一次性更新count数
|
||||
if valid_expressions:
|
||||
self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||
|
||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||
|
|
@ -43,4 +43,4 @@ class FrequencyControlManager:
|
|||
|
||||
|
||||
# 创建全局实例
|
||||
frequency_control_manager = FrequencyControlManager()
|
||||
frequency_control_manager = FrequencyControlManager()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from multiprocessing import context
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
|
|
@ -17,14 +18,15 @@ from src.chat.planner_actions.action_modifier import ActionModifier
|
|||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.memory_system.question_maker import QuestionMaker
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
|
|
@ -98,11 +100,15 @@ class HeartFChatting:
|
|||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.last_read_time = time.time() - 2
|
||||
|
||||
self.talk_threshold = global_config.chat.talk_value
|
||||
|
||||
self.no_reply_until_call = False
|
||||
|
||||
self.is_mute = False
|
||||
|
||||
self.last_active_time = time.time() # 记录上一次非noreply时间
|
||||
|
||||
self.questioned = False
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
|
|
@ -154,16 +160,19 @@ class HeartFChatting:
|
|||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
||||
if elapsed < 0.1:
|
||||
# 不显示小于0.1秒的计时器
|
||||
continue
|
||||
formatted_time = f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒" # type: ignore
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore
|
||||
+ (f"详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
async def _loopbody(self): # sourcery skip: hoist-if-from-if
|
||||
async def _loopbody(self):
|
||||
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
|
|
@ -174,7 +183,43 @@ class HeartFChatting:
|
|||
filter_command=True,
|
||||
)
|
||||
|
||||
question_probability = 0
|
||||
if time.time() - self.last_active_time > 3600:
|
||||
question_probability = 0.001
|
||||
elif time.time() - self.last_active_time > 1200:
|
||||
question_probability = 0.0003
|
||||
else:
|
||||
question_probability = 0.0001
|
||||
|
||||
question_probability = question_probability * global_config.chat.get_auto_chat_value(self.stream_id)
|
||||
|
||||
# print(f"{self.log_prefix} questioned: {self.questioned},len: {len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id))}")
|
||||
if question_probability > 0 and not self.questioned and len(global_conflict_tracker.get_questions_by_chat_id(self.stream_id)) == 0: #长久没有回复,可以试试主动发言,提问概率随着时间增加
|
||||
# logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,概率: {question_probability}")
|
||||
if random.random() < question_probability: # 30%概率主动发言
|
||||
try:
|
||||
self.questioned = True
|
||||
self.last_active_time = time.time()
|
||||
# print(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
|
||||
logger.info(f"{self.log_prefix} 长久没有回复,可以试试主动发言,开始生成问题")
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
question_maker = QuestionMaker(self.stream_id)
|
||||
question, context,conflict_context = await question_maker.make_question()
|
||||
if question:
|
||||
logger.info(f"{self.log_prefix} 问题: {question}")
|
||||
await global_conflict_tracker.track_conflict(question, conflict_context, True, self.stream_id)
|
||||
await self._lift_question_reply(question,context,thinking_id)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 无问题")
|
||||
# self.end_cycle(cycle_timers, thinking_id)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 主动提问失败: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
if len(recent_messages_list) >= 1:
|
||||
# for message in recent_messages_list:
|
||||
# print(message.processed_plain_text)
|
||||
# !处理no_reply_until_call逻辑
|
||||
if self.no_reply_until_call:
|
||||
for message in recent_messages_list:
|
||||
|
|
@ -185,6 +230,7 @@ class HeartFChatting:
|
|||
or time.time() - self.last_read_time > 600
|
||||
):
|
||||
self.no_reply_until_call = False
|
||||
self.last_read_time = time.time()
|
||||
break
|
||||
# 没有提到,继续保持沉默
|
||||
if self.no_reply_until_call:
|
||||
|
|
@ -200,15 +246,22 @@ class HeartFChatting:
|
|||
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
||||
mentioned_message = message
|
||||
|
||||
logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
||||
|
||||
# *控制频率用
|
||||
if mentioned_message:
|
||||
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
|
||||
elif random.random() < global_config.chat.talk_value * frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust():
|
||||
elif (
|
||||
random.random()
|
||||
< global_config.chat.get_talk_value(self.stream_id)
|
||||
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
|
||||
):
|
||||
await self._observe(recent_messages_list=recent_messages_list)
|
||||
else:
|
||||
# 没有提到,继续保持沉默,等待5秒防止频繁触发
|
||||
await asyncio.sleep(10)
|
||||
return True
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
await asyncio.sleep(0.2)
|
||||
return True
|
||||
return True
|
||||
|
|
@ -272,12 +325,14 @@ class HeartFChatting:
|
|||
recent_messages_list = []
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
if s4u_config.enable_s4u:
|
||||
await send_typing()
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
await self.expression_learner.trigger_learning_for_chat()
|
||||
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
|
||||
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
|
|
@ -322,7 +377,7 @@ class HeartFChatting:
|
|||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
action_to_use_info, _ = await self.action_planner.plan(
|
||||
action_to_use_info = await self.action_planner.plan(
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
|
@ -344,6 +399,10 @@ class HeartFChatting:
|
|||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
|
||||
)
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
action_tasks = [
|
||||
asyncio.create_task(
|
||||
|
|
@ -361,21 +420,26 @@ class HeartFChatting:
|
|||
action_success = False
|
||||
action_reply_text = ""
|
||||
|
||||
excute_result_str = ""
|
||||
for result in results:
|
||||
excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n"
|
||||
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
action_reply_text = result["reply_text"]
|
||||
action_reply_text = result["result"]
|
||||
elif result["action_type"] == "reply":
|
||||
if result["success"]:
|
||||
reply_loop_info = result["loop_info"]
|
||||
reply_text_from_reply = result["reply_text"]
|
||||
reply_text_from_reply = result["result"]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
||||
|
||||
self.action_planner.add_plan_excute_log(result=excute_result_str)
|
||||
|
||||
# 构建最终的循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
|
|
@ -405,12 +469,12 @@ class HeartFChatting:
|
|||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
"""S4U内容,暂时保留"""
|
||||
if s4u_config.enable_s4u:
|
||||
await stop_typing()
|
||||
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
||||
"""S4U内容,暂时保留"""
|
||||
|
||||
end_time = time.time()
|
||||
if end_time - start_time < global_config.chat.planner_smooth:
|
||||
wait_time = global_config.chat.planner_smooth - (end_time - start_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
return True
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
|
|
@ -435,7 +499,7 @@ class HeartFChatting:
|
|||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
reasoning: str,
|
||||
action_reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
|
|
@ -446,11 +510,11 @@ class HeartFChatting:
|
|||
|
||||
参数:
|
||||
action: 动作类型
|
||||
reasoning: 决策理由
|
||||
action_reasoning: 决策理由
|
||||
action_data: 动作数据,包含不同动作需要的参数
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
|
||||
action_message: 消息数据
|
||||
返回:
|
||||
tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令)
|
||||
"""
|
||||
|
|
@ -460,33 +524,101 @@ class HeartFChatting:
|
|||
action_handler = self.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
action_reasoning=action_reasoning,
|
||||
action_message=action_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
return False, ""
|
||||
|
||||
if not action_handler:
|
||||
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
|
||||
return False, "", ""
|
||||
|
||||
# 处理动作并获取结果
|
||||
# 处理动作并获取结果(固定记录一次动作信息)
|
||||
result = await action_handler.execute()
|
||||
success, action_text = result
|
||||
command = ""
|
||||
|
||||
return success, action_text, command
|
||||
|
||||
return success, action_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
return False, ""
|
||||
|
||||
async def _lift_question_reply(self, question: str, question_context: str, thinking_id: str):
|
||||
reason = f"在聊天中:\n{question_context}\n你对问题\"{question}\"感到好奇,想要和群友讨论"
|
||||
new_msg = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=1,
|
||||
)
|
||||
|
||||
reply_action_info = ActionPlannerInfo(
|
||||
action_type="reply",
|
||||
reasoning= "",
|
||||
action_data={},
|
||||
action_message=new_msg[0],
|
||||
available_actions=None,
|
||||
loop_start_time=time.time(),
|
||||
action_reasoning=reason)
|
||||
self.action_planner.add_plan_log(reasoning=f"你对问题\"{question}\"感到好奇,想要和群友讨论", actions=[reply_action_info])
|
||||
|
||||
success, llm_response = await generator_api.rewrite_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_data={
|
||||
"raw_reply": f"我对这个问题感到好奇:{question}",
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
logger.info("主动提问发言失败")
|
||||
self.action_planner.add_plan_excute_log(result="主动回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "主动回复生成失败", "loop_info": None}
|
||||
|
||||
if success:
|
||||
for reply_seg in llm_response.reply_set.reply_data:
|
||||
send_data = reply_seg.content
|
||||
await send_api.text_to_stream(
|
||||
text=send_data,
|
||||
stream_id=self.stream_id,
|
||||
)
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": llm_response.reply_set.reply_data[0].content},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": [reply_action_info],
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": llm_response.reply_set.reply_data[0].content,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
self.last_active_time = time.time()
|
||||
self.action_planner.add_plan_excute_log(result=f"你提问:{question}")
|
||||
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你提问:{question}",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
|
||||
async def _send_response(
|
||||
self,
|
||||
|
|
@ -498,7 +630,7 @@ class HeartFChatting:
|
|||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
need_reply = new_message_count >= random.randint(2, 3)
|
||||
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
|
||||
|
|
@ -543,59 +675,83 @@ class HeartFChatting:
|
|||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
|
||||
# 直接当场执行no_reply逻辑
|
||||
if action_planner_info.action_type == "no_reply":
|
||||
# 直接处理no_action逻辑,不再通过动作系统
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 存储no_action信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_action",
|
||||
action_data={},
|
||||
action_name="no_reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "wait_time":
|
||||
action_planner_info.action_data = action_planner_info.action_data or {}
|
||||
logger.info(f"{self.log_prefix} 等待{action_planner_info.action_data['time']}秒后回复")
|
||||
await asyncio.sleep(action_planner_info.action_data["time"])
|
||||
return {"action_type": "wait_time", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "no_reply_until_call":
|
||||
# 直接当场执行no_reply_until_call逻辑
|
||||
logger.info(f"{self.log_prefix} 保持沉默,直到有人直接叫的名字")
|
||||
reason = action_planner_info.reasoning or "选择不回复"
|
||||
|
||||
self.no_reply_until_call = True
|
||||
return {"action_type": "no_reply_until_call", "success": True, "reply_text": "", "command": ""}
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="no_reply_until_call",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
return {"action_type": "no_reply_until_call", "success": True, "result": "保持沉默,直到有人直接叫的名字", "command": ""}
|
||||
|
||||
elif action_planner_info.action_type == "reply":
|
||||
try:
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=action_planner_info.reasoning or "",
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
)
|
||||
# 直接当场执行reply逻辑
|
||||
self.questioned = False
|
||||
# 刷新主动发言状态
|
||||
|
||||
reason = action_planner_info.reasoning or "选择回复"
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={},
|
||||
action_name="reply",
|
||||
action_reasoning=reason,
|
||||
)
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_action_plan_infos,
|
||||
reply_reason=reason,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="replyer",
|
||||
from_plugin=False,
|
||||
reply_time_point = action_planner_info.action_data.get("loop_start_time", time.time()),
|
||||
)
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None}
|
||||
|
||||
if not success or not llm_response or not llm_response.reply_set:
|
||||
if action_planner_info.action_message:
|
||||
logger.info(
|
||||
f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败"
|
||||
)
|
||||
else:
|
||||
logger.info("回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
response_set = llm_response.reply_set
|
||||
selected_expressions = llm_response.selected_expressions
|
||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||
|
|
@ -606,30 +762,30 @@ class HeartFChatting:
|
|||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"reply_text": reply_text,
|
||||
"result": f"你回复内容{reply_text}",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
# 其他动作
|
||||
else:
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_planner_info.action_type,
|
||||
action_planner_info.reasoning or "",
|
||||
action_planner_info.action_data or {},
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
action_planner_info.action_message,
|
||||
success, result = await self._handle_action(
|
||||
action = action_planner_info.action_type,
|
||||
action_reasoning = action_planner_info.action_reasoning or "",
|
||||
action_data = action_planner_info.action_data or {},
|
||||
cycle_timers = cycle_timers,
|
||||
thinking_id = thinking_id,
|
||||
action_message= action_planner_info.action_message,
|
||||
)
|
||||
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
"result": result,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -638,7 +794,7 @@ class HeartFChatting:
|
|||
return {
|
||||
"action_type": action_planner_info.action_type,
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"result": "",
|
||||
"loop_info": None,
|
||||
"error": str(e),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
import asyncio
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from typing import Tuple, TYPE_CHECKING
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.chat_message_builder import replace_user_references
|
||||
from src.common.logger import get_logger
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.common.database.database_model import Images
|
||||
|
||||
|
|
@ -75,11 +72,7 @@ class HeartFCMessageReceiver:
|
|||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message))
|
||||
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
|
||||
# 3. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
|
|
@ -107,7 +100,7 @@ class HeartFCMessageReceiver:
|
|||
replace_bot_name=True,
|
||||
)
|
||||
# if not processed_plain_text:
|
||||
# print(message)
|
||||
# print(message)
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,241 +0,0 @@
|
|||
import json
|
||||
import random
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import List, Tuple
|
||||
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
Args:
|
||||
json_str: JSON格式的字符串
|
||||
|
||||
Returns:
|
||||
List[str]: 关键词列表
|
||||
"""
|
||||
try:
|
||||
# 使用repair_json修复JSON格式
|
||||
fixed_json = repair_json(json_str)
|
||||
|
||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||
return result.get("keywords", [])
|
||||
except Exception as e:
|
||||
logger.error(f"解析关键词JSON失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def init_prompt():
|
||||
# --- Group Chat Prompt ---
|
||||
memory_activator_prompt = """
|
||||
你需要根据以下信息来挑选合适的记忆编号
|
||||
以下是一段聊天记录,请根据这些信息,和下方的记忆,挑选和群聊内容有关的记忆编号
|
||||
|
||||
聊天记录:
|
||||
{obs_info_text}
|
||||
你想要回复的消息:
|
||||
{target_message}
|
||||
|
||||
记忆:
|
||||
{memory_info}
|
||||
|
||||
请输出一个json格式,包含以下字段:
|
||||
{{
|
||||
"memory_ids": "记忆1编号,记忆2编号,记忆3编号,......"
|
||||
}}
|
||||
不要输出其他多余内容,只输出json格式就好
|
||||
"""
|
||||
|
||||
Prompt(memory_activator_prompt, "memory_activator_prompt")
|
||||
|
||||
|
||||
class MemoryActivator:
|
||||
def __init__(self):
|
||||
self.key_words_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.activator",
|
||||
)
|
||||
# 用于记忆选择的 LLM 模型
|
||||
self.memory_selection_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.selection",
|
||||
)
|
||||
|
||||
async def activate_memory_with_chat_history(
|
||||
self, target_message, chat_history: List[DatabaseMessages]
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
# 如果记忆系统被禁用,直接返回空列表
|
||||
if not global_config.memory.enable_memory:
|
||||
return []
|
||||
|
||||
keywords_list = set()
|
||||
|
||||
for msg in chat_history:
|
||||
keywords = parse_keywords_string(msg.key_words)
|
||||
if keywords:
|
||||
if len(keywords_list) < 30:
|
||||
# 最多容纳30个关键词
|
||||
keywords_list.update(keywords)
|
||||
logger.debug(f"提取关键词: {keywords_list}")
|
||||
else:
|
||||
break
|
||||
|
||||
if not keywords_list:
|
||||
logger.debug("没有提取到关键词,返回空记忆列表")
|
||||
return []
|
||||
|
||||
# 从海马体获取相关记忆
|
||||
related_memory = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
|
||||
)
|
||||
|
||||
# logger.info(f"当前记忆关键词: {keywords_list}")
|
||||
logger.debug(f"获取到的记忆: {related_memory}")
|
||||
|
||||
if not related_memory:
|
||||
logger.debug("海马体没有返回相关记忆")
|
||||
return []
|
||||
|
||||
used_ids = set()
|
||||
candidate_memories = []
|
||||
|
||||
# 为每个记忆分配随机ID并过滤相关记忆
|
||||
for memory in related_memory:
|
||||
keyword, content = memory
|
||||
found = any(kw in content for kw in keywords_list)
|
||||
if found:
|
||||
# 随机分配一个不重复的2位数id
|
||||
while True:
|
||||
random_id = "{:02d}".format(random.randint(0, 99))
|
||||
if random_id not in used_ids:
|
||||
used_ids.add(random_id)
|
||||
break
|
||||
candidate_memories.append({"memory_id": random_id, "keyword": keyword, "content": content})
|
||||
|
||||
if not candidate_memories:
|
||||
logger.info("没有找到相关的候选记忆")
|
||||
return []
|
||||
|
||||
# 如果只有少量记忆,直接返回
|
||||
if len(candidate_memories) <= 2:
|
||||
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
|
||||
|
||||
return await self._select_memories_with_llm(target_message, chat_history, candidate_memories)
|
||||
|
||||
async def _select_memories_with_llm(
|
||||
self, target_message, chat_history: List[DatabaseMessages], candidate_memories
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
使用 LLM 选择合适的记忆
|
||||
|
||||
Args:
|
||||
target_message: 目标消息
|
||||
chat_history_prompt: 聊天历史
|
||||
candidate_memories: 候选记忆列表,每个记忆包含 memory_id、keyword、content
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 选择的记忆列表,格式为 (keyword, content)
|
||||
"""
|
||||
try:
|
||||
# 构建聊天历史字符串
|
||||
obs_info_text = build_readable_messages(
|
||||
chat_history,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 构建记忆信息字符串
|
||||
memory_lines = []
|
||||
for memory in candidate_memories:
|
||||
memory_id = memory["memory_id"]
|
||||
keyword = memory["keyword"]
|
||||
content = memory["content"]
|
||||
|
||||
# 将 content 列表转换为字符串
|
||||
if isinstance(content, list):
|
||||
content_str = " | ".join(str(item) for item in content)
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
|
||||
|
||||
memory_info = "\n".join(memory_lines)
|
||||
|
||||
# 获取并格式化 prompt
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
|
||||
formatted_prompt = prompt_template.format(
|
||||
obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info
|
||||
)
|
||||
|
||||
# 调用 LLM
|
||||
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
|
||||
formatted_prompt, temperature=0.3, max_tokens=150
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆选择 prompt: {formatted_prompt}")
|
||||
logger.info(f"LLM 记忆选择响应: {response}")
|
||||
else:
|
||||
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
|
||||
logger.debug(f"LLM 记忆选择响应: {response}")
|
||||
|
||||
# 解析响应获取选择的记忆编号
|
||||
try:
|
||||
fixed_json = repair_json(response)
|
||||
|
||||
# 解析为 Python 对象
|
||||
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||
|
||||
# 提取 memory_ids 字段并解析逗号分隔的编号
|
||||
if memory_ids_str := result.get("memory_ids", ""):
|
||||
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
|
||||
# 过滤掉空字符串和无效编号
|
||||
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
|
||||
selected_memory_ids = valid_memory_ids
|
||||
else:
|
||||
selected_memory_ids = []
|
||||
except Exception as e:
|
||||
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
|
||||
selected_memory_ids = []
|
||||
|
||||
# 根据编号筛选记忆
|
||||
selected_memories = []
|
||||
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
|
||||
|
||||
selected_memories = [
|
||||
memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory
|
||||
]
|
||||
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
|
||||
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
|
||||
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
|
||||
# 出错时返回前3个候选记忆作为备选,转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
|
@ -3,19 +3,18 @@ import os
|
|||
import re
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from maim_message import UserInfo, Seg
|
||||
from maim_message import UserInfo, Seg, GroupInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
# 定义日志配置
|
||||
|
|
@ -27,7 +26,7 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
|
|||
logger = get_logger("chat")
|
||||
|
||||
|
||||
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
def _check_ban_words(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
|
||||
"""检查消息是否包含过滤词
|
||||
|
||||
Args:
|
||||
|
|
@ -40,14 +39,14 @@ def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
|||
"""
|
||||
for word in global_config.message_receive.ban_words:
|
||||
if word in text:
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
chat_name = group_info.group_name if group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
def _check_ban_regex(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式
|
||||
|
||||
Args:
|
||||
|
|
@ -61,10 +60,10 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
|||
# 检查text是否为None或空字符串
|
||||
if text is None or not text:
|
||||
return False
|
||||
|
||||
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
chat_name = group_info.group_name if group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return True
|
||||
|
|
@ -78,8 +77,6 @@ class ChatBot:
|
|||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
if not self._started:
|
||||
|
|
@ -153,35 +150,10 @@ class ChatBot:
|
|||
if message.message_info.message_id == "notice":
|
||||
message.is_notify = True
|
||||
logger.info("notice消息")
|
||||
# print(message)
|
||||
print(message)
|
||||
|
||||
return True
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容
|
||||
await message.process()
|
||||
|
||||
_ = Person.register_person(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_id=message.message_info.user_info.user_id, # type: ignore
|
||||
nickname=user_info.user_nickname, # type: ignore
|
||||
)
|
||||
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
|
||||
return
|
||||
|
||||
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
|
||||
|
|
@ -219,11 +191,6 @@ class ChatBot:
|
|||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
platform = message_data["message_info"].get("platform")
|
||||
|
||||
if platform == "amaidesu_default":
|
||||
await self.do_s4u(message_data)
|
||||
return
|
||||
|
||||
if message_data["message_info"].get("group_info") is not None:
|
||||
message_data["message_info"]["group_info"]["group_id"] = str(
|
||||
|
|
@ -251,6 +218,21 @@ class ChatBot:
|
|||
# return
|
||||
pass
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(
|
||||
message.processed_plain_text,
|
||||
user_info, # type: ignore
|
||||
group_info,
|
||||
) or _check_ban_regex(
|
||||
message.raw_message, # type: ignore
|
||||
user_info, # type: ignore
|
||||
group_info,
|
||||
):
|
||||
return
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
|
|
@ -261,21 +243,10 @@ class ChatBot:
|
|||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# if await self.check_ban_content(message):
|
||||
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
||||
# return
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
message.raw_message, # type: ignore
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
):
|
||||
return
|
||||
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ class MessageRecv(Message):
|
|||
|
||||
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
||||
"""
|
||||
# print(f"self.message_segment: {self.message_segment}")
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
|
|
@ -208,129 +209,6 @@ class MessageRecv(Message):
|
|||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecvS4U(MessageRecv):
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
super().__init__(message_dict)
|
||||
self.is_gift = False
|
||||
self.is_fake_gift = False
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count: Optional[str] = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
self.is_screen = False
|
||||
self.is_internal = False
|
||||
self.voice_done = None
|
||||
|
||||
self.chat_info = None
|
||||
|
||||
async def process(self) -> None:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
self.is_voice = False
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.has_picid = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "gift":
|
||||
self.is_voice = False
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1) # type: ignore
|
||||
self.gift_info = segment.data
|
||||
self.gift_name = name.strip()
|
||||
self.gift_count = int(count.strip())
|
||||
return ""
|
||||
elif segment.type == "voice_done":
|
||||
msg_id = segment.data
|
||||
logger.info(f"voice_done: {msg_id}")
|
||||
self.voice_done = msg_id
|
||||
return ""
|
||||
elif segment.type == "superchat":
|
||||
self.is_superchat = True
|
||||
self.superchat_info = segment.data
|
||||
price, message_text = segment.data.split(":", 1) # type: ignore
|
||||
self.superchat_price = price.strip()
|
||||
self.superchat_message_text = message_text.strip()
|
||||
|
||||
self.processed_plain_text = str(self.superchat_message_text)
|
||||
self.processed_plain_text += (
|
||||
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||
)
|
||||
|
||||
return self.processed_plain_text
|
||||
elif segment.type == "screen":
|
||||
self.is_screen = True
|
||||
self.screen_info = segment.data
|
||||
return "屏幕信息"
|
||||
else:
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageProcessBase(Message):
|
||||
"""消息处理基类,用于处理中和发送中的消息"""
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class ActionManager:
|
|||
self,
|
||||
action_name: str,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
action_reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
chat_stream: ChatStream,
|
||||
|
|
@ -46,7 +46,7 @@ class ActionManager:
|
|||
Args:
|
||||
action_name: 动作名称
|
||||
action_data: 动作数据
|
||||
reasoning: 执行理由
|
||||
action_reasoning: 执行理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
chat_stream: 聊天流
|
||||
|
|
@ -77,7 +77,7 @@ class ActionManager:
|
|||
# 创建动作实例
|
||||
instance = component_class(
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
action_reasoning=action_reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=chat_stream,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import time
|
|||
import traceback
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING
|
||||
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
|
|
@ -44,14 +44,13 @@ def init_prompt():
|
|||
**聊天内容**
|
||||
{chat_content_block}
|
||||
|
||||
**动作记录**
|
||||
{actions_before_now_block}
|
||||
|
||||
**可用的action**
|
||||
**可选的action**
|
||||
reply
|
||||
动作描述:
|
||||
1.你可以选择呼叫了你的名字,但是你没有做出回应的消息进行回复
|
||||
2.你可以自然的顺着正在进行的聊天内容进行回复或自然的提出一个问题
|
||||
3.不要回复你自己发送的消息
|
||||
4.不要单独对表情包进行回复
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id":"想要回复的消息id",
|
||||
|
|
@ -76,7 +75,11 @@ no_reply_until_call
|
|||
|
||||
{action_options_text}
|
||||
|
||||
请选择合适的action,并说明触发action的消息id和选择该action的原因。消息id格式:m+数字
|
||||
**你之前的action执行和思考记录**
|
||||
{actions_before_now_block}
|
||||
|
||||
请选择**可选的**且符合使用条件的action,并说明触发action的消息id(消息id格式:m+数字)
|
||||
不要回复你自己发送的消息
|
||||
先输出你的选择思考理由,再输出你选择的action,理由是一段平文本,不要分点,精简。
|
||||
**动作选择要求**
|
||||
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
|
||||
|
|
@ -99,9 +102,7 @@ no_reply_until_call
|
|||
"target_message_id":"触发动作的消息id",
|
||||
//对应参数
|
||||
}}
|
||||
```
|
||||
|
||||
""",
|
||||
```""",
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
|
|
@ -109,7 +110,7 @@ no_reply_until_call
|
|||
"""
|
||||
{action_name}
|
||||
动作描述:{action_description}
|
||||
使用条件:
|
||||
使用条件{parallel_text}:
|
||||
{action_require}
|
||||
{{
|
||||
"action": "{action_name}",{action_parameters},
|
||||
|
|
@ -133,6 +134,9 @@ class ActionPlanner:
|
|||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
|
||||
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
|
||||
|
||||
def find_message_by_id(
|
||||
self, message_id: str, message_id_list: List[Tuple[str, "DatabaseMessages"]]
|
||||
) -> Optional["DatabaseMessages"]:
|
||||
|
|
@ -157,15 +161,16 @@ class ActionPlanner:
|
|||
action_json: dict,
|
||||
message_id_list: List[Tuple[str, "DatabaseMessages"]],
|
||||
current_available_actions: List[Tuple[str, ActionInfo]],
|
||||
extracted_reasoning: str = "",
|
||||
) -> List[ActionPlannerInfo]:
|
||||
"""解析单个action JSON并返回ActionPlannerInfo列表"""
|
||||
action_planner_infos = []
|
||||
|
||||
try:
|
||||
action = action_json.get("action", "no_action")
|
||||
action = action_json.get("action", "no_reply")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
|
||||
# 非no_action动作需要target_message_id
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
|
|
@ -202,6 +207,7 @@ class ActionPlanner:
|
|||
action_data=action_data,
|
||||
action_message=target_message,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=extracted_reasoning if extracted_reasoning else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -216,6 +222,7 @@ class ActionPlanner:
|
|||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=extracted_reasoning if extracted_reasoning else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -225,12 +232,11 @@ class ActionPlanner:
|
|||
self,
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float = 0.0,
|
||||
) -> Tuple[List[ActionPlannerInfo], Optional["DatabaseMessages"]]:
|
||||
) -> List[ActionPlannerInfo]:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
target_message: Optional["DatabaseMessages"] = None
|
||||
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
|
|
@ -276,7 +282,7 @@ class ActionPlanner:
|
|||
)
|
||||
|
||||
# 调用LLM获取决策
|
||||
actions = await self._execute_main_planner(
|
||||
reasoning, actions = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
|
|
@ -284,12 +290,34 @@ class ActionPlanner:
|
|||
loop_start_time=loop_start_time,
|
||||
)
|
||||
|
||||
# 获取target_message(如果有非no_action的动作)
|
||||
non_no_actions = [a for a in actions if a.action_type != "no_reply"]
|
||||
if non_no_actions:
|
||||
target_message = non_no_actions[0].action_message
|
||||
logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
return actions
|
||||
|
||||
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
|
||||
self.plan_log.append((reasoning, time.time(), actions))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def add_plan_excute_log(self, result: str):
|
||||
self.plan_log.append(("", time.time(), result))
|
||||
if len(self.plan_log) > 20:
|
||||
self.plan_log.pop(0)
|
||||
|
||||
def get_plan_log_str(self) -> str:
|
||||
plan_log_str = ""
|
||||
for reasoning, time, content in self.plan_log:
|
||||
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
|
||||
time = datetime.fromtimestamp(time).strftime("%H:%M:%S")
|
||||
plan_log_str += f"{time}:{reasoning}|你使用了{','.join([action.action_type for action in content])}\n"
|
||||
else:
|
||||
time = datetime.fromtimestamp(time).strftime("%H:%M:%S")
|
||||
plan_log_str += f"{time}:{content}\n"
|
||||
|
||||
return plan_log_str
|
||||
|
||||
return actions, target_message
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
|
|
@ -302,18 +330,8 @@ class ActionPlanner:
|
|||
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
# 获取最近执行过的动作
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=time.time() - 600,
|
||||
timestamp_end=time.time(),
|
||||
limit=6,
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
if actions_before_now_block:
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
else:
|
||||
actions_before_now_block = ""
|
||||
|
||||
actions_before_now_block=self.get_plan_log_str()
|
||||
|
||||
# 构建聊天上下文描述
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
|
|
@ -343,7 +361,6 @@ class ActionPlanner:
|
|||
interest=interest,
|
||||
plan_style=global_config.personality.plan_style,
|
||||
)
|
||||
|
||||
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
|
|
@ -421,6 +438,11 @@ class ActionPlanner:
|
|||
for require_item in action_info.action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
if not action_info.parallel_action:
|
||||
parallel_text = "(当选择这个动作时,请不要选择其他动作)"
|
||||
else:
|
||||
parallel_text = ""
|
||||
|
||||
# 获取动作提示模板并填充
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
||||
|
|
@ -429,6 +451,7 @@ class ActionPlanner:
|
|||
action_description=action_info.description,
|
||||
action_parameters=param_text,
|
||||
action_require=require_text,
|
||||
parallel_text=parallel_text,
|
||||
)
|
||||
|
||||
action_options_block += using_action_prompt
|
||||
|
|
@ -442,7 +465,7 @@ class ActionPlanner:
|
|||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> List[ActionPlannerInfo]:
|
||||
) -> Tuple[str,List[ActionPlannerInfo]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
|
|
@ -451,8 +474,8 @@ class ActionPlanner:
|
|||
# 调用LLM
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
|
|
@ -467,7 +490,7 @@ class ActionPlanner:
|
|||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
return [
|
||||
return f"LLM 请求失败,模型出现问题: {req_e}",[
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
|
||||
|
|
@ -478,38 +501,41 @@ class ActionPlanner:
|
|||
]
|
||||
|
||||
# 解析LLM响应
|
||||
extracted_reasoning = ""
|
||||
if llm_content:
|
||||
try:
|
||||
if json_objects := self._extract_json_from_markdown(llm_content):
|
||||
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
|
||||
if json_objects:
|
||||
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||
filtered_actions_list = list(filtered_actions.items())
|
||||
for json_obj in json_objects:
|
||||
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list))
|
||||
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list, extracted_reasoning))
|
||||
else:
|
||||
# 尝试解析为直接的JSON
|
||||
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
|
||||
extracted_reasoning = "LLM没有返回可用动作"
|
||||
actions = self._create_no_reply("LLM没有返回可用动作", available_actions)
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
extracted_reasoning = f"解析LLM响应JSON失败: {json_e}"
|
||||
actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions)
|
||||
traceback.print_exc()
|
||||
else:
|
||||
extracted_reasoning = "规划器没有获得LLM响应"
|
||||
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
|
||||
|
||||
# 添加循环开始时间到所有非no_action动作
|
||||
# 添加循环开始时间到所有非no_reply动作
|
||||
for action in actions:
|
||||
action.action_data = action.action_data or {}
|
||||
action.action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||
|
||||
return actions
|
||||
return extracted_reasoning,actions
|
||||
|
||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
"""创建no_action"""
|
||||
"""创建no_reply"""
|
||||
return [
|
||||
ActionPlannerInfo(
|
||||
action_type="no_reply",
|
||||
|
|
@ -520,15 +546,26 @@ class ActionPlanner:
|
|||
)
|
||||
]
|
||||
|
||||
def _extract_json_from_markdown(self, content: str) -> List[dict]:
|
||||
def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""从Markdown格式的内容中提取JSON对象"""
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, content, re.DOTALL)
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = content.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = content[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
|
|
@ -546,7 +583,7 @@ class ActionPlanner:
|
|||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
return json_objects
|
||||
return json_objects, reasoning_content
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ import re
|
|||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
|
|
@ -19,16 +20,17 @@ from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
|||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
|
||||
# from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.person_info import Person, is_person_known
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
|
|
@ -43,6 +45,7 @@ init_rewrite_prompt()
|
|||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
|
||||
class DefaultReplyer:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -69,6 +72,7 @@ class DefaultReplyer:
|
|||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
|
|
@ -102,6 +106,7 @@ class DefaultReplyer:
|
|||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
reply_time_point=reply_time_point,
|
||||
)
|
||||
llm_response.prompt = prompt
|
||||
llm_response.selected_expressions = selected_expressions
|
||||
|
|
@ -216,30 +221,6 @@ class DefaultReplyer:
|
|||
traceback.print_exc()
|
||||
return False, llm_response
|
||||
|
||||
async def build_relation_info(self, chat_content: str, sender: str, person_list: List[Person]):
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
if sender == global_config.bot.nickname:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person = Person(person_name=sender)
|
||||
if not is_person_known(person_name=sender):
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
sender_relation = await person.build_relationship(chat_content)
|
||||
others_relation = ""
|
||||
for person in person_list:
|
||||
person_relation = await person.build_relationship()
|
||||
others_relation += person_relation
|
||||
|
||||
return f"{sender_relation}\n{others_relation}"
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
|
|
@ -257,8 +238,8 @@ class DefaultReplyer:
|
|||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
|
||||
# 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
||||
)
|
||||
|
||||
|
|
@ -277,44 +258,42 @@ class DefaultReplyer:
|
|||
expression_habits_block = ""
|
||||
expression_habits_title = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_title = (
|
||||
"在回复时,你可以参考以下的语言习惯,不要生硬使用:"
|
||||
)
|
||||
expression_habits_title = "在回复时,你可以参考以下的语言习惯,不要生硬使用:"
|
||||
expression_habits_block += f"{style_habits_str}\n"
|
||||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
async def build_mood_state_prompt(self) -> str:
|
||||
"""构建情绪状态提示"""
|
||||
if not global_config.mood.enable_mood:
|
||||
return ""
|
||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||
return f"你现在的心情是:{mood_state}"
|
||||
|
||||
async def build_memory_block(self) -> str:
|
||||
"""构建记忆块
|
||||
"""
|
||||
# if not global_config.memory.enable_memory:
|
||||
# return ""
|
||||
|
||||
# async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
|
||||
# """构建记忆块
|
||||
if global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id):
|
||||
return f"你有以下记忆:\n{global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id)}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
# Args:
|
||||
# chat_history: 聊天历史记录
|
||||
# target: 目标消息内容
|
||||
|
||||
# Returns:
|
||||
# str: 记忆信息字符串
|
||||
# """
|
||||
|
||||
# if not global_config.memory.enable_memory:
|
||||
# return ""
|
||||
|
||||
# instant_memory = None
|
||||
|
||||
# running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||
# target_message=target, chat_history=chat_history
|
||||
# )
|
||||
# if not running_memories:
|
||||
# return ""
|
||||
|
||||
# memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
# for running_memory in running_memories:
|
||||
# keywords, content = running_memory
|
||||
# memory_str += f"- {keywords}:{content}\n"
|
||||
|
||||
# if instant_memory:
|
||||
# memory_str += f"- {instant_memory}\n"
|
||||
|
||||
# return memory_str
|
||||
async def build_question_block(self) -> str:
|
||||
"""构建问题块"""
|
||||
# if not global_config.question.enable_question:
|
||||
# return ""
|
||||
questions = global_conflict_tracker.get_questions_by_chat_id(self.chat_stream.stream_id)
|
||||
questions_str = ""
|
||||
for question in questions:
|
||||
questions_str += f"- {question.question}\n"
|
||||
if questions_str:
|
||||
return f"你在聊天中,有以下问题想要得到解答:\n{questions_str}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
|
@ -344,7 +323,7 @@ class DefaultReplyer:
|
|||
content = tool_result.get("content", "")
|
||||
result_type = tool_result.get("type", "tool_result")
|
||||
|
||||
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
|
||||
tool_info_str += f"- 【{tool_name}】: {content}\n"
|
||||
|
||||
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
|
||||
logger.info(f"获取到 {len(tool_results)} 个工具结果")
|
||||
|
|
@ -380,6 +359,64 @@ class DefaultReplyer:
|
|||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
def _replace_picids_with_descriptions(self, text: str) -> str:
|
||||
"""将文本中的[picid:xxx]替换为具体的图片描述
|
||||
|
||||
Args:
|
||||
text: 包含picid标记的文本
|
||||
|
||||
Returns:
|
||||
替换后的文本
|
||||
"""
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(match: re.Match) -> str:
|
||||
pic_id = match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
return re.sub(pic_pattern, replace_pic_id, text)
|
||||
|
||||
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
|
||||
"""分析target内容类型(基于原始picid格式)
|
||||
|
||||
Args:
|
||||
target: 目标消息内容(包含[picid:xxx]格式)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
|
||||
"""
|
||||
if not target or not target.strip():
|
||||
return False, False, "", ""
|
||||
|
||||
# 检查是否只包含picid标记
|
||||
picid_pattern = r"\[picid:[^\]]+\]"
|
||||
picid_matches = re.findall(picid_pattern, target)
|
||||
|
||||
# 移除所有picid标记后检查是否还有文字内容
|
||||
text_without_picids = re.sub(picid_pattern, "", target).strip()
|
||||
|
||||
has_only_pics = len(picid_matches) > 0 and not text_without_picids
|
||||
has_text = bool(text_without_picids)
|
||||
|
||||
# 提取图片部分(转换为[图片:描述]格式)
|
||||
pic_part = ""
|
||||
if picid_matches:
|
||||
pic_descriptions = []
|
||||
for picid_match in picid_matches:
|
||||
pic_id = picid_match[7:-1] # 提取picid:xxx中的xxx部分(从第7个字符开始)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
logger.info(f"图片ID: {pic_id}, 描述: {description}")
|
||||
# 如果description已经是[图片]格式,直接使用;否则包装为[图片:描述]格式
|
||||
if description == "[图片]":
|
||||
pic_descriptions.append(description)
|
||||
else:
|
||||
pic_descriptions.append(f"[图片:{description}]")
|
||||
pic_part = "".join(pic_descriptions)
|
||||
|
||||
return has_only_pics, has_text, pic_part, text_without_picids
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
"""构建关键词反应提示
|
||||
|
||||
|
|
@ -438,11 +475,10 @@ class DefaultReplyer:
|
|||
duration = end_time - start_time
|
||||
return name, result, duration
|
||||
|
||||
def build_s4u_chat_history_prompts(
|
||||
def build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
构建 s4u 风格的分离对话 prompt
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
|
|
@ -498,7 +534,6 @@ class DefaultReplyer:
|
|||
--------------------------------
|
||||
"""
|
||||
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
|
|
@ -516,51 +551,6 @@ class DefaultReplyer:
|
|||
|
||||
return core_dialogue_prompt, all_dialogue_prompt
|
||||
|
||||
def build_mai_think_context(
|
||||
self,
|
||||
chat_id: str,
|
||||
memory_block: str,
|
||||
relation_info: str,
|
||||
time_block: str,
|
||||
chat_target_1: str,
|
||||
chat_target_2: str,
|
||||
mood_prompt: str,
|
||||
identity_block: str,
|
||||
sender: str,
|
||||
target: str,
|
||||
chat_info: str,
|
||||
) -> Any:
|
||||
"""构建 mai_think 上下文信息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
memory_block: 记忆块内容
|
||||
relation_info: 关系信息
|
||||
time_block: 时间块内容
|
||||
chat_target_1: 聊天目标1
|
||||
chat_target_2: 聊天目标2
|
||||
mood_prompt: 情绪提示
|
||||
identity_block: 身份块内容
|
||||
sender: 发送者名称
|
||||
target: 目标消息内容
|
||||
chat_info: 聊天信息
|
||||
|
||||
Returns:
|
||||
Any: mai_think 实例
|
||||
"""
|
||||
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
||||
mai_think.memory_block = memory_block
|
||||
mai_think.relation_info_block = relation_info
|
||||
mai_think.time_block = time_block
|
||||
mai_think.chat_target = chat_target_1
|
||||
mai_think.chat_target_2 = chat_target_2
|
||||
mai_think.chat_info = chat_info
|
||||
mai_think.mood_state = mood_prompt
|
||||
mai_think.identity = identity_block
|
||||
mai_think.sender = sender
|
||||
mai_think.target = target
|
||||
return mai_think
|
||||
|
||||
async def build_actions_prompt(
|
||||
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
|
||||
) -> str:
|
||||
|
|
@ -615,6 +605,7 @@ class DefaultReplyer:
|
|||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
) -> Tuple[str, List[int]]:
|
||||
"""
|
||||
构建回复器上下文
|
||||
|
|
@ -649,23 +640,23 @@ class DefaultReplyer:
|
|||
sender = person_name
|
||||
target = reply_message.processed_plain_text
|
||||
|
||||
mood_prompt: str = ""
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
timestamp=reply_time_point,
|
||||
limit=global_config.chat.max_context_size * 1,
|
||||
)
|
||||
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
timestamp=reply_time_point,
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
|
||||
|
|
@ -686,8 +677,8 @@ class DefaultReplyer:
|
|||
if person.is_known:
|
||||
person_list_short.append(person)
|
||||
|
||||
for person in person_list_short:
|
||||
print(person.person_name)
|
||||
# for person in person_list_short:
|
||||
# print(person.person_name)
|
||||
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
message_list_before_short,
|
||||
|
|
@ -702,16 +693,15 @@ class DefaultReplyer:
|
|||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||
),
|
||||
# self._time_and_run_task(
|
||||
# self.build_relation_info(chat_talking_prompt_short, sender, person_list_short), "relation_info"
|
||||
# ),
|
||||
# self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
|
||||
self._time_and_run_task(self.build_memory_block(), "memory_block"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
),
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||
self._time_and_run_task(self.build_question_block(), "question_block"),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
|
|
@ -719,10 +709,13 @@ class DefaultReplyer:
|
|||
"expression_habits": "选取表达方式",
|
||||
"relation_info": "感受关系",
|
||||
# "memory_block": "回忆",
|
||||
"memory_block": "记忆",
|
||||
"tool_info": "使用工具",
|
||||
"prompt_info": "获取知识",
|
||||
"actions_info": "动作信息",
|
||||
"personality_prompt": "人格信息",
|
||||
"mood_state_prompt": "情绪状态",
|
||||
"question_block": "问题",
|
||||
}
|
||||
|
||||
# 处理结果
|
||||
|
|
@ -747,11 +740,14 @@ class DefaultReplyer:
|
|||
selected_expressions: List[int]
|
||||
# relation_info: str = results_dict["relation_info"]
|
||||
# memory_block: str = results_dict["memory_block"]
|
||||
memory_block: str = results_dict["memory_block"]
|
||||
tool_info: str = results_dict["tool_info"]
|
||||
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||
actions_info: str = results_dict["actions_info"]
|
||||
personality_prompt: str = results_dict["personality_prompt"]
|
||||
question_block: str = results_dict["question_block"]
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||
|
||||
if extra_info:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
|
|
@ -763,63 +759,49 @@ class DefaultReplyer:
|
|||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
if sender:
|
||||
if is_group_chat:
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{target}。引起了你的注意"
|
||||
)
|
||||
else: # private chat
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{target}。引起了你的注意"
|
||||
)
|
||||
# 使用预先分析的内容类型结果
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意"
|
||||
elif has_text:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意"
|
||||
else:
|
||||
# 其他情况(空内容等)
|
||||
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意"
|
||||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(
|
||||
message_list_before_now_long, user_id, sender
|
||||
)
|
||||
|
||||
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_self_prompt",
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
# memory_block=memory_block,
|
||||
# relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
action_descriptions=actions_info,
|
||||
mood_state=mood_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
time_block=time_block,
|
||||
target=target,
|
||||
reason=reply_reason,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
), selected_expressions
|
||||
else:
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_prompt",
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
# memory_block=memory_block,
|
||||
# relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
action_descriptions=actions_info,
|
||||
sender_name=sender,
|
||||
mood_state=mood_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
time_block=time_block,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
reply_target_block=reply_target_block,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
), selected_expressions
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"replyer_prompt",
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
memory_block=memory_block,
|
||||
knowledge_prompt=prompt_info,
|
||||
mood_state=mood_state_prompt,
|
||||
# memory_block=memory_block,
|
||||
# relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
action_descriptions=actions_info,
|
||||
sender_name=sender,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
time_block=time_block,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
reply_target_block=reply_target_block,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
question_block=question_block,
|
||||
), selected_expressions
|
||||
|
||||
async def build_prompt_rewrite_context(
|
||||
self,
|
||||
|
|
@ -833,14 +815,12 @@ class DefaultReplyer:
|
|||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
|
||||
|
||||
# 添加情绪状态获取
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
else:
|
||||
mood_prompt = ""
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
|
|
@ -858,7 +838,6 @@ class DefaultReplyer:
|
|||
# 并行执行2个构建任务
|
||||
(expression_habits_block, _), personality_prompt = await asyncio.gather(
|
||||
self.build_expression_habits(chat_talking_prompt_half, target),
|
||||
# self.build_relation_info(chat_talking_prompt_half, sender, []),
|
||||
self.build_personality_prompt(),
|
||||
)
|
||||
|
||||
|
|
@ -871,18 +850,39 @@ class DefaultReplyer:
|
|||
)
|
||||
|
||||
if sender and target:
|
||||
# 使用预先分析的内容类型结果
|
||||
if is_group_chat:
|
||||
if sender:
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
else: # private chat
|
||||
if sender:
|
||||
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。"
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
|
|
@ -918,7 +918,6 @@ class DefaultReplyer:
|
|||
reply_target_block=reply_target_block,
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
mood_state=mood_prompt, # 添加情绪状态参数
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
|
|
@ -966,13 +965,13 @@ class DefaultReplyer:
|
|||
if global_config.debug.show_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
|
|
@ -1059,6 +1058,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
|||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import re
|
|||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
|
|
@ -24,10 +24,12 @@ from src.chat.utils.chat_message_builder import (
|
|||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
|
||||
# from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
from src.person_info.person_info import Person, is_person_known
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
|
@ -69,6 +71,7 @@ class PrivateReplyer:
|
|||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[DatabaseMessages] = None,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
|
|
@ -253,8 +256,8 @@ class PrivateReplyer:
|
|||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
|
||||
# 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
||||
)
|
||||
|
||||
|
|
@ -280,38 +283,22 @@ class PrivateReplyer:
|
|||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
# async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
|
||||
# """构建记忆块
|
||||
async def build_mood_state_prompt(self) -> str:
|
||||
"""构建情绪状态提示"""
|
||||
if not global_config.mood.enable_mood:
|
||||
return ""
|
||||
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
|
||||
return f"你现在的心情是:{mood_state}"
|
||||
|
||||
# Args:
|
||||
# chat_history: 聊天历史记录
|
||||
# target: 目标消息内容
|
||||
|
||||
# Returns:
|
||||
# str: 记忆信息字符串
|
||||
# """
|
||||
|
||||
# if not global_config.memory.enable_memory:
|
||||
# return ""
|
||||
|
||||
# instant_memory = None
|
||||
|
||||
# running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||
# target_message=target, chat_history=chat_history
|
||||
# )
|
||||
# if not running_memories:
|
||||
# return ""
|
||||
|
||||
# memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
# for running_memory in running_memories:
|
||||
# keywords, content = running_memory
|
||||
# memory_str += f"- {keywords}:{content}\n"
|
||||
|
||||
# if instant_memory:
|
||||
# memory_str += f"- {instant_memory}\n"
|
||||
|
||||
# return memory_str
|
||||
|
||||
async def build_memory_block(self) -> str:
|
||||
"""构建记忆块
|
||||
"""
|
||||
if global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id):
|
||||
return f"你有以下记忆:\n{global_memory_chest.get_chat_memories_as_string(self.chat_stream.stream_id)}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
||||
|
|
@ -376,6 +363,64 @@ class PrivateReplyer:
|
|||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
def _replace_picids_with_descriptions(self, text: str) -> str:
|
||||
"""将文本中的[picid:xxx]替换为具体的图片描述
|
||||
|
||||
Args:
|
||||
text: 包含picid标记的文本
|
||||
|
||||
Returns:
|
||||
替换后的文本
|
||||
"""
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(match: re.Match) -> str:
|
||||
pic_id = match.group(1)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
return f"[图片:{description}]"
|
||||
|
||||
return re.sub(pic_pattern, replace_pic_id, text)
|
||||
|
||||
def _analyze_target_content(self, target: str) -> Tuple[bool, bool, str, str]:
|
||||
"""分析target内容类型(基于原始picid格式)
|
||||
|
||||
Args:
|
||||
target: 目标消息内容(包含[picid:xxx]格式)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool, str, str]: (是否只包含图片, 是否包含文字, 图片部分, 文字部分)
|
||||
"""
|
||||
if not target or not target.strip():
|
||||
return False, False, "", ""
|
||||
|
||||
# 检查是否只包含picid标记
|
||||
picid_pattern = r"\[picid:[^\]]+\]"
|
||||
picid_matches = re.findall(picid_pattern, target)
|
||||
|
||||
# 移除所有picid标记后检查是否还有文字内容
|
||||
text_without_picids = re.sub(picid_pattern, "", target).strip()
|
||||
|
||||
has_only_pics = len(picid_matches) > 0 and not text_without_picids
|
||||
has_text = bool(text_without_picids)
|
||||
|
||||
# 提取图片部分(转换为[图片:描述]格式)
|
||||
pic_part = ""
|
||||
if picid_matches:
|
||||
pic_descriptions = []
|
||||
for picid_match in picid_matches:
|
||||
pic_id = picid_match[7:-1] # 提取picid:xxx中的xxx部分(从第7个字符开始)
|
||||
description = translate_pid_to_description(pic_id)
|
||||
logger.debug(f"图片ID: {pic_id}, 描述: {description}")
|
||||
# 如果description已经是[图片]格式,直接使用;否则包装为[图片:描述]格式
|
||||
if description == "[图片]":
|
||||
pic_descriptions.append(description)
|
||||
else:
|
||||
pic_descriptions.append(f"[图片:{description}]")
|
||||
pic_part = "".join(pic_descriptions)
|
||||
|
||||
return has_only_pics, has_text, pic_part, text_without_picids
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
"""构建关键词反应提示
|
||||
|
||||
|
|
@ -521,13 +566,15 @@ class PrivateReplyer:
|
|||
sender = person_name
|
||||
target = reply_message.processed_plain_text
|
||||
|
||||
mood_prompt: str = ""
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
|
||||
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
|
|
@ -566,8 +613,8 @@ class PrivateReplyer:
|
|||
if person.is_known:
|
||||
person_list_short.append(person)
|
||||
|
||||
for person in person_list_short:
|
||||
print(person.person_name)
|
||||
# for person in person_list_short:
|
||||
# print(person.person_name)
|
||||
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
message_list_before_short,
|
||||
|
|
@ -585,6 +632,7 @@ class PrivateReplyer:
|
|||
self._time_and_run_task(
|
||||
self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"
|
||||
),
|
||||
self._time_and_run_task(self.build_memory_block(), "memory_block"),
|
||||
# self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
|
|
@ -592,17 +640,19 @@ class PrivateReplyer:
|
|||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
task_name_mapping = {
|
||||
"expression_habits": "选取表达方式",
|
||||
"relation_info": "感受关系",
|
||||
# "memory_block": "回忆",
|
||||
"memory_block": "回忆",
|
||||
"tool_info": "使用工具",
|
||||
"prompt_info": "获取知识",
|
||||
"actions_info": "动作信息",
|
||||
"personality_prompt": "人格信息",
|
||||
"mood_state_prompt": "情绪状态",
|
||||
}
|
||||
|
||||
# 处理结果
|
||||
|
|
@ -626,11 +676,12 @@ class PrivateReplyer:
|
|||
expression_habits_block: str
|
||||
selected_expressions: List[int]
|
||||
relation_info: str = results_dict["relation_info"]
|
||||
# memory_block: str = results_dict["memory_block"]
|
||||
memory_block: str = results_dict["memory_block"]
|
||||
tool_info: str = results_dict["tool_info"]
|
||||
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||
actions_info: str = results_dict["actions_info"]
|
||||
personality_prompt: str = results_dict["personality_prompt"]
|
||||
mood_state_prompt: str = results_dict["mood_state_prompt"]
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
|
||||
if extra_info:
|
||||
|
|
@ -642,9 +693,19 @@ class PrivateReplyer:
|
|||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
reply_target_block = (
|
||||
f"现在对方说的:{target}。引起了你的注意"
|
||||
)
|
||||
# 使用预先分析的内容类型结果
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在对方发送的图片:{pic_part}。引起了你的注意"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在对方发送了图片:{pic_part},并说:{text_part}。引起了你的注意"
|
||||
elif has_text:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在对方说的:{text_part}。引起了你的注意"
|
||||
else:
|
||||
# 其他情况(空内容等)
|
||||
reply_target_block = f"现在对方说的:{target}。引起了你的注意"
|
||||
|
||||
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
||||
return await global_prompt_manager.format_prompt(
|
||||
|
|
@ -652,12 +713,12 @@ class PrivateReplyer:
|
|||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
# memory_block=memory_block,
|
||||
mood_state=mood_state_prompt,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
action_descriptions=actions_info,
|
||||
mood_state=mood_prompt,
|
||||
dialogue_prompt=dialogue_prompt,
|
||||
time_block=time_block,
|
||||
target=target,
|
||||
|
|
@ -673,12 +734,12 @@ class PrivateReplyer:
|
|||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
# memory_block=memory_block,
|
||||
mood_state=mood_state_prompt,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=personality_prompt,
|
||||
action_descriptions=actions_info,
|
||||
mood_state=mood_prompt,
|
||||
dialogue_prompt=dialogue_prompt,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
|
|
@ -700,14 +761,14 @@ class PrivateReplyer:
|
|||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
target = re.sub(r"\\[picid:[^\\]]+\\]", "[图片]", target)
|
||||
|
||||
# 在picid替换之前分析内容类型(防止prompt注入)
|
||||
has_only_pics, has_text, pic_part, text_part = self._analyze_target_content(target)
|
||||
|
||||
# 将[picid:xxx]替换为具体的图片描述
|
||||
target = self._replace_picids_with_descriptions(target)
|
||||
|
||||
|
||||
# 添加情绪状态获取
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
else:
|
||||
mood_prompt = ""
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
|
|
@ -738,18 +799,39 @@ class PrivateReplyer:
|
|||
)
|
||||
|
||||
if sender and target:
|
||||
# 使用预先分析的内容类型结果
|
||||
if is_group_chat:
|
||||
if sender:
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = (
|
||||
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
else: # private chat
|
||||
if sender:
|
||||
reply_target_block = f"现在{sender}说的:{target}。引起了你的注意,针对这条消息回复。"
|
||||
if has_only_pics and not has_text:
|
||||
# 只包含图片
|
||||
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif has_text and pic_part:
|
||||
# 既有图片又有文字
|
||||
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
# 只包含文字
|
||||
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
|
||||
else:
|
||||
|
|
@ -785,7 +867,6 @@ class PrivateReplyer:
|
|||
reply_target_block=reply_target_block,
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
mood_state=mood_prompt, # 添加情绪状态参数
|
||||
reply_style=global_config.personality.reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
|
|
@ -839,7 +920,7 @@ class PrivateReplyer:
|
|||
prompt
|
||||
)
|
||||
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
logger.info(f"使用 {model_name} 生成回复内容: {content}")
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
# from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
|
||||
def init_lpmm_prompt():
|
||||
|
|
@ -20,5 +18,3 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
|||
""",
|
||||
name="lpmm_get_knowledge_prompt",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def init_replyer_prompt():
|
|||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}
|
||||
{expression_habits_block}{memory_block}{question_block}
|
||||
|
||||
你正在qq群里聊天,下面是群里正在聊的内容:
|
||||
{time_block}
|
||||
|
|
@ -22,40 +22,19 @@ def init_replyer_prompt():
|
|||
|
||||
{reply_target_block}。
|
||||
{identity}
|
||||
你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,
|
||||
你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出一句回复内容就好。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
现在,你说:""",
|
||||
"replyer_prompt",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}
|
||||
|
||||
你正在qq群里聊天,下面是群里正在聊的内容:
|
||||
{time_block}
|
||||
{background_dialogue_prompt}
|
||||
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。
|
||||
{identity}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
""",
|
||||
"replyer_self_prompt",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}
|
||||
{expression_habits_block}{memory_block}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
|
|
@ -63,7 +42,7 @@ def init_replyer_prompt():
|
|||
|
||||
{reply_target_block}。
|
||||
{identity}
|
||||
你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,
|
||||
你正在和{sender_name}聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,{mood_state}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
|
|
@ -74,19 +53,19 @@ def init_replyer_prompt():
|
|||
|
||||
Prompt(
|
||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||
{expression_habits_block}
|
||||
{expression_habits_block}{memory_block}
|
||||
|
||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||
{time_block}
|
||||
{dialogue_prompt}
|
||||
|
||||
你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason}
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。
|
||||
请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。{mood_state}
|
||||
{identity}
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。
|
||||
{reply_style}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。
|
||||
{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。
|
||||
""",
|
||||
"private_replyer_self_prompt",
|
||||
)
|
||||
|
|
@ -1,7 +1,5 @@
|
|||
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
# from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
|
||||
|
||||
def init_rewrite_prompt():
|
||||
|
|
@ -14,13 +12,11 @@ def init_rewrite_prompt():
|
|||
"""
|
||||
{expression_habits_block}
|
||||
{chat_target}
|
||||
{time_block}
|
||||
{chat_info}
|
||||
{identity}
|
||||
|
||||
你现在的心情是:{mood_state}
|
||||
你正在{chat_target_2},{reply_target_block}
|
||||
你想要对上述的发言进行回复,回复的具体内容(原句)是:{raw_reply}
|
||||
现在请你对这句内容进行改写,请你参考上述内容进行改写,原句是:{raw_reply}:
|
||||
原因是:{reason}
|
||||
现在请你将这条具体内容改写成一条适合在群聊中发送的回复消息。
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯
|
||||
|
|
@ -28,8 +24,8 @@ def init_rewrite_prompt():
|
|||
你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||
{keywords_reaction_prompt}
|
||||
{moderation_prompt}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,emoji,at或 @等 ),只输出一条回复就好。
|
||||
现在,你说:
|
||||
不要输出多余内容(包括冒号和引号,表情包,emoji,at或 @等 ),只输出一条回复就好。
|
||||
改写后的回复:
|
||||
""",
|
||||
"default_expressor_prompt",
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import time
|
|||
import random
|
||||
import re
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
|
|
@ -124,6 +124,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
|||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
# 直接将 limit_mode 传递给 find_messages
|
||||
# print(f"get_raw_msg_by_timestamp_with_chat: {chat_id}, {timestamp_start}, {timestamp_end}, {limit}, {limit_mode}, {filter_bot}, {filter_command}")
|
||||
return find_messages(
|
||||
message_filter=filter_query,
|
||||
sort=sort_order,
|
||||
|
|
@ -215,6 +216,7 @@ def get_actions_by_timestamp_with_chat(
|
|||
chat_id=action.chat_id,
|
||||
chat_info_stream_id=action.chat_info_stream_id,
|
||||
chat_info_platform=action.chat_info_platform,
|
||||
action_reasoning=action.action_reasoning,
|
||||
)
|
||||
for action in actions
|
||||
]
|
||||
|
|
@ -417,12 +419,6 @@ def _build_readable_messages_internal(
|
|||
timestamp = message.time
|
||||
content = message.display_message or message.processed_plain_text or ""
|
||||
|
||||
# 向下兼容
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
if "ⁿ" in content:
|
||||
content = content.replace("ⁿ", "")
|
||||
|
||||
# 处理图片ID
|
||||
if show_pic:
|
||||
content = process_pic_ids(content)
|
||||
|
|
@ -564,14 +560,12 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
|
|||
output_lines = []
|
||||
current_time = time.time()
|
||||
|
||||
# The get functions return actions sorted ascending by time. Let's reverse it to show newest first.
|
||||
# sorted_actions = sorted(actions, key=lambda x: x.get("time", 0), reverse=True)
|
||||
|
||||
for action in actions:
|
||||
action_time = action.time or current_time
|
||||
action_name = action.action_name or "未知动作"
|
||||
# action_reason = action.get(action_data")
|
||||
if action_name in ["no_action", "no_action"]:
|
||||
if action_name in ["no_reply", "no_reply"]:
|
||||
continue
|
||||
|
||||
action_prompt_display = action.action_prompt_display or "无具体内容"
|
||||
|
|
@ -593,6 +587,7 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
|
|||
|
||||
line = f"{time_ago_str},你使用了“{action_name}”,具体内容是:“{action_prompt_display}”"
|
||||
output_lines.append(line)
|
||||
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
|
|
@ -628,6 +623,7 @@ def build_readable_messages_with_id(
|
|||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
remove_emoji_stickers: bool = False,
|
||||
) -> Tuple[str, List[Tuple[str, DatabaseMessages]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
|
|
@ -644,6 +640,7 @@ def build_readable_messages_with_id(
|
|||
show_pic=show_pic,
|
||||
read_mark=read_mark,
|
||||
message_id_list=message_id_list,
|
||||
remove_emoji_stickers=remove_emoji_stickers,
|
||||
)
|
||||
|
||||
return formatted_string, message_id_list
|
||||
|
|
@ -658,6 +655,7 @@ def build_readable_messages(
|
|||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
|
||||
remove_emoji_stickers: bool = False,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
|
|
@ -672,13 +670,40 @@ def build_readable_messages(
|
|||
read_mark: 已读标记时间戳
|
||||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
remove_emoji_stickers: 是否移除表情包并过滤空消息
|
||||
"""
|
||||
# WIP HERE and BELOW ----------------------------------------------
|
||||
# 创建messages的深拷贝,避免修改原始列表
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages]
|
||||
# 如果启用移除表情包,先过滤消息
|
||||
if remove_emoji_stickers:
|
||||
filtered_messages = []
|
||||
for msg in messages:
|
||||
# 获取消息内容
|
||||
content = msg.processed_plain_text
|
||||
# 移除表情包
|
||||
emoji_pattern = r"\[表情包:[^\]]+\]"
|
||||
content = re.sub(emoji_pattern, "", content)
|
||||
|
||||
# 如果移除表情包后内容不为空,则保留消息
|
||||
if content.strip():
|
||||
filtered_messages.append(msg)
|
||||
|
||||
messages = filtered_messages
|
||||
|
||||
copy_messages: List[MessageAndActionModel] = []
|
||||
for msg in messages:
|
||||
if remove_emoji_stickers:
|
||||
# 创建 MessageAndActionModel 但移除表情包
|
||||
model = MessageAndActionModel.from_DatabaseMessages(msg)
|
||||
# 移除表情包
|
||||
if model.processed_plain_text:
|
||||
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
|
||||
copy_messages.append(model)
|
||||
else:
|
||||
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
|
||||
|
||||
if show_actions and copy_messages:
|
||||
# 获取所有消息的时间范围
|
||||
|
|
@ -862,17 +887,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
|||
user_id = msg.user_info.user_id
|
||||
content = msg.display_message or msg.processed_plain_text or ""
|
||||
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
if "ⁿ" in content:
|
||||
content = content.replace("ⁿ", "")
|
||||
|
||||
# 处理图片ID
|
||||
content = process_pic_ids(content)
|
||||
|
||||
# if not all([platform, user_id, timestamp is not None]):
|
||||
# continue
|
||||
|
||||
anon_name = get_anon_name(platform, user_id)
|
||||
# print(f"anon_name:{anon_name}")
|
||||
|
||||
|
|
@ -909,6 +926,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
|||
return formatted_string
|
||||
|
||||
|
||||
|
||||
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
||||
|
|
@ -937,3 +955,45 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||
person_ids_set.add(person_id)
|
||||
|
||||
return list(person_ids_set) # 将集合转换为列表返回
|
||||
|
||||
|
||||
async def build_bare_messages(messages: List[DatabaseMessages]) -> str:
|
||||
"""
|
||||
构建简化版消息字符串,只包含processed_plain_text内容,不考虑用户名和时间戳
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
只包含消息内容的字符串
|
||||
"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
output_lines = []
|
||||
|
||||
for msg in messages:
|
||||
# 获取纯文本内容
|
||||
content = msg.processed_plain_text or ""
|
||||
|
||||
# 处理图片ID
|
||||
pic_pattern = r"\[picid:[^\]]+\]"
|
||||
|
||||
def replace_pic_id(match):
|
||||
return "[图片]"
|
||||
|
||||
content = re.sub(pic_pattern, replace_pic_id, content)
|
||||
|
||||
# 处理用户引用格式,移除回复和@标记
|
||||
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
|
||||
content = re.sub(reply_pattern, "回复[某人]", content)
|
||||
|
||||
at_pattern = r"@<[^:<>]+:[^:<>]+>"
|
||||
content = re.sub(at_pattern, "@[某人]", content)
|
||||
|
||||
# 清理并添加到输出
|
||||
content = content.strip()
|
||||
if content:
|
||||
output_lines.append(content)
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ class Prompt(str):
|
|||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template) -> str:
|
||||
"""处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" # type: ignore
|
||||
"""处理模板中的转义花括号,替换为临时标记""" # type: ignore
|
||||
# 如果传入的是列表,将其转换为字符串
|
||||
if isinstance(template, list):
|
||||
template = "\n".join(str(item) for item in template)
|
||||
|
|
|
|||
|
|
@ -383,10 +383,6 @@ def calculate_typing_time(
|
|||
- 在所有输入结束后,额外加上回车时间0.3秒
|
||||
- 如果is_emoji为True,将使用固定1秒的输入时间
|
||||
"""
|
||||
# # 将0-1的唤醒度映射到-1到1
|
||||
# mood_arousal = mood_manager.current_mood.arousal
|
||||
# # 映射到0.5到2倍的速度系数
|
||||
# typing_speed_multiplier = 1.5**mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
|
||||
# chinese_time *= 1 / typing_speed_multiplier
|
||||
# english_time *= 1 / typing_speed_multiplier
|
||||
# 计算中文字符数
|
||||
|
|
|
|||
|
|
@ -623,3 +623,41 @@ def image_path_to_base64(image_path: str) -> str:
|
|||
return base64.b64encode(image_data).decode("utf-8")
|
||||
else:
|
||||
raise IOError(f"读取图片文件失败: {image_path}")
|
||||
|
||||
|
||||
def base64_to_image(image_base64: str, output_path: str) -> bool:
|
||||
"""将base64编码的图片保存为文件
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
bool: 是否成功保存
|
||||
|
||||
Raises:
|
||||
ValueError: 当base64编码无效时
|
||||
IOError: 当保存文件失败时
|
||||
"""
|
||||
try:
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
|
||||
# 解码base64
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 保存文件
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存base64图片失败: {e}")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -220,6 +220,7 @@ class DatabaseActionRecords(BaseDataModel):
|
|||
chat_id: str,
|
||||
chat_info_stream_id: str,
|
||||
chat_info_platform: str,
|
||||
action_reasoning:str
|
||||
):
|
||||
self.action_id = action_id
|
||||
self.time = time
|
||||
|
|
@ -234,3 +235,4 @@ class DatabaseActionRecords(BaseDataModel):
|
|||
self.chat_id = chat_id
|
||||
self.chat_info_stream_id = chat_info_stream_id
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.action_reasoning = action_reasoning
|
||||
|
|
@ -24,3 +24,4 @@ class ActionPlannerInfo(BaseDataModel):
|
|||
action_message: Optional["DatabaseMessages"] = None
|
||||
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||
loop_start_time: Optional[float] = None
|
||||
action_reasoning: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -16,4 +16,4 @@ class LLMGenerationDataModel(BaseDataModel):
|
|||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional["ReplySetModel"] = None
|
||||
reply_set: Optional["ReplySetModel"] = None
|
||||
|
|
|
|||
|
|
@ -1,64 +1,9 @@
|
|||
import os
|
||||
from pymongo import MongoClient
|
||||
from peewee import SqliteDatabase
|
||||
from pymongo.database import Database
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
_client = None
|
||||
_db = None
|
||||
|
||||
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
host = os.getenv("MONGODB_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("MONGODB_PORT", "27017"))
|
||||
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
|
||||
username = os.getenv("MONGODB_USERNAME")
|
||||
password = os.getenv("MONGODB_PASSWORD")
|
||||
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
|
||||
|
||||
if uri:
|
||||
# 支持标准mongodb://和mongodb+srv://连接字符串
|
||||
if uri.startswith(("mongodb://", "mongodb+srv://")):
|
||||
return MongoClient(uri)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid MongoDB URI format. URI must start with 'mongodb://' or 'mongodb+srv://'. "
|
||||
"For MongoDB Atlas, use 'mongodb+srv://' format. "
|
||||
"See: https://www.mongodb.com/docs/manual/reference/connection-string/"
|
||||
)
|
||||
|
||||
if username and password:
|
||||
# 如果有用户名和密码,使用认证连接
|
||||
return MongoClient(host, port, username=username, password=password, authSource=auth_source)
|
||||
|
||||
# 否则使用无认证连接
|
||||
return MongoClient(host, port)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""获取数据库连接实例,延迟初始化。"""
|
||||
global _client, _db
|
||||
if _client is None:
|
||||
_client = __create_database_instance()
|
||||
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
|
||||
return _db
|
||||
|
||||
|
||||
class DBWrapper:
|
||||
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(get_db(), name)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return get_db()[key] # type: ignore
|
||||
|
||||
|
||||
# 全局数据库访问点
|
||||
memory_db: Database = DBWrapper() # type: ignore
|
||||
|
||||
# 定义数据库文件路径
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
|
|
|||
|
|
@ -185,6 +185,8 @@ class ActionRecords(BaseModel):
|
|||
action_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
|
||||
time = DoubleField() # 消息时间戳
|
||||
|
||||
action_reasoning = TextField(null=True)
|
||||
|
||||
action_name = TextField()
|
||||
action_data = TextField()
|
||||
action_done = BooleanField(default=False)
|
||||
|
|
@ -301,46 +303,47 @@ class Expression(BaseModel):
|
|||
|
||||
situation = TextField()
|
||||
style = TextField()
|
||||
count = FloatField()
|
||||
|
||||
# new mode fields
|
||||
context = TextField(null=True)
|
||||
up_content = TextField(null=True)
|
||||
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
type = TextField()
|
||||
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
|
||||
|
||||
class Meta:
|
||||
table_name = "expression"
|
||||
|
||||
|
||||
class GraphNodes(BaseModel):
|
||||
class MemoryChest(BaseModel):
|
||||
"""
|
||||
用于存储记忆图节点的模型
|
||||
用于存储记忆仓库的模型
|
||||
"""
|
||||
|
||||
concept = TextField(unique=True, index=True) # 节点概念
|
||||
memory_items = TextField() # JSON格式存储的记忆列表
|
||||
weight = FloatField(default=0.0) # 节点权重
|
||||
hash = TextField() # 节点哈希值
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
title = TextField() # 标题
|
||||
content = TextField() # 内容
|
||||
chat_id = TextField(null=True) # 聊天ID
|
||||
locked = BooleanField(default=False) # 是否锁定
|
||||
|
||||
class Meta:
|
||||
table_name = "graph_nodes"
|
||||
table_name = "memory_chest"
|
||||
|
||||
|
||||
class GraphEdges(BaseModel):
|
||||
class MemoryConflict(BaseModel):
|
||||
"""
|
||||
用于存储记忆图边的模型
|
||||
用于存储记忆整合过程中冲突内容的模型
|
||||
"""
|
||||
|
||||
source = TextField(index=True) # 源节点
|
||||
target = TextField(index=True) # 目标节点
|
||||
strength = IntegerField() # 连接强度
|
||||
hash = TextField() # 边哈希值
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
conflict_content = TextField() # 冲突内容
|
||||
answer = TextField(null=True) # 回答内容
|
||||
create_time = FloatField() # 创建时间
|
||||
update_time = FloatField() # 更新时间
|
||||
context = TextField(null=True) # 上下文
|
||||
chat_id = TextField(null=True) # 聊天ID
|
||||
raise_time = FloatField(null=True) # 触发次数
|
||||
|
||||
class Meta:
|
||||
table_name = "graph_edges"
|
||||
table_name = "memory_conflicts"
|
||||
|
||||
|
||||
|
||||
def create_tables():
|
||||
|
|
@ -359,9 +362,9 @@ def create_tables():
|
|||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
GraphNodes, # 添加图节点表
|
||||
GraphEdges, # 添加图边表
|
||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||
MemoryChest,
|
||||
MemoryConflict, # 添加记忆冲突表
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -386,9 +389,9 @@ def initialize_database(sync_constraints=False):
|
|||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||
MemoryChest,
|
||||
MemoryConflict,
|
||||
]
|
||||
|
||||
try:
|
||||
|
|
@ -483,9 +486,9 @@ def sync_field_constraints():
|
|||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
ActionRecords,
|
||||
MemoryChest,
|
||||
MemoryConflict,
|
||||
]
|
||||
|
||||
try:
|
||||
|
|
@ -667,9 +670,9 @@ def check_field_constraints():
|
|||
OnlineTime,
|
||||
PersonInfo,
|
||||
Expression,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
ActionRecords,
|
||||
MemoryChest,
|
||||
MemoryConflict,
|
||||
]
|
||||
|
||||
inconsistencies = {}
|
||||
|
|
@ -725,11 +728,14 @@ def check_field_constraints():
|
|||
logger.exception(f"检查字段约束时出错: {e}")
|
||||
|
||||
return inconsistencies
|
||||
|
||||
|
||||
def fix_image_id():
|
||||
"""
|
||||
修复表情包的 image_id 字段
|
||||
"""
|
||||
import uuid
|
||||
|
||||
try:
|
||||
with db:
|
||||
for img in Images.select():
|
||||
|
|
@ -740,6 +746,7 @@ def fix_image_id():
|
|||
except Exception as e:
|
||||
logger.exception(f"修复 image_id 时出错: {e}")
|
||||
|
||||
|
||||
# 模块加载时调用初始化函数
|
||||
initialize_database(sync_constraints=True)
|
||||
fix_image_id()
|
||||
fix_image_id()
|
||||
|
|
|
|||
|
|
@ -363,8 +363,8 @@ MODULE_COLORS = {
|
|||
"planner": "\033[36m",
|
||||
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
||||
# 聊天相关模块
|
||||
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
|
||||
"heartflow": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
|
||||
"hfc": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
|
||||
"bc": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
|
||||
"sub_heartflow": "\033[38;5;207m", # 粉紫色
|
||||
"subheartflow_manager": "\033[38;5;201m", # 深粉色
|
||||
"background_tasks": "\033[38;5;240m", # 灰色
|
||||
|
|
@ -372,8 +372,6 @@ MODULE_COLORS = {
|
|||
"chat_stream": "\033[38;5;51m", # 亮青色
|
||||
"message_storage": "\033[38;5;33m", # 深蓝色
|
||||
"expressor": "\033[38;5;166m", # 橙色
|
||||
# 专注聊天模块
|
||||
"memory_activator": "\033[38;5;117m", # 天蓝色
|
||||
# 插件系统
|
||||
"plugins": "\033[31m", # 红色
|
||||
"plugin_api": "\033[33m", # 黄色
|
||||
|
|
@ -408,7 +406,7 @@ MODULE_COLORS = {
|
|||
"tts_action": "\033[38;5;58m", # 深黄色
|
||||
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
|
||||
# Action组件
|
||||
"no_action_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
|
||||
"no_reply_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
|
||||
"reply_action": "\033[38;5;46m", # 亮绿色
|
||||
"base_action": "\033[38;5;250m", # 浅灰色
|
||||
# 数据库和消息
|
||||
|
|
@ -421,9 +419,7 @@ MODULE_COLORS = {
|
|||
"model_utils": "\033[38;5;164m", # 紫红色
|
||||
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
|
||||
"relationship_builder": "\033[38;5;93m", # 浅蓝色
|
||||
# s4u
|
||||
"context_web_api": "\033[38;5;240m", # 深灰色
|
||||
"S4U_chat": "\033[92m", # 深灰色
|
||||
"conflict_tracker": "\033[38;5;82m", # 柔和的粉色,不显眼但保持粉色系
|
||||
}
|
||||
|
||||
# 定义模块别名映射 - 将真实的logger名称映射到显示的别名
|
||||
|
|
|
|||
|
|
@ -81,7 +81,8 @@ def find_messages(
|
|||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||
|
||||
if filter_command:
|
||||
query = query.where(not Messages.is_command)
|
||||
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
|
||||
query = query.where(~Messages.is_command)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from src.config.official_configs import (
|
|||
ExpressionConfig,
|
||||
ChatConfig,
|
||||
EmojiConfig,
|
||||
MoodConfig,
|
||||
KeywordReactionConfig,
|
||||
ChineseTypoConfig,
|
||||
ResponsePostProcessConfig,
|
||||
|
|
@ -31,6 +30,8 @@ from src.config.official_configs import (
|
|||
RelationshipConfig,
|
||||
ToolConfig,
|
||||
VoiceConfig,
|
||||
MoodConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
)
|
||||
|
||||
|
|
@ -54,7 +55,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
|||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.10.3-snapshot.4"
|
||||
MMC_VERSION = "0.11.0-snapshot.3"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
|
|
@ -173,13 +174,8 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
|
|||
_update_dict(target_value, value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
# 统一使用 tomlkit.item 来保持原生类型与转义,不对列表做字符串化处理
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
|
@ -345,7 +341,6 @@ class Config(ConfigBase):
|
|||
message_receive: MessageReceiveConfig
|
||||
emoji: EmojiConfig
|
||||
expression: ExpressionConfig
|
||||
mood: MoodConfig
|
||||
keyword_reaction: KeywordReactionConfig
|
||||
chinese_typo: ChineseTypoConfig
|
||||
response_post_process: ResponsePostProcessConfig
|
||||
|
|
@ -355,7 +350,9 @@ class Config(ConfigBase):
|
|||
maim_message: MaimMessageConfig
|
||||
lpmm_knowledge: LPMMKnowledgeConfig
|
||||
tool: ToolConfig
|
||||
memory: MemoryConfig
|
||||
debug: DebugConfig
|
||||
mood: MoodConfig
|
||||
voice: VoiceConfig
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import re
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
import time
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
|
|
@ -38,21 +39,18 @@ class PersonalityConfig(ConfigBase):
|
|||
personality: str
|
||||
"""人格"""
|
||||
|
||||
emotion_style: str
|
||||
"""情感特征"""
|
||||
|
||||
reply_style: str = ""
|
||||
"""表达风格"""
|
||||
|
||||
interest: str = ""
|
||||
"""兴趣"""
|
||||
|
||||
|
||||
plan_style: str = ""
|
||||
"""说话规则,行为风格"""
|
||||
|
||||
|
||||
visual_style: str = ""
|
||||
"""图片提示词"""
|
||||
|
||||
|
||||
private_plan_style: str = ""
|
||||
"""私聊说话规则,行为风格"""
|
||||
|
||||
|
|
@ -81,48 +79,212 @@ class ChatConfig(ConfigBase):
|
|||
mentioned_bot_reply: bool = True
|
||||
"""是否启用提及必回复"""
|
||||
|
||||
auto_chat_value: float = 1
|
||||
"""自动聊天,越小,麦麦主动聊天的概率越低"""
|
||||
|
||||
at_bot_inevitable_reply: float = 1
|
||||
"""@bot 必然回复,1为100%回复,0为不额外增幅"""
|
||||
|
||||
talk_frequency: float = 0.5
|
||||
"""回复频率阈值"""
|
||||
|
||||
planner_smooth: float = 3
|
||||
"""规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐2-5,0为关闭,必须大于等于0"""
|
||||
|
||||
talk_value: float = 1
|
||||
"""思考频率"""
|
||||
|
||||
# 合并后的时段频率配置
|
||||
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
|
||||
|
||||
focus_value: float = 0.5
|
||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||
|
||||
focus_value_adjust: list[list[str]] = field(default_factory=lambda: [])
|
||||
|
||||
talk_value_rules: list[dict] = field(default_factory=lambda: [])
|
||||
"""
|
||||
统一的活跃度和专注度配置
|
||||
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
|
||||
|
||||
全局配置示例:
|
||||
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
|
||||
|
||||
特定聊天流配置示例:
|
||||
思考频率规则列表,支持按聊天流/按日内时段配置。
|
||||
规则格式:{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
|
||||
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
|
||||
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
|
||||
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
|
||||
["", "09:00-22:59", 1.0], # 全局规则:白天正常
|
||||
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
|
||||
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
|
||||
]
|
||||
|
||||
说明:
|
||||
- 当第一个元素为空字符串""时,表示全局默认配置
|
||||
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
|
||||
- 后续元素是"时间,频率"格式,表示从该时间开始使用该频率,直到下一个时间点
|
||||
- 优先级:特定聊天流配置 > 全局配置 > 默认值
|
||||
|
||||
注意:
|
||||
- talk_frequency_adjust 控制回复频率,数值越高回复越频繁
|
||||
- focus_value_adjust 控制专注思考能力,数值越低越容易专注,消耗token也越多
|
||||
|
||||
匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
|
||||
时间区间支持跨夜,例如 "23:00-02:00"。
|
||||
"""
|
||||
|
||||
auto_chat_value_rules: list[dict] = field(default_factory=lambda: [])
|
||||
"""
|
||||
自动聊天频率规则列表,支持按聊天流/按日内时段配置。
|
||||
规则格式:{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "00:00-08:59", 0.2], # 全局规则:凌晨到早上更安静
|
||||
["", "09:00-22:59", 1.0], # 全局规则:白天正常
|
||||
["qq:1919810:group", "20:00-23:59", 0.6], # 指定群在晚高峰降低发言
|
||||
["qq:114514:private", "00:00-23:59", 0.3],# 指定私聊全时段较安静
|
||||
]
|
||||
|
||||
匹配优先级: 先匹配指定 chat 流规则,再匹配全局规则(\"\").
|
||||
时间区间支持跨夜,例如 "23:00-02:00"。
|
||||
"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
is_group = stream_type == "group"
|
||||
|
||||
import hashlib
|
||||
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def _now_minutes(self) -> int:
|
||||
"""返回本地时间的分钟数(0-1439)。"""
|
||||
lt = time.localtime()
|
||||
return lt.tm_hour * 60 + lt.tm_min
|
||||
|
||||
def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
|
||||
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
|
||||
try:
|
||||
start_str, end_str = [s.strip() for s in range_str.split("-")]
|
||||
sh, sm = [int(x) for x in start_str.split(":")]
|
||||
eh, em = [int(x) for x in end_str.split(":")]
|
||||
return sh * 60 + sm, eh * 60 + em
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool:
|
||||
"""
|
||||
判断 now_min 是否在 [start_min, end_min] 区间内。
|
||||
支持跨夜:如果 start > end,则表示跨越午夜。
|
||||
"""
|
||||
if start_min <= end_min:
|
||||
return start_min <= now_min <= end_min
|
||||
# 跨夜:例如 23:00-02:00
|
||||
return now_min >= start_min or now_min <= end_min
|
||||
|
||||
def get_talk_value(self, chat_id: Optional[str]) -> float:
|
||||
"""根据规则返回当前 chat 的动态 talk_value,未匹配则回退到基础值。"""
|
||||
if not self.talk_value_rules:
|
||||
return self.talk_value
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
# 1) 先尝试匹配指定 chat 的规则
|
||||
if chat_id:
|
||||
for rule in self.talk_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", "")
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if not isinstance(time_range, str):
|
||||
continue
|
||||
# 跳过全局
|
||||
if target == "":
|
||||
continue
|
||||
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
|
||||
if config_chat_id is None or config_chat_id != chat_id:
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2) 再匹配全局规则("")
|
||||
for rule in self.talk_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", None)
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if target != "" or not isinstance(time_range, str):
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 3) 未命中规则返回基础值
|
||||
return self.talk_value
|
||||
|
||||
def get_auto_chat_value(self, chat_id: Optional[str]) -> float:
|
||||
"""根据规则返回当前 chat 的动态 auto_chat_value,未匹配则回退到基础值。"""
|
||||
if not self.auto_chat_value_rules:
|
||||
return self.auto_chat_value
|
||||
|
||||
now_min = self._now_minutes()
|
||||
|
||||
# 1) 先尝试匹配指定 chat 的规则
|
||||
if chat_id:
|
||||
for rule in self.auto_chat_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", "")
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if not isinstance(time_range, str):
|
||||
continue
|
||||
# 跳过全局
|
||||
if target == "":
|
||||
continue
|
||||
config_chat_id = self._parse_stream_config_to_chat_id(str(target))
|
||||
if config_chat_id is None or config_chat_id != chat_id:
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2) 再匹配全局规则("")
|
||||
for rule in self.auto_chat_value_rules:
|
||||
if not isinstance(rule, dict):
|
||||
continue
|
||||
target = rule.get("target", None)
|
||||
time_range = rule.get("time", "")
|
||||
value = rule.get("value", None)
|
||||
if target != "" or not isinstance(time_range, str):
|
||||
continue
|
||||
parsed = self._parse_range(time_range)
|
||||
if not parsed:
|
||||
continue
|
||||
start_min, end_min = parsed
|
||||
if self._in_range(now_min, start_min, end_min):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 3) 未命中规则返回基础值
|
||||
return self.auto_chat_value
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageReceiveConfig(ConfigBase):
|
||||
|
|
@ -134,11 +296,23 @@ class MessageReceiveConfig(ConfigBase):
|
|||
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||
"""过滤正则表达式列表"""
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig(ConfigBase):
|
||||
"""记忆配置类"""
|
||||
|
||||
max_memory_number: int = 100
|
||||
"""记忆最大数量"""
|
||||
|
||||
memory_build_frequency: int = 1
|
||||
"""记忆构建频率"""
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
mode: str = "classic"
|
||||
"""表达方式模式,可选:classic经典模式,exp_model 表达模型模式"""
|
||||
|
||||
learning_list: list[list] = field(default_factory=lambda: [])
|
||||
"""
|
||||
表达学习配置列表,支持按聊天流配置
|
||||
|
|
@ -299,6 +473,19 @@ class ToolConfig(ConfigBase):
|
|||
"""是否在聊天中启用工具"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoodConfig(ConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
enable_mood: bool = True
|
||||
"""是否启用情绪系统"""
|
||||
|
||||
mood_update_threshold: float = 1
|
||||
"""情绪更新阈值,越高,更新越慢"""
|
||||
|
||||
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
|
||||
"""情感特征,影响情绪的变化情况"""
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
|
|
@ -333,17 +520,6 @@ class EmojiConfig(ConfigBase):
|
|||
"""表情包过滤要求"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoodConfig(ConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
enable_mood: bool = False
|
||||
"""是否启用情绪系统"""
|
||||
|
||||
mood_update_threshold: float = 1.0
|
||||
"""情绪更新阈值,越高,更新越慢"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordRuleConfig(ConfigBase):
|
||||
"""关键词规则配置类"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,580 @@
|
|||
import time
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
import traceback
|
||||
import difflib
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
build_anonymous_messages,
|
||||
build_bare_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
# MAX_EXPRESSION_COUNT = 300
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
使用SequenceMatcher计算相似度
|
||||
"""
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def init_prompt() -> None:
|
||||
learn_style_prompt = """
|
||||
{chat_str}
|
||||
|
||||
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
|
||||
1. 只考虑文字,不要考虑表情包和图片
|
||||
2. 不要涉及具体的人名,但是可以涉及具体名词
|
||||
3. 思考有没有特殊的梗,一并总结成语言风格
|
||||
4. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||
|
||||
例如:
|
||||
当"对某件事表示十分惊叹"时,使用"我嘞个xxxx"
|
||||
当"表示讽刺的赞同,不讲道理"时,使用"对对对"
|
||||
当"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
|
||||
当"当涉及游戏相关时,夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
|
||||
请注意:不要总结你自己(SELF)的发言,尽量保证总结内容的逻辑性
|
||||
现在请你概括
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
match_expression_context_prompt = """
|
||||
**聊天内容**
|
||||
{chat_str}
|
||||
|
||||
**从聊天内容总结的表达方式pairs**
|
||||
{expression_pairs}
|
||||
|
||||
请你为上面的每一条表达方式,找到该表达方式的原文句子,并输出匹配结果,expression_pair不能有重复,每个expression_pair仅输出一个最合适的context。
|
||||
如果找不到原句,就不输出该句的匹配结果。
|
||||
以json格式输出:
|
||||
格式如下:
|
||||
{{
|
||||
"expression_pair": "表达方式pair的序号(数字)",
|
||||
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
|
||||
}},
|
||||
{{
|
||||
"expression_pair": "表达方式pair的序号(数字)",
|
||||
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
|
||||
}},
|
||||
...
|
||||
|
||||
现在请你输出匹配结果:
|
||||
"""
|
||||
Prompt(match_expression_context_prompt, "match_expression_context_prompt")
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="expression.learner"
|
||||
)
|
||||
self.embedding_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
|
||||
)
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 维护每个chat的上次学习时间
|
||||
self.last_learning_time: float = time.time()
|
||||
|
||||
# 学习参数
|
||||
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||
self.chat_id
|
||||
)
|
||||
self.min_messages_for_learning = 30 / self.learning_intensity # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 300 / self.learning_intensity
|
||||
|
||||
def should_trigger_learning(self) -> bool:
|
||||
"""
|
||||
检查是否应该触发学习
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否应该触发学习
|
||||
"""
|
||||
# 检查是否允许学习
|
||||
if not self.enable_learning:
|
||||
return False
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = time.time() - self.last_learning_time
|
||||
if time_diff < self.min_learning_interval:
|
||||
return False
|
||||
|
||||
# 检查消息数量(只检查指定聊天流的消息)
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def trigger_learning_for_chat(self):
|
||||
"""
|
||||
为指定聊天流触发学习
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功触发学习
|
||||
"""
|
||||
if not self.should_trigger_learning():
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
|
||||
# 学习语言风格
|
||||
learnt_style = await self.learn_and_store(num=25)
|
||||
|
||||
# 更新学习时间
|
||||
self.last_learning_time = time.time()
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
else:
|
||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
|
||||
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
"""
|
||||
learnt_expressions = await self.learn_expression(num)
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
_up_content,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表并训练 style_learner
|
||||
has_new_expressions = False # 记录是否有新的表达方式
|
||||
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
|
||||
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
up_content,
|
||||
) in learnt_expressions:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == self.chat_id)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
# 表达方式完全相同,只更新时间戳
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
continue
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
)
|
||||
has_new_expressions = True
|
||||
|
||||
# 训练 style_learner(up_content 和 style 必定存在)
|
||||
try:
|
||||
learner.add_style(style, situation)
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
self.chat_id,
|
||||
up_content,
|
||||
style
|
||||
)
|
||||
if success:
|
||||
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
|
||||
else:
|
||||
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
|
||||
|
||||
|
||||
# 保存当前聊天室的 style_learner 模型
|
||||
if has_new_expressions:
|
||||
try:
|
||||
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
|
||||
save_success = learner.save(style_learner_manager.model_save_path)
|
||||
|
||||
if save_success:
|
||||
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
|
||||
else:
|
||||
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner 模型保存异常: {e}")
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def match_expression_context(
|
||||
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# 为expression_pairs逐个条目赋予编号,并构建成字符串
|
||||
numbered_pairs = []
|
||||
for i, (situation, style) in enumerate(expression_pairs, 1):
|
||||
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
|
||||
|
||||
expression_pairs_str = "\n".join(numbered_pairs)
|
||||
|
||||
prompt = "match_expression_context_prompt"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
expression_pairs=expression_pairs_str,
|
||||
chat_str=random_msg_match_str,
|
||||
)
|
||||
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
# print(f"match_expression_context_prompt: {prompt}")
|
||||
# print(f"{response}")
|
||||
|
||||
# 解析JSON响应
|
||||
match_responses = []
|
||||
try:
|
||||
response = response.strip()
|
||||
# 检查是否已经是标准JSON数组格式
|
||||
if response.startswith("[") and response.endswith("]"):
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 尝试直接解析多个JSON对象
|
||||
try:
|
||||
# 如果是多个JSON对象用逗号分隔,包装成数组
|
||||
if response.startswith("{") and not response.startswith("["):
|
||||
response = "[" + response + "]"
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 使用repair_json处理响应
|
||||
repaired_content = repair_json(response)
|
||||
|
||||
# 确保repaired_content是列表格式
|
||||
if isinstance(repaired_content, str):
|
||||
try:
|
||||
parsed_data = json.loads(repaired_content)
|
||||
if isinstance(parsed_data, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [parsed_data]
|
||||
elif isinstance(parsed_data, list):
|
||||
match_responses = parsed_data
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
match_responses = []
|
||||
elif isinstance(repaired_content, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [repaired_content]
|
||||
elif isinstance(repaired_content, list):
|
||||
match_responses = repaired_content
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
# 如果还是失败,尝试repair_json
|
||||
repaired_content = repair_json(response)
|
||||
if isinstance(repaired_content, str):
|
||||
parsed_data = json.loads(repaired_content)
|
||||
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
|
||||
else:
|
||||
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
|
||||
return []
|
||||
|
||||
# 确保 match_responses 是一个列表
|
||||
if not isinstance(match_responses, list):
|
||||
if isinstance(match_responses, dict):
|
||||
match_responses = [match_responses]
|
||||
else:
|
||||
logger.error(f"match_responses 不是列表或字典类型: {type(match_responses)}, 内容: {match_responses}")
|
||||
return []
|
||||
|
||||
matched_expressions = []
|
||||
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
|
||||
|
||||
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
|
||||
logger.debug(f"match_responses 内容: {match_responses}")
|
||||
|
||||
for match_response in match_responses:
|
||||
try:
|
||||
# 检查 match_response 的类型
|
||||
if not isinstance(match_response, dict):
|
||||
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
|
||||
continue
|
||||
|
||||
# 获取表达方式序号
|
||||
if "expression_pair" not in match_response:
|
||||
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
|
||||
continue
|
||||
|
||||
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效且未被使用过
|
||||
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
|
||||
situation, style = expression_pairs[pair_index]
|
||||
context = match_response.get("context", "")
|
||||
matched_expressions.append((situation, style, context))
|
||||
used_pair_indices.add(pair_index) # 标记该索引已使用
|
||||
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
|
||||
elif pair_index in used_pair_indices:
|
||||
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
|
||||
except (ValueError, KeyError, IndexError, TypeError) as e:
|
||||
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
|
||||
continue
|
||||
|
||||
return matched_expressions
|
||||
|
||||
async def learn_expression(
|
||||
self, num: int = 10
|
||||
) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
num: 学习数量
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习之后的消息
|
||||
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=current_time,
|
||||
limit=num,
|
||||
)
|
||||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
|
||||
# 学习用
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# 溯源用
|
||||
random_msg_match_str: str = await build_bare_messages(random_msg)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"learn_style_prompt",
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
# logger.info(f"学习{type_str}的prompt: {prompt}")
|
||||
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
|
||||
# 对表达方式溯源
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
)
|
||||
# 为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
bare_lines: List[Tuple[int, str]] = self._build_bare_lines(random_msg)
|
||||
# 将 matched_expressions 结合上一句 up_content(若不存在上一句则跳过)
|
||||
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
|
||||
for situation, style, context in matched_expressions:
|
||||
# 在 bare_lines 中找到第一处相似度达到85%的行
|
||||
pos = None
|
||||
for i, (_, c) in enumerate(bare_lines):
|
||||
similarity = calculate_similarity(c, context)
|
||||
if similarity >= 0.85: # 85%相似度阈值
|
||||
pos = i
|
||||
break
|
||||
|
||||
if pos is None or pos == 0:
|
||||
# 没有匹配到目标句或没有上一句,跳过该表达
|
||||
continue
|
||||
|
||||
# 检查目标句是否为空
|
||||
target_content = bare_lines[pos][1]
|
||||
if not target_content:
|
||||
# 目标句为空,跳过该表达
|
||||
continue
|
||||
|
||||
prev_original_idx = bare_lines[pos - 1][0]
|
||||
up_content = self._filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
|
||||
if not up_content:
|
||||
# 上一句为空,跳过该表达
|
||||
continue
|
||||
filtered_with_up.append((situation, style, context, up_content))
|
||||
|
||||
if not filtered_with_up:
|
||||
return None
|
||||
|
||||
return filtered_with_up
|
||||
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
"""
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# 查找"当"和下一个引号
|
||||
idx_when = line.find('当"')
|
||||
if idx_when == -1:
|
||||
continue
|
||||
idx_quote1 = idx_when + 1
|
||||
idx_quote2 = line.find('"', idx_quote1 + 1)
|
||||
if idx_quote2 == -1:
|
||||
continue
|
||||
situation = line[idx_quote1 + 1 : idx_quote2]
|
||||
# 查找"使用"
|
||||
idx_use = line.find('使用"', idx_quote2)
|
||||
if idx_use == -1:
|
||||
continue
|
||||
idx_quote3 = idx_use + 2
|
||||
idx_quote4 = line.find('"', idx_quote3 + 1)
|
||||
if idx_quote4 == -1:
|
||||
continue
|
||||
style = line[idx_quote3 + 1 : idx_quote4]
|
||||
expressions.append((situation, style))
|
||||
return expressions
|
||||
|
||||
def _filter_message_content(self, content: str) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
Args:
|
||||
content: 原始消息内容
|
||||
|
||||
Returns:
|
||||
str: 过滤后的内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r'@<[^>]*>', '', content)
|
||||
# 移除[picid:...]格式的图片ID
|
||||
content = re.sub(r'\[picid:[^\]]*\]', '', content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
|
||||
"""
|
||||
bare_lines: List[Tuple[int, str]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
content = msg.processed_plain_text or ""
|
||||
content = self._filter_message_content(content)
|
||||
# 即使content为空也要记录,防止错位
|
||||
bare_lines.append((idx, content))
|
||||
|
||||
return bare_lines
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
class ExpressionLearnerManager:
|
||||
def __init__(self):
|
||||
self.expression_learners = {}
|
||||
|
||||
self._ensure_expression_directories()
|
||||
|
||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
if chat_id not in self.expression_learners:
|
||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||
return self.expression_learners[chat_id]
|
||||
|
||||
def _ensure_expression_directories(self):
|
||||
"""
|
||||
确保表达方式相关的目录结构存在
|
||||
"""
|
||||
base_dir = os.path.join("data", "expression")
|
||||
directories_to_create = [
|
||||
base_dir,
|
||||
os.path.join(base_dir, "learnt_style"),
|
||||
os.path.join(base_dir, "learnt_grammar"),
|
||||
]
|
||||
|
||||
for directory in directories_to_create:
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
|
||||
expression_learner_manager = ExpressionLearnerManager()
|
||||
|
|
@ -0,0 +1,462 @@
|
|||
import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.express.style_learner import style_learner_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
expression_evaluation_prompt = """
|
||||
以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
|
||||
你的名字是{bot_name}{target_message}
|
||||
|
||||
以下是可选的表达情境:
|
||||
{all_situations}
|
||||
|
||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||
考虑因素包括:
|
||||
1. 聊天的情绪氛围(轻松、严肃、幽默等)
|
||||
2. 话题类型(日常、技术、游戏、情感等)
|
||||
3. 情境与当前语境的匹配度
|
||||
{target_message_extra_block}
|
||||
|
||||
请以JSON格式输出,只需要输出选中的情境编号:
|
||||
例如:
|
||||
{{
|
||||
"selected_situations": [2, 3, 5, 7, 19]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出,不要包含其他内容:
|
||||
"""
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""随机抽样"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 使用随机抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
|
||||
# 随机选择一个元素
|
||||
chosen_idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(chosen_idx))
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
使用 style_learner 模型预测最合适的表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
target_message: 目标消息内容
|
||||
total_num: 需要预测的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 预测的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并预测
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
|
||||
predicted_expressions = []
|
||||
|
||||
# 为每个相关的chat_id进行预测
|
||||
for related_chat_id in related_chat_ids:
|
||||
try:
|
||||
# 使用 style_learner 预测最合适的风格
|
||||
best_style, scores = style_learner_manager.predict_style(
|
||||
related_chat_id, target_message, top_k=total_num
|
||||
)
|
||||
|
||||
if best_style and scores:
|
||||
# 获取预测风格的完整信息
|
||||
learner = style_learner_manager.get_learner(related_chat_id)
|
||||
style_id, situation = learner.get_style_info(best_style)
|
||||
|
||||
if style_id and situation:
|
||||
# 从数据库查找对应的表达记录
|
||||
expr_query = Expression.select().where(
|
||||
(Expression.chat_id == related_chat_id) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == best_style)
|
||||
)
|
||||
|
||||
if expr_query.exists():
|
||||
expr = expr_query.get()
|
||||
predicted_expressions.append({
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"prediction_score": scores.get(best_style, 0.0),
|
||||
"prediction_input": target_message
|
||||
})
|
||||
else:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
|
||||
continue
|
||||
|
||||
# 按预测分数排序,取前 total_num 个
|
||||
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
|
||||
selected_expressions = predicted_expressions[:total_num]
|
||||
|
||||
logger.info(f"为聊天室 {chat_id} 预测到 {len(selected_expressions)} 个表达方式")
|
||||
return selected_expressions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型预测表达方式失败: {e}")
|
||||
# 如果预测失败,回退到随机选择
|
||||
return self._random_expressions(chat_id, total_num)
|
||||
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
total_num: 需要选择的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 随机选择的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids))
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 随机抽样
|
||||
if style_exprs:
|
||||
selected_style = weighted_sample(style_exprs, total_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
|
||||
return selected_style
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机选择表达方式失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
根据配置模式选择适合的表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 获取配置模式
|
||||
expression_mode = global_config.expression.mode
|
||||
|
||||
if expression_mode == "exp_model":
|
||||
# exp_model模式:直接使用模型预测,不经过LLM
|
||||
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_model_only(chat_id, target_message, max_num)
|
||||
elif expression_mode == "classic":
|
||||
# classic模式:随机选择+LLM选择
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
|
||||
else:
|
||||
logger.warning(f"未知的表达模式: {expression_mode},回退到classic模式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
|
||||
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
target_message: str,
|
||||
max_num: int = 10,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
exp_model模式:直接使用模型预测,不经过LLM
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
target_message: 目标消息内容
|
||||
max_num: 最大选择数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 使用模型预测最合适的表达方式
|
||||
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
|
||||
selected_ids = [expr["id"] for expr in selected_expressions]
|
||||
|
||||
# 更新last_active_time
|
||||
if selected_expressions:
|
||||
self.update_expressions_last_active_time(selected_expressions)
|
||||
|
||||
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
|
||||
return selected_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"exp_model模式选择表达方式失败: {e}")
|
||||
return [], []
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
classic模式:随机选择+LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 1. 使用随机抽样选择表达方式
|
||||
style_exprs = self._random_expressions(chat_id, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in style_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("没有找到可用的表达方式")
|
||||
return [], []
|
||||
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要回复消息:{target_message}"
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_observe_info=chat_info,
|
||||
all_situations=all_situations_str,
|
||||
max_num=max_num,
|
||||
target_message=target_message_str,
|
||||
target_message_extra_block=target_message_extra_block,
|
||||
)
|
||||
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
# 5. 解析结果
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return [], []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
# 根据索引获取完整的表达方式
|
||||
valid_expressions: List[Dict[str, Any]] = []
|
||||
selected_ids = []
|
||||
for idx in selected_indices:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
selected_ids.append(expression["id"])
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的所有表达方式,更新last_active_time
|
||||
if valid_expressions:
|
||||
self.update_expressions_last_active_time(valid_expressions)
|
||||
|
||||
logger.info(f"classic模式从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"classic模式处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
|
||||
"""对一批表达方式更新last_active_time"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
"表达方式激活: 更新last_active_time in db"
|
||||
)
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
from typing import Dict, Optional, Tuple, List
|
||||
from collections import Counter, defaultdict
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .online_nb import OnlineNaiveBayes
|
||||
|
||||
class ExpressorModel:
|
||||
"""
|
||||
直接使用朴素贝叶斯精排(可在线学习)
|
||||
支持存储situation字段,不参与计算,仅与style对应
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha: float = 0.5,
|
||||
beta: float = 0.5,
|
||||
gamma: float = 1.0,
|
||||
vocab_size: int = 200000,
|
||||
use_jieba: bool = True):
|
||||
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
|
||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
|
||||
def add_candidate(self, cid: str, text: str, situation: str = None):
|
||||
"""添加候选文本和对应的situation"""
|
||||
self._candidates[cid] = text
|
||||
if situation is not None:
|
||||
self._situations[cid] = situation
|
||||
|
||||
# 确保在nb模型中初始化该候选的计数
|
||||
if cid not in self.nb.cls_counts:
|
||||
self.nb.cls_counts[cid] = 0.0
|
||||
if cid not in self.nb.token_counts:
|
||||
self.nb.token_counts[cid] = defaultdict(float)
|
||||
|
||||
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
|
||||
"""批量添加候选文本和对应的situations"""
|
||||
for i, (cid, text) in enumerate(items):
|
||||
situation = situations[i] if situations and i < len(situations) else None
|
||||
self.add_candidate(cid, text, situation)
|
||||
|
||||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""直接对所有候选进行朴素贝叶斯评分"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return None, {}
|
||||
|
||||
if not self._candidates:
|
||||
return None, {}
|
||||
|
||||
# 对所有候选进行评分
|
||||
tf = Counter(toks)
|
||||
all_cids = list(self._candidates.keys())
|
||||
scores = self.nb.score_batch(tf, all_cids)
|
||||
|
||||
# 取最高分
|
||||
if not scores:
|
||||
return None, {}
|
||||
|
||||
# 根据k参数限制返回的候选数量
|
||||
if k is not None and k > 0:
|
||||
# 按分数降序排序,取前k个
|
||||
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
limited_scores = dict(sorted_scores[:k])
|
||||
best = sorted_scores[0][0] if sorted_scores else None
|
||||
return best, limited_scores
|
||||
else:
|
||||
# 如果没有指定k,返回所有分数
|
||||
best = max(scores.items(), key=lambda x: x[1])[0]
|
||||
return best, scores
|
||||
|
||||
def update_positive(self, text: str, cid: str):
|
||||
"""更新正反馈学习"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return
|
||||
tf = Counter(toks)
|
||||
self.nb.update_positive(tf, cid)
|
||||
|
||||
def decay(self, factor: float):
|
||||
self.nb.decay(factor=factor)
|
||||
|
||||
def get_situation(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的situation"""
|
||||
return self._situations.get(cid)
|
||||
|
||||
def get_style(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的style"""
|
||||
return self._candidates.get(cid)
|
||||
|
||||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""获取候选的style和situation信息"""
|
||||
return self._candidates.get(cid), self._situations.get(cid)
|
||||
|
||||
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||
"""获取所有候选的style和situation信息"""
|
||||
return {cid: (style, self._situations.get(cid))
|
||||
for cid, style in self._candidates.items()}
|
||||
|
||||
def save(self, path: str):
|
||||
"""保存模型"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump({
|
||||
"candidates": self._candidates,
|
||||
"situations": self._situations,
|
||||
"nb": {
|
||||
"cls_counts": dict(self.nb.cls_counts),
|
||||
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
|
||||
"alpha": self.nb.alpha,
|
||||
"beta": self.nb.beta,
|
||||
"gamma": self.nb.gamma,
|
||||
"V": self.nb.V,
|
||||
}
|
||||
}, f)
|
||||
|
||||
def load(self, path: str):
|
||||
"""加载模型"""
|
||||
with open(path, "rb") as f:
|
||||
obj = pickle.load(f)
|
||||
# 还原候选文本
|
||||
self._candidates = obj["candidates"]
|
||||
# 还原situations(兼容旧版本)
|
||||
self._situations = obj.get("situations", {})
|
||||
# 还原朴素贝叶斯模型
|
||||
self.nb.cls_counts = obj["nb"]["cls_counts"]
|
||||
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
|
||||
self.nb.alpha = obj["nb"]["alpha"]
|
||||
self.nb.beta = obj["nb"]["beta"]
|
||||
self.nb.gamma = obj["nb"]["gamma"]
|
||||
self.nb.V = obj["nb"]["V"]
|
||||
self.nb._logZ.clear()
|
||||
|
||||
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
|
||||
from collections import defaultdict
|
||||
outer = defaultdict(lambda: defaultdict(float))
|
||||
for k, inner in d.items():
|
||||
outer[k].update(inner)
|
||||
return outer
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
import math
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
class OnlineNaiveBayes:
|
||||
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.V = vocab_size
|
||||
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
|
||||
def _invalidate(self, cid: str):
|
||||
if cid in self._logZ:
|
||||
del self._logZ[cid]
|
||||
|
||||
def _logZ_c(self, cid: str) -> float:
|
||||
if cid not in self._logZ:
|
||||
Z = self.cls_counts[cid] + self.V * self.alpha
|
||||
self._logZ[cid] = math.log(max(Z, 1e-12))
|
||||
return self._logZ[cid]
|
||||
|
||||
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
|
||||
total_cls = sum(self.cls_counts.values())
|
||||
n_cls = max(1, len(self.cls_counts))
|
||||
denom_prior = math.log(total_cls + self.beta * n_cls)
|
||||
|
||||
out: Dict[str, float] = {}
|
||||
for cid in cids:
|
||||
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
||||
s = prior
|
||||
logZ = self._logZ_c(cid)
|
||||
tc = self.token_counts[cid]
|
||||
for term, qtf in tf.items():
|
||||
num = tc.get(term, 0.0) + self.alpha
|
||||
s += qtf * (math.log(num) - logZ)
|
||||
out[cid] = s
|
||||
return out
|
||||
|
||||
def update_positive(self, tf: Counter, cid: str):
|
||||
inc = 0.0
|
||||
tc = self.token_counts[cid]
|
||||
for term, c in tf.items():
|
||||
tc[term] += float(c)
|
||||
inc += float(c)
|
||||
self.cls_counts[cid] += inc
|
||||
self._invalidate(cid)
|
||||
|
||||
def decay(self, factor: float = None):
|
||||
g = self.gamma if factor is None else factor
|
||||
if g >= 1.0:
|
||||
return
|
||||
for cid in list(self.cls_counts.keys()):
|
||||
self.cls_counts[cid] *= g
|
||||
for term in list(self.token_counts[cid].keys()):
|
||||
self.token_counts[cid][term] *= g
|
||||
self._invalidate(cid)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import re
|
||||
from typing import List, Optional, Set
|
||||
|
||||
try:
|
||||
import jieba
|
||||
_HAS_JIEBA = True
|
||||
except Exception:
|
||||
_HAS_JIEBA = False
|
||||
|
||||
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
|
||||
def simple_en_tokenize(text: str) -> List[str]:
|
||||
return _WORD_RE.findall(text.lower())
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
|
||||
self.stopwords = stopwords or set()
|
||||
self.use_jieba = use_jieba and _HAS_JIEBA
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
if self.use_jieba:
|
||||
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
|
||||
else:
|
||||
toks = simple_en_tokenize(text)
|
||||
return [t for t in toks if t not in self.stopwords]
|
||||
|
|
@ -0,0 +1,628 @@
|
|||
"""
|
||||
多聊天室表达风格学习系统
|
||||
支持为每个chat_id维护独立的表达模型,学习从up_content到style的映射
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import traceback
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .expressor_model.model import ExpressorModel
|
||||
|
||||
logger = get_logger("style_learner")
|
||||
|
||||
|
||||
class StyleLearner:
|
||||
"""
|
||||
单个聊天室的表达风格学习器
|
||||
学习从up_content到style的映射关系
|
||||
支持动态管理风格集合(最多2000个)
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
||||
self.chat_id = chat_id
|
||||
self.model_config = model_config or {
|
||||
"alpha": 0.5,
|
||||
"beta": 0.5,
|
||||
"gamma": 0.99, # 衰减因子,支持遗忘
|
||||
"vocab_size": 200000,
|
||||
"use_jieba": True
|
||||
}
|
||||
|
||||
# 初始化表达模型
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 动态风格管理
|
||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
||||
self.next_style_id = 0 # 下一个可用的style_id
|
||||
|
||||
# 学习统计
|
||||
self.learning_stats = {
|
||||
"total_samples": 0,
|
||||
"style_counts": defaultdict(int),
|
||||
"last_update": None,
|
||||
"style_usage_frequency": defaultdict(int) # 风格使用频率
|
||||
}
|
||||
|
||||
def add_style(self, style: str, situation: str = None) -> bool:
|
||||
"""
|
||||
动态添加一个新的风格
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
situation: 对应的situation文本(可选)
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
try:
|
||||
# 检查是否已存在
|
||||
if style in self.style_to_id:
|
||||
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
|
||||
return True
|
||||
|
||||
# 检查是否超过最大限制
|
||||
if len(self.style_to_id) >= self.max_styles:
|
||||
logger.warning(f"[{self.chat_id}] 已达到最大风格数量限制 ({self.max_styles})")
|
||||
return False
|
||||
|
||||
# 生成新的style_id
|
||||
style_id = f"style_{self.next_style_id}"
|
||||
self.next_style_id += 1
|
||||
|
||||
# 添加到映射
|
||||
self.style_to_id[style] = style_id
|
||||
self.id_to_style[style_id] = style
|
||||
if situation:
|
||||
self.id_to_situation[style_id] = situation
|
||||
|
||||
# 添加到expressor模型
|
||||
self.expressor.add_candidate(style_id, style, situation)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
|
||||
(f", situation: '{situation}'" if situation else ""))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
|
||||
return False
|
||||
|
||||
def remove_style(self, style: str) -> bool:
|
||||
"""
|
||||
删除一个风格
|
||||
|
||||
Args:
|
||||
style: 要删除的风格文本
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
try:
|
||||
if style not in self.style_to_id:
|
||||
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
|
||||
return False
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
|
||||
# 从映射中删除
|
||||
del self.style_to_id[style]
|
||||
del self.id_to_style[style_id]
|
||||
if style_id in self.id_to_situation:
|
||||
del self.id_to_situation[style_id]
|
||||
|
||||
# 从expressor模型中删除(通过重新构建)
|
||||
self._rebuild_expressor()
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
|
||||
return False
|
||||
|
||||
def update_style(self, old_style: str, new_style: str) -> bool:
|
||||
"""
|
||||
更新一个风格
|
||||
|
||||
Args:
|
||||
old_style: 原风格文本
|
||||
new_style: 新风格文本
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
try:
|
||||
if old_style not in self.style_to_id:
|
||||
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
|
||||
return False
|
||||
|
||||
if new_style in self.style_to_id and new_style != old_style:
|
||||
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
|
||||
return False
|
||||
|
||||
style_id = self.style_to_id[old_style]
|
||||
|
||||
# 更新映射
|
||||
del self.style_to_id[old_style]
|
||||
self.style_to_id[new_style] = style_id
|
||||
self.id_to_style[style_id] = new_style
|
||||
|
||||
# 更新expressor模型(保留原有的situation)
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
self.expressor.add_candidate(style_id, new_style, situation)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
|
||||
return False
|
||||
|
||||
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
|
||||
"""
|
||||
批量添加风格
|
||||
|
||||
Args:
|
||||
styles: 风格文本列表
|
||||
situations: 对应的situation文本列表(可选)
|
||||
|
||||
Returns:
|
||||
int: 成功添加的数量
|
||||
"""
|
||||
success_count = 0
|
||||
for i, style in enumerate(styles):
|
||||
situation = situations[i] if situations and i < len(situations) else None
|
||||
if self.add_style(style, situation):
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
|
||||
return success_count
|
||||
|
||||
def get_all_styles(self) -> List[str]:
|
||||
"""获取所有已注册的风格"""
|
||||
return list(self.style_to_id.keys())
|
||||
|
||||
def get_style_count(self) -> int:
|
||||
"""获取当前风格数量"""
|
||||
return len(self.style_to_id)
|
||||
|
||||
def get_situation(self, style: str) -> Optional[str]:
|
||||
"""
|
||||
获取风格对应的situation
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
Optional[str]: 对应的situation,如果不存在则返回None
|
||||
"""
|
||||
if style not in self.style_to_id:
|
||||
return None
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
return self.id_to_situation.get(style_id)
|
||||
|
||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
获取风格的完整信息
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str]]: (style_id, situation)
|
||||
"""
|
||||
if style not in self.style_to_id:
|
||||
return None, None
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
return style_id, situation
|
||||
|
||||
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||
"""
|
||||
获取所有风格的完整信息
|
||||
|
||||
Returns:
|
||||
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
|
||||
"""
|
||||
result = {}
|
||||
for style, style_id in self.style_to_id.items():
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
result[style] = (style_id, situation)
|
||||
return result
|
||||
|
||||
def _rebuild_expressor(self):
|
||||
"""重新构建expressor模型(删除风格后使用)"""
|
||||
try:
|
||||
# 重新创建expressor
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 重新添加所有风格和situation
|
||||
for style_id, style_text in self.id_to_style.items():
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
self.expressor.add_candidate(style_id, style_text, situation)
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
|
||||
|
||||
def learn_mapping(self, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个up_content到style的映射
|
||||
如果style不存在,会自动添加
|
||||
|
||||
Args:
|
||||
up_content: 输入内容
|
||||
style: 对应的style文本
|
||||
|
||||
Returns:
|
||||
bool: 学习是否成功
|
||||
"""
|
||||
try:
|
||||
# 如果style不存在,先添加它
|
||||
if style not in self.style_to_id:
|
||||
if not self.add_style(style):
|
||||
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
|
||||
return False
|
||||
|
||||
# 获取style_id
|
||||
style_id = self.style_to_id[style]
|
||||
|
||||
# 使用正反馈学习
|
||||
self.expressor.update_positive(up_content, style_id)
|
||||
|
||||
# 更新统计
|
||||
self.learning_stats["total_samples"] += 1
|
||||
self.learning_stats["style_counts"][style_id] += 1
|
||||
self.learning_stats["style_usage_frequency"][style] += 1
|
||||
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
根据up_content预测最合适的style
|
||||
|
||||
Args:
|
||||
up_content: 输入内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
Tuple[最佳style文本, 所有候选的分数]
|
||||
"""
|
||||
try:
|
||||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||||
|
||||
if best_style_id is None:
|
||||
return None, {}
|
||||
|
||||
# 将style_id转换为style文本
|
||||
best_style = self.id_to_style.get(best_style_id)
|
||||
|
||||
# 转换所有分数
|
||||
style_scores = {}
|
||||
for sid, score in scores.items():
|
||||
style_text = self.id_to_style.get(sid)
|
||||
if style_text:
|
||||
style_scores[style_text] = score
|
||||
|
||||
return best_style, style_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
|
||||
traceback.print_exc()
|
||||
return None, {}
|
||||
|
||||
def decay_learning(self, factor: Optional[float] = None) -> None:
|
||||
"""
|
||||
对学习到的知识进行衰减(遗忘)
|
||||
|
||||
Args:
|
||||
factor: 衰减因子,None则使用配置中的gamma
|
||||
"""
|
||||
self.expressor.decay(factor)
|
||||
logger.debug(f"[{self.chat_id}] 执行知识衰减")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取学习统计信息"""
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"total_samples": self.learning_stats["total_samples"],
|
||||
"style_count": len(self.style_to_id),
|
||||
"max_styles": self.max_styles,
|
||||
"style_counts": dict(self.learning_stats["style_counts"]),
|
||||
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
|
||||
"last_update": self.learning_stats["last_update"],
|
||||
"all_styles": list(self.style_to_id.keys())
|
||||
}
|
||||
|
||||
def save(self, base_path: str) -> bool:
|
||||
"""
|
||||
保存模型到文件
|
||||
|
||||
Args:
|
||||
base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl
|
||||
"""
|
||||
try:
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||||
|
||||
# 保存模型和统计信息
|
||||
save_data = {
|
||||
"model_config": self.model_config,
|
||||
"style_to_id": self.style_to_id,
|
||||
"id_to_style": self.id_to_style,
|
||||
"id_to_situation": self.id_to_situation,
|
||||
"next_style_id": self.next_style_id,
|
||||
"max_styles": self.max_styles,
|
||||
"learning_stats": self.learning_stats
|
||||
}
|
||||
|
||||
# 先保存expressor模型
|
||||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||||
self.expressor.save(expressor_path)
|
||||
|
||||
# 保存其他数据
|
||||
with open(file_path, "wb") as f:
|
||||
pickle.dump(save_data, f)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
|
||||
return False
|
||||
|
||||
def load(self, base_path: str) -> bool:
|
||||
"""
|
||||
从文件加载模型
|
||||
|
||||
Args:
|
||||
base_path: 基础路径
|
||||
"""
|
||||
try:
|
||||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||||
|
||||
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
|
||||
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
|
||||
return False
|
||||
|
||||
# 加载其他数据
|
||||
with open(file_path, "rb") as f:
|
||||
save_data = pickle.load(f)
|
||||
|
||||
# 恢复配置和状态
|
||||
self.model_config = save_data["model_config"]
|
||||
self.style_to_id = save_data["style_to_id"]
|
||||
self.id_to_style = save_data["id_to_style"]
|
||||
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
|
||||
self.next_style_id = save_data["next_style_id"]
|
||||
self.max_styles = save_data.get("max_styles", 2000)
|
||||
self.learning_stats = save_data["learning_stats"]
|
||||
|
||||
# 重新创建expressor并加载
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
self.expressor.load(expressor_path)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class StyleLearnerManager:
|
||||
"""
|
||||
多聊天室表达风格学习管理器
|
||||
为每个chat_id维护独立的StyleLearner实例
|
||||
每个chat_id可以动态管理自己的风格集合(最多2000个)
|
||||
"""
|
||||
|
||||
def __init__(self, model_save_path: str = "data/style_models"):
|
||||
self.model_save_path = model_save_path
|
||||
self.learners: Dict[str, StyleLearner] = {}
|
||||
|
||||
# 自动保存配置
|
||||
self.auto_save_interval = 300 # 5分钟
|
||||
self._auto_save_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info("StyleLearnerManager 已初始化")
|
||||
|
||||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
||||
"""
|
||||
获取或创建指定chat_id的学习器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
model_config: 模型配置,None则使用默认配置
|
||||
|
||||
Returns:
|
||||
StyleLearner实例
|
||||
"""
|
||||
if chat_id not in self.learners:
|
||||
# 创建新的学习器
|
||||
learner = StyleLearner(chat_id, model_config)
|
||||
|
||||
# 尝试加载已保存的模型
|
||||
learner.load(self.model_save_path)
|
||||
|
||||
self.learners[chat_id] = learner
|
||||
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
|
||||
|
||||
return self.learners[chat_id]
|
||||
|
||||
def add_style(self, chat_id: str, style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id添加风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.add_style(style)
|
||||
|
||||
def remove_style(self, chat_id: str, style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id删除风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.remove_style(style)
|
||||
|
||||
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id更新风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
old_style: 原风格文本
|
||||
new_style: 新风格文本
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.update_style(old_style, new_style)
|
||||
|
||||
def get_chat_styles(self, chat_id: str) -> List[str]:
|
||||
"""
|
||||
获取指定chat_id的所有风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
List[str]: 风格列表
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.get_all_styles()
|
||||
|
||||
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个映射关系
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 输入内容
|
||||
style: 对应的style
|
||||
|
||||
Returns:
|
||||
bool: 学习是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.learn_mapping(up_content, style)
|
||||
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
预测最合适的style
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 输入内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
Tuple[最佳style, 所有候选分数]
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.predict_style(up_content, top_k)
|
||||
|
||||
def decay_all_learners(self, factor: Optional[float] = None) -> None:
|
||||
"""
|
||||
对所有学习器执行衰减
|
||||
|
||||
Args:
|
||||
factor: 衰减因子
|
||||
"""
|
||||
for learner in self.learners.values():
|
||||
learner.decay_learning(factor)
|
||||
logger.info("已对所有学习器执行衰减")
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict]:
|
||||
"""获取所有学习器的统计信息"""
|
||||
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
|
||||
|
||||
def save_all_models(self) -> bool:
|
||||
"""保存所有模型"""
|
||||
success_count = 0
|
||||
for learner in self.learners.values():
|
||||
if learner.save(self.model_save_path):
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
|
||||
return success_count == len(self.learners)
|
||||
|
||||
def load_all_models(self) -> int:
|
||||
"""加载所有已保存的模型"""
|
||||
if not os.path.exists(self.model_save_path):
|
||||
return 0
|
||||
|
||||
loaded_count = 0
|
||||
for filename in os.listdir(self.model_save_path):
|
||||
if filename.endswith("_style_model.pkl"):
|
||||
chat_id = filename.replace("_style_model.pkl", "")
|
||||
learner = StyleLearner(chat_id)
|
||||
if learner.load(self.model_save_path):
|
||||
self.learners[chat_id] = learner
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"已加载 {loaded_count} 个模型")
|
||||
return loaded_count
|
||||
|
||||
async def start_auto_save(self) -> None:
|
||||
"""启动自动保存任务"""
|
||||
if self._auto_save_task is None or self._auto_save_task.done():
|
||||
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
|
||||
logger.info("已启动自动保存任务")
|
||||
|
||||
async def stop_auto_save(self) -> None:
|
||||
"""停止自动保存任务"""
|
||||
if self._auto_save_task and not self._auto_save_task.done():
|
||||
self._auto_save_task.cancel()
|
||||
try:
|
||||
await self._auto_save_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("已停止自动保存任务")
|
||||
|
||||
async def _auto_save_loop(self) -> None:
|
||||
"""自动保存循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.auto_save_interval)
|
||||
self.save_all_models()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"自动保存失败: {e}")
|
||||
|
||||
|
||||
# 全局管理器实例
|
||||
style_learner_manager = StyleLearnerManager()
|
||||
|
|
@ -85,4 +85,4 @@ class ModelAttemptFailed(Exception):
|
|||
self.original_exception = original_exception
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
return self.message
|
||||
|
|
|
|||
|
|
@ -72,8 +72,8 @@ class BaseClient(ABC):
|
|||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
|
|
@ -117,6 +117,7 @@ class BaseClient(ABC):
|
|||
self,
|
||||
model_info: ModelInfo,
|
||||
audio_base64: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import io
|
||||
import base64
|
||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
|
||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict
|
||||
|
||||
from google import genai
|
||||
from google.genai.types import (
|
||||
|
|
@ -17,6 +17,7 @@ from google.genai.types import (
|
|||
EmbedContentResponse,
|
||||
EmbedContentConfig,
|
||||
SafetySetting,
|
||||
HttpOptions,
|
||||
HarmCategory,
|
||||
HarmBlockThreshold,
|
||||
)
|
||||
|
|
@ -182,6 +183,14 @@ def _process_delta(
|
|||
if delta.text:
|
||||
fc_delta_buffer.write(delta.text)
|
||||
|
||||
# 处理 thought(Gemini 的特殊字段)
|
||||
for c in getattr(delta, "candidates", []):
|
||||
if c.content and getattr(c.content, "parts", None):
|
||||
for p in c.content.parts:
|
||||
if getattr(p, "thought", False) and getattr(p, "text", None):
|
||||
# 把 thought 写入 buffer,避免 resp.content 永远为空
|
||||
fc_delta_buffer.write(p.text)
|
||||
|
||||
if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的
|
||||
for call in delta.function_calls:
|
||||
try:
|
||||
|
|
@ -203,6 +212,7 @@ def _process_delta(
|
|||
def _build_stream_api_resp(
|
||||
_fc_delta_buffer: io.StringIO,
|
||||
_tool_calls_buffer: list[tuple[str, str, dict]],
|
||||
last_resp: GenerateContentResponse | None = None, # 传入 last_resp
|
||||
) -> APIResponse:
|
||||
# sourcery skip: simplify-len-comparison, use-assigned-variable
|
||||
resp = APIResponse()
|
||||
|
|
@ -227,6 +237,21 @@ def _build_stream_api_resp(
|
|||
|
||||
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
|
||||
|
||||
# 检查是否因为 max_tokens 截断
|
||||
reason = None
|
||||
if last_resp and getattr(last_resp, "candidates", None):
|
||||
c0 = last_resp.candidates[0]
|
||||
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
|
||||
|
||||
if str(reason).endswith("MAX_TOKENS"):
|
||||
if resp.content and resp.content.strip():
|
||||
logger.warning(
|
||||
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
|
||||
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
|
||||
)
|
||||
else:
|
||||
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
|
||||
|
||||
if not resp.content and not resp.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
|
||||
|
|
@ -245,12 +270,14 @@ async def _default_stream_response_handler(
|
|||
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
|
||||
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
|
||||
_usage_record = None # 使用情况记录
|
||||
last_resp: GenerateContentResponse | None = None # 保存最后一个 chunk
|
||||
|
||||
def _insure_buffer_closed():
|
||||
if _fc_delta_buffer and not _fc_delta_buffer.closed:
|
||||
_fc_delta_buffer.close()
|
||||
|
||||
async for chunk in resp_stream:
|
||||
last_resp = chunk # 保存最后一个响应
|
||||
# 检查是否有中断量
|
||||
if interrupt_flag and interrupt_flag.is_set():
|
||||
# 如果中断量被设置,则抛出ReqAbortException
|
||||
|
|
@ -269,10 +296,12 @@ async def _default_stream_response_handler(
|
|||
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
|
||||
chunk.usage_metadata.total_token_count or 0,
|
||||
)
|
||||
|
||||
try:
|
||||
return _build_stream_api_resp(
|
||||
_fc_delta_buffer,
|
||||
_tool_calls_buffer,
|
||||
last_resp=last_resp,
|
||||
), _usage_record
|
||||
except Exception:
|
||||
# 确保缓冲区被关闭
|
||||
|
|
@ -332,6 +361,35 @@ def _default_normal_response_parser(
|
|||
|
||||
api_response.raw_data = resp
|
||||
|
||||
# 检查是否因为 max_tokens 截断
|
||||
try:
|
||||
if resp.candidates:
|
||||
c0 = resp.candidates[0]
|
||||
reason = getattr(c0, "finish_reason", None) or getattr(c0, "finishReason", None)
|
||||
if reason and "MAX_TOKENS" in str(reason):
|
||||
# 检查第二个及之后的 parts 是否有内容
|
||||
has_real_output = False
|
||||
if getattr(c0, "content", None) and getattr(c0.content, "parts", None):
|
||||
for p in c0.content.parts[1:]: # 跳过第一个 thought
|
||||
if getattr(p, "text", None) and p.text.strip():
|
||||
has_real_output = True
|
||||
break
|
||||
|
||||
if not has_real_output and getattr(resp, "text", None):
|
||||
has_real_output = True
|
||||
|
||||
if has_real_output:
|
||||
logger.warning(
|
||||
"⚠ Gemini 响应因达到 max_tokens 限制被部分截断,\n"
|
||||
" 可能会对回复内容造成影响,建议修改模型 max_tokens 配置!"
|
||||
)
|
||||
else:
|
||||
logger.warning("⚠ Gemini 响应因达到 max_tokens 限制被截断,\n 请修改模型 max_tokens 配置!")
|
||||
|
||||
return api_response, _usage_record
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 MAX_TOKENS 截断时异常: {e}")
|
||||
|
||||
# 最终的、唯一的空响应检查
|
||||
if not api_response.content and not api_response.tool_calls:
|
||||
raise EmptyResponseException("响应中既无文本内容也无工具调用")
|
||||
|
|
@ -345,17 +403,45 @@ class GeminiClient(BaseClient):
|
|||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
super().__init__(api_provider)
|
||||
|
||||
# 增加传入参数处理
|
||||
http_options_kwargs: Dict[str, Any] = {}
|
||||
|
||||
# 秒转换为毫秒传入
|
||||
if api_provider.timeout is not None:
|
||||
http_options_kwargs["timeout"] = int(api_provider.timeout * 1000)
|
||||
|
||||
# 传入并处理地址和版本(必须为Gemini格式)
|
||||
if api_provider.base_url:
|
||||
parts = api_provider.base_url.rstrip("/").rsplit("/", 1)
|
||||
if len(parts) == 2 and parts[1].startswith("v"):
|
||||
http_options_kwargs["base_url"] = f"{parts[0]}/"
|
||||
http_options_kwargs["api_version"] = parts[1]
|
||||
else:
|
||||
http_options_kwargs["base_url"] = api_provider.base_url
|
||||
http_options_kwargs["api_version"] = None
|
||||
self.client = genai.Client(
|
||||
http_options=HttpOptions(**http_options_kwargs),
|
||||
api_key=api_provider.api_key,
|
||||
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
||||
|
||||
@staticmethod
|
||||
def clamp_thinking_budget(tb: int, model_id: str) -> int:
|
||||
def clamp_thinking_budget(extra_params: dict[str, Any] | None, model_id: str) -> int:
|
||||
"""
|
||||
按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本)
|
||||
"""
|
||||
limits = None
|
||||
|
||||
# 参数传入处理
|
||||
tb = THINKING_BUDGET_AUTO
|
||||
if extra_params and "thinking_budget" in extra_params:
|
||||
try:
|
||||
tb = int(extra_params["thinking_budget"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
|
||||
)
|
||||
|
||||
# 优先尝试精确匹配
|
||||
if model_id in THINKING_BUDGET_LIMITS:
|
||||
limits = THINKING_BUDGET_LIMITS[model_id]
|
||||
|
|
@ -368,20 +454,29 @@ class GeminiClient(BaseClient):
|
|||
limits = THINKING_BUDGET_LIMITS[key]
|
||||
break
|
||||
|
||||
# 特殊值处理
|
||||
# 预算值处理
|
||||
if tb == THINKING_BUDGET_AUTO:
|
||||
return THINKING_BUDGET_AUTO
|
||||
if tb == THINKING_BUDGET_DISABLED:
|
||||
if limits and limits.get("can_disable", False):
|
||||
return THINKING_BUDGET_DISABLED
|
||||
return limits["min"] if limits else THINKING_BUDGET_AUTO
|
||||
if limits:
|
||||
logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退到最小值 {limits['min']}")
|
||||
return limits["min"]
|
||||
return THINKING_BUDGET_AUTO
|
||||
|
||||
# 已知模型裁剪到范围
|
||||
# 已知模型范围裁剪 + 提示
|
||||
if limits:
|
||||
return max(limits["min"], min(tb, limits["max"]))
|
||||
if tb < limits["min"]:
|
||||
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过小,已调整为最小值 {limits['min']}")
|
||||
return limits["min"]
|
||||
if tb > limits["max"]:
|
||||
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过大,已调整为最大值 {limits['max']}")
|
||||
return limits["max"]
|
||||
return tb
|
||||
|
||||
# 未知模型,返回动态模式
|
||||
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
|
||||
# 未知模型 → 默认自动模式
|
||||
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,已启用模型自动预算兼容")
|
||||
return THINKING_BUDGET_AUTO
|
||||
|
||||
async def get_response(
|
||||
|
|
@ -389,8 +484,8 @@ class GeminiClient(BaseClient):
|
|||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: Optional[int] = 1024,
|
||||
temperature: Optional[float] = 0.4,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
|
|
@ -429,16 +524,8 @@ class GeminiClient(BaseClient):
|
|||
messages = _convert_messages(message_list)
|
||||
# 将tool_options转换为Gemini API所需的格式
|
||||
tools = _convert_tool_options(tool_options) if tool_options else None
|
||||
|
||||
tb = THINKING_BUDGET_AUTO
|
||||
# 空处理
|
||||
if extra_params and "thinking_budget" in extra_params:
|
||||
try:
|
||||
tb = int(extra_params["thinking_budget"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}")
|
||||
# 裁剪到模型支持的范围
|
||||
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
||||
# 解析并裁剪 thinking_budget
|
||||
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
|
||||
|
||||
# 将response_format转换为Gemini API所需的格式
|
||||
generation_config_dict = {
|
||||
|
|
@ -497,15 +584,20 @@ class GeminiClient(BaseClient):
|
|||
|
||||
resp, usage_record = async_response_parser(req_task.result())
|
||||
except (ClientError, ServerError) as e:
|
||||
# 重封装ClientError和ServerError为RespNotOkException
|
||||
# 重封装 ClientError 和 ServerError 为 RespNotOkException
|
||||
raise RespNotOkException(e.code, e.message) from None
|
||||
except (
|
||||
UnknownFunctionCallArgumentError,
|
||||
UnsupportedFunctionError,
|
||||
FunctionInvocationError,
|
||||
) as e:
|
||||
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
|
||||
# 工具调用相关错误
|
||||
raise RespParseException(None, f"工具调用参数错误: {str(e)}") from None
|
||||
except EmptyResponseException as e:
|
||||
# 保持原始异常,便于区分“空响应”和网络异常
|
||||
raise e
|
||||
except Exception as e:
|
||||
# 其他未预料的错误,才归为网络连接类
|
||||
raise NetworkConnectionError() from e
|
||||
|
||||
if usage_record:
|
||||
|
|
@ -561,41 +653,51 @@ class GeminiClient(BaseClient):
|
|||
|
||||
return response
|
||||
|
||||
def get_audio_transcriptions(
|
||||
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
|
||||
async def get_audio_transcriptions(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
audio_base64: str,
|
||||
max_tokens: Optional[int] = 2048,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取音频转录
|
||||
:param model_info: 模型信息
|
||||
:param audio_base64: 音频文件的Base64编码字符串
|
||||
:param max_tokens: 最大输出token数(默认2048)
|
||||
:param extra_params: 额外参数(可选)
|
||||
:return: 转录响应
|
||||
"""
|
||||
# 解析并裁剪 thinking_budget
|
||||
tb = self.clamp_thinking_budget(extra_params, model_info.model_identifier)
|
||||
|
||||
# 构造 prompt + 音频输入
|
||||
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||
contents = [
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part.from_text(text=prompt),
|
||||
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
generation_config_dict = {
|
||||
"max_output_tokens": 2048,
|
||||
"max_output_tokens": max_tokens,
|
||||
"response_modalities": ["TEXT"],
|
||||
"thinking_config": ThinkingConfig(
|
||||
include_thoughts=True,
|
||||
thinking_budget=(
|
||||
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
||||
),
|
||||
thinking_budget=tb,
|
||||
),
|
||||
"safety_settings": gemini_safe_settings,
|
||||
}
|
||||
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||
|
||||
try:
|
||||
raw_response: GenerateContentResponse = self.client.models.generate_content(
|
||||
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
|
||||
model=model_info.model_identifier,
|
||||
contents=[
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part.from_text(text=prompt),
|
||||
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
||||
],
|
||||
)
|
||||
],
|
||||
contents=contents,
|
||||
config=generate_content_config,
|
||||
)
|
||||
resp, usage_record = _default_normal_response_parser(raw_response)
|
||||
|
|
|
|||
|
|
@ -403,8 +403,8 @@ class OpenaiClient(BaseClient):
|
|||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = 1024,
|
||||
temperature: Optional[float] = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
|
|
@ -488,6 +488,9 @@ class OpenaiClient(BaseClient):
|
|||
raise ReqAbortException("请求被外部信号中断")
|
||||
await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态
|
||||
|
||||
# logger.
|
||||
logger.debug(f"OpenAI API响应(非流式): {req_task.result()}")
|
||||
|
||||
# logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}")
|
||||
|
||||
resp, usage_record = async_response_parser(req_task.result())
|
||||
|
|
@ -507,6 +510,8 @@ class OpenaiClient(BaseClient):
|
|||
total_tokens=usage_record[2],
|
||||
)
|
||||
|
||||
# logger.debug(f"OpenAI API响应: {resp}")
|
||||
|
||||
return resp
|
||||
|
||||
async def get_embedding(
|
||||
|
|
|
|||
|
|
@ -26,18 +26,6 @@ install(extra_lines=3)
|
|||
|
||||
logger = get_logger("model_utils")
|
||||
|
||||
# 常见Error Code Mapping
|
||||
error_code_mapping = {
|
||||
400: "参数不正确",
|
||||
401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
|
||||
402: "账号余额不足",
|
||||
403: "需要实名,或余额不足",
|
||||
404: "Not Found",
|
||||
429: "请求过于频繁,请稍后再试",
|
||||
500: "服务器内部故障",
|
||||
503: "服务器负载过高",
|
||||
}
|
||||
|
||||
|
||||
class RequestType(Enum):
|
||||
"""请求类型枚举"""
|
||||
|
|
@ -160,6 +148,8 @@ class LLMRequest:
|
|||
)
|
||||
|
||||
logger.debug(f"LLM请求总耗时: {time.time() - start_time}")
|
||||
logger.debug(f"LLM生成内容: {response}")
|
||||
|
||||
content = response.content
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
tool_calls = response.tool_calls
|
||||
|
|
@ -267,14 +257,14 @@ class LLMRequest:
|
|||
extra_params=model_info.extra_params,
|
||||
)
|
||||
elif request_type == RequestType.EMBEDDING:
|
||||
assert embedding_input is not None
|
||||
assert embedding_input is not None, "嵌入输入不能为空"
|
||||
return await client.get_embedding(
|
||||
model_info=model_info,
|
||||
embedding_input=embedding_input,
|
||||
extra_params=model_info.extra_params,
|
||||
)
|
||||
elif request_type == RequestType.AUDIO:
|
||||
assert audio_base64 is not None
|
||||
assert audio_base64 is not None, "音频Base64不能为空"
|
||||
return await client.get_audio_transcriptions(
|
||||
model_info=model_info,
|
||||
audio_base64=audio_base64,
|
||||
|
|
@ -365,24 +355,23 @@ class LLMRequest:
|
|||
embedding_input=embedding_input,
|
||||
audio_base64=audio_base64,
|
||||
)
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
if response_usage := response.usage:
|
||||
total_tokens += response_usage.total_tokens
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
|
||||
return response, model_info
|
||||
|
||||
except ModelAttemptFailed as e:
|
||||
last_exception = e.original_exception or e
|
||||
logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}")
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty - 1)
|
||||
failed_models_this_request.add(model_info.name)
|
||||
|
||||
if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400:
|
||||
logger.error("收到不可恢复的客户端错误 (400),中止所有尝试。")
|
||||
raise last_exception from e
|
||||
|
||||
finally:
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
if usage_penalty > 0:
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
|
||||
|
||||
logger.error(f"所有 {max_attempts} 个模型均尝试失败。")
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
|
|
|||
20
src/main.py
20
src/main.py
|
|
@ -13,8 +13,8 @@ from src.common.logger import get_logger
|
|||
from src.common.server import get_global_server, Server
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.chat.knowledge import lpmm_start_up
|
||||
from src.memory_system.memory_management_task import MemoryManagementTask
|
||||
from rich.traceback import install
|
||||
from src.migrate_helper.migrate import check_and_run_migrations
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
# 导入新的插件管理器
|
||||
|
|
@ -83,22 +83,19 @@ class MainSystem:
|
|||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
# 启动情绪管理器
|
||||
await mood_manager.start()
|
||||
logger.info("情绪管理器初始化成功")
|
||||
if global_config.mood.enable_mood:
|
||||
await mood_manager.start()
|
||||
logger.info("情绪管理器初始化成功")
|
||||
|
||||
# 初始化聊天管理器
|
||||
await get_chat_manager()._initialize()
|
||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||
|
||||
logger.info("聊天管理器初始化成功")
|
||||
|
||||
# # 根据配置条件性地初始化记忆系统
|
||||
# if global_config.memory.enable_memory:
|
||||
# if self.hippocampus_manager:
|
||||
# self.hippocampus_manager.initialize()
|
||||
# logger.info("记忆系统初始化成功")
|
||||
# else:
|
||||
# logger.info("记忆系统已禁用,跳过初始化")
|
||||
|
||||
# 添加记忆管理任务
|
||||
await async_task_manager.add_task(MemoryManagementTask())
|
||||
logger.info("记忆管理任务已启动")
|
||||
|
||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||
|
||||
|
|
@ -106,7 +103,6 @@ class MainSystem:
|
|||
self.app.register_message_handler(chat_bot.message_process)
|
||||
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
|
||||
|
||||
await check_and_run_migrations()
|
||||
|
||||
# 触发 ON_START 事件
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
|
|
|
|||
|
|
@ -1,36 +0,0 @@
|
|||
[inner]
|
||||
version = "1.0.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
enable_loading_indicator = true # 是否显示加载提示
|
||||
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
[inner]
|
||||
version = "1.2.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
enable_s4u = false
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
|
||||
enable_streaming_output = true # 是否启用流式输出,false时全部生成后一次性发送
|
||||
|
||||
max_context_message_length = 20
|
||||
max_core_message_length = 30
|
||||
|
||||
# 模型配置
|
||||
[models]
|
||||
# 主要对话模型配置
|
||||
[models.chat]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 规划模型配置
|
||||
[models.motion]
|
||||
name = "qwen3-32b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 情感分析模型配置
|
||||
[models.emotion]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
|
@ -1,167 +0,0 @@
|
|||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import time
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
你之前的内心想法是:{mind}
|
||||
|
||||
{memory_block}
|
||||
{relation_info_block}
|
||||
|
||||
{chat_target}
|
||||
{time_block}
|
||||
{chat_info}
|
||||
{identity}
|
||||
|
||||
你刚刚在{chat_target_2},你你刚刚的心情是:{mood_state}
|
||||
---------------------
|
||||
在这样的情况下,你对上面的内容,你对 {sender} 发送的 消息 “{target}” 进行了回复
|
||||
你刚刚选择回复的内容是:{reponse}
|
||||
现在,根据你之前的想法和回复的内容,推测你现在的想法,思考你现在的想法是什么,为什么做出上面的回复内容
|
||||
请不要浮夸和夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出想法:""",
|
||||
"after_response_think_prompt",
|
||||
)
|
||||
|
||||
|
||||
class MaiThinking:
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.platform = self.chat_stream.platform
|
||||
|
||||
if self.chat_stream.group_info:
|
||||
self.is_group = True
|
||||
else:
|
||||
self.is_group = False
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
self.mind = ""
|
||||
|
||||
self.memory_block = ""
|
||||
self.relation_info_block = ""
|
||||
self.time_block = ""
|
||||
self.chat_target = ""
|
||||
self.chat_target_2 = ""
|
||||
self.chat_info = ""
|
||||
self.mood_state = ""
|
||||
self.identity = ""
|
||||
self.sender = ""
|
||||
self.target = ""
|
||||
|
||||
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
|
||||
|
||||
async def do_think_before_response(self):
|
||||
pass
|
||||
|
||||
async def do_think_after_response(self, reponse: str):
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"after_response_think_prompt",
|
||||
mind=self.mind,
|
||||
reponse=reponse,
|
||||
memory_block=self.memory_block,
|
||||
relation_info_block=self.relation_info_block,
|
||||
time_block=self.time_block,
|
||||
chat_target=self.chat_target,
|
||||
chat_target_2=self.chat_target_2,
|
||||
chat_info=self.chat_info,
|
||||
mood_state=self.mood_state,
|
||||
identity=self.identity,
|
||||
sender=self.sender,
|
||||
target=self.target,
|
||||
)
|
||||
|
||||
result, _ = await self.thinking_model.generate_response_async(prompt)
|
||||
self.mind = result
|
||||
|
||||
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
|
||||
# logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}")
|
||||
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
|
||||
|
||||
msg_recv = await self.build_internal_message_recv(self.mind)
|
||||
await self.s4u_message_processor.process_message(msg_recv)
|
||||
internal_manager.set_internal_state(self.mind)
|
||||
|
||||
async def do_think_when_receive_message(self):
|
||||
pass
|
||||
|
||||
async def build_internal_message_recv(self, message_text: str):
|
||||
msg_id = f"internal_{time.time()}"
|
||||
|
||||
message_dict = {
|
||||
"message_info": {
|
||||
"message_id": msg_id,
|
||||
"time": time.time(),
|
||||
"user_info": {
|
||||
"user_id": "internal", # 内部用户ID
|
||||
"user_nickname": "内心", # 内部昵称
|
||||
"platform": self.platform, # 平台标记为 internal
|
||||
# 其他 user_info 字段按需补充
|
||||
},
|
||||
"platform": self.platform, # 平台
|
||||
# 其他 message_info 字段按需补充
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "text", # 消息类型
|
||||
"data": message_text, # 消息内容
|
||||
# 其他 segment 字段按需补充
|
||||
},
|
||||
"raw_message": message_text, # 原始消息内容
|
||||
"processed_plain_text": message_text, # 处理后的纯文本
|
||||
# 下面这些字段可选,根据 MessageRecv 需要
|
||||
"is_emoji": False,
|
||||
"has_emoji": False,
|
||||
"is_picid": False,
|
||||
"has_picid": False,
|
||||
"is_voice": False,
|
||||
"is_mentioned": False,
|
||||
"is_command": False,
|
||||
"is_internal": True,
|
||||
"priority_mode": "interest",
|
||||
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
|
||||
"interest_value": 1.0,
|
||||
}
|
||||
|
||||
if self.is_group:
|
||||
message_dict["message_info"]["group_info"] = {
|
||||
"platform": self.platform,
|
||||
"group_id": self.chat_stream.group_info.group_id,
|
||||
"group_name": self.chat_stream.group_info.group_name,
|
||||
}
|
||||
|
||||
msg_recv = MessageRecvS4U(message_dict)
|
||||
msg_recv.chat_info = self.chat_info
|
||||
msg_recv.chat_stream = self.chat_stream
|
||||
msg_recv.is_internal = True
|
||||
|
||||
return msg_recv
|
||||
|
||||
|
||||
class MaiThinkingManager:
|
||||
def __init__(self):
|
||||
self.mai_think_list = []
|
||||
|
||||
def get_mai_think(self, chat_id):
|
||||
for mai_think in self.mai_think_list:
|
||||
if mai_think.chat_id == chat_id:
|
||||
return mai_think
|
||||
mai_think = MaiThinking(chat_id)
|
||||
self.mai_think_list.append(mai_think)
|
||||
return mai_think
|
||||
|
||||
|
||||
mai_thinking_manager = MaiThinkingManager()
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
|
@ -1,342 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
|
||||
logger = get_logger("action")
|
||||
|
||||
# 使用字典作为默认值,但通过Prompt来注册以便外部重载
|
||||
DEFAULT_HEAD_CODE = {
|
||||
"看向上方": "(0,0.5,0)",
|
||||
"看向下方": "(0,-0.5,0)",
|
||||
"看向左边": "(-1,0,0)",
|
||||
"看向右边": "(1,0,0)",
|
||||
"随意朝向": "random",
|
||||
"看向摄像机": "camera",
|
||||
"注视对方": "(0,0,0)",
|
||||
"看向正前方": "(0,0,0)",
|
||||
}
|
||||
|
||||
DEFAULT_BODY_CODE = {
|
||||
"双手背后向前弯腰": "010_0070",
|
||||
"歪头双手合十": "010_0100",
|
||||
"标准文静站立": "010_0101",
|
||||
"双手交叠腹部站立": "010_0150",
|
||||
"帅气的姿势": "010_0190",
|
||||
"另一个帅气的姿势": "010_0191",
|
||||
"手掌朝前可爱": "010_0210",
|
||||
"平静,双手后放": "平静,双手后放",
|
||||
"思考": "思考",
|
||||
"优雅,左手放在腰上": "优雅,左手放在腰上",
|
||||
"一般": "一般",
|
||||
"可爱,双手前放": "可爱,双手前放",
|
||||
}
|
||||
|
||||
|
||||
async def get_head_code() -> dict:
|
||||
"""获取头部动作代码字典"""
|
||||
head_code_str = await global_prompt_manager.format_prompt("head_code_prompt")
|
||||
if not head_code_str:
|
||||
return DEFAULT_HEAD_CODE
|
||||
try:
|
||||
return json.loads(head_code_str)
|
||||
except Exception as e:
|
||||
logger.error(f"解析head_code_prompt失败,使用默认值: {e}")
|
||||
return DEFAULT_HEAD_CODE
|
||||
|
||||
|
||||
async def get_body_code() -> dict:
|
||||
"""获取身体动作代码字典"""
|
||||
body_code_str = await global_prompt_manager.format_prompt("body_code_prompt")
|
||||
if not body_code_str:
|
||||
return DEFAULT_BODY_CODE
|
||||
try:
|
||||
return json.loads(body_code_str)
|
||||
except Exception as e:
|
||||
logger.error(f"解析body_code_prompt失败,使用默认值: {e}")
|
||||
return DEFAULT_BODY_CODE
|
||||
|
||||
|
||||
def init_prompt():
|
||||
# 注册头部动作代码
|
||||
Prompt(
|
||||
json.dumps(DEFAULT_HEAD_CODE, ensure_ascii=False, indent=2),
|
||||
"head_code_prompt",
|
||||
)
|
||||
|
||||
# 注册身体动作代码
|
||||
Prompt(
|
||||
json.dumps(DEFAULT_BODY_CODE, ensure_ascii=False, indent=2),
|
||||
"body_code_prompt",
|
||||
)
|
||||
|
||||
# 注册原有提示模板
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里正在进行的聊天记录
|
||||
|
||||
{indentify_block}
|
||||
你现在的动作状态是:
|
||||
- 身体动作:{body_action}
|
||||
|
||||
现在,因为你发送了消息,或者群里其他人发送了消息,引起了你的注意,你对其进行了阅读和思考,请你更新你的动作状态。
|
||||
身体动作可选:
|
||||
{all_actions}
|
||||
|
||||
请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在:
|
||||
{{
|
||||
"body_action": "..."
|
||||
}}
|
||||
""",
|
||||
"change_action_prompt",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里最近的聊天记录
|
||||
|
||||
{indentify_block}
|
||||
你之前的动作状态是
|
||||
- 身体动作:{body_action}
|
||||
|
||||
身体动作可选:
|
||||
{all_actions}
|
||||
|
||||
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,你的动作会趋于平缓或静止,请你输出你现在新的动作状态,用中文。
|
||||
请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在:
|
||||
{{
|
||||
"body_action": "..."
|
||||
}}
|
||||
""",
|
||||
"regress_action_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChatAction:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
self.body_action: str = "一般"
|
||||
self.head_action: str = "注视摄像机"
|
||||
|
||||
self.regression_count: int = 0
|
||||
# 新增:body_action冷却池,key为动作名,value为剩余冷却次数
|
||||
self.body_action_cooldown: dict[str, int] = {}
|
||||
|
||||
print(s4u_config.models.motion)
|
||||
print(model_config.model_task_config.emotion)
|
||||
|
||||
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
async def send_action_update(self):
|
||||
"""发送动作更新到前端"""
|
||||
|
||||
body_code = (await get_body_code()).get(self.body_action, "")
|
||||
await send_api.custom_to_stream(
|
||||
message_type="body_action",
|
||||
content=body_code,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
async def update_action_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=15,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
try:
|
||||
# 冷却池处理:过滤掉冷却中的动作
|
||||
self._update_body_action_cooldown()
|
||||
available_actions = [k for k in (await get_body_code()).keys() if k not in self.body_action_cooldown]
|
||||
all_actions = "\n".join(available_actions)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_action_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
body_action=self.body_action,
|
||||
all_actions=all_actions,
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
if action_data := json.loads(repair_json(response)):
|
||||
# 记录原动作,切换后进入冷却
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 3
|
||||
self.body_action = new_body_action
|
||||
self.head_action = action_data.get("head_action", self.head_action)
|
||||
# 发送动作更新
|
||||
await self.send_action_update()
|
||||
|
||||
self.last_change_time = message_time
|
||||
except Exception as e:
|
||||
logger.error(f"update_action_by_message error: {e}")
|
||||
|
||||
async def regress_action(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
try:
|
||||
# 冷却池处理:过滤掉冷却中的动作
|
||||
self._update_body_action_cooldown()
|
||||
available_actions = [k for k in (await get_body_code()).keys() if k not in self.body_action_cooldown]
|
||||
all_actions = "\n".join(available_actions)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_action_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
body_action=self.body_action,
|
||||
all_actions=all_actions,
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
if action_data := json.loads(repair_json(response)):
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 6
|
||||
self.body_action = new_body_action
|
||||
# 发送动作更新
|
||||
await self.send_action_update()
|
||||
|
||||
self.regression_count += 1
|
||||
self.last_change_time = message_time
|
||||
except Exception as e:
|
||||
logger.error(f"regress_action error: {e}")
|
||||
|
||||
# 新增:冷却池维护方法
|
||||
def _update_body_action_cooldown(self):
|
||||
remove_keys = []
|
||||
for k in self.body_action_cooldown:
|
||||
self.body_action_cooldown[k] -= 1
|
||||
if self.body_action_cooldown[k] <= 0:
|
||||
remove_keys.append(k)
|
||||
for k in remove_keys:
|
||||
del self.body_action_cooldown[k]
|
||||
|
||||
|
||||
class ActionRegressionTask(AsyncTask):
|
||||
def __init__(self, action_manager: "ActionManager"):
|
||||
super().__init__(task_name="ActionRegressionTask", run_interval=3)
|
||||
self.action_manager = action_manager
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Running action regression task...")
|
||||
now = time.time()
|
||||
for action_state in self.action_manager.action_state_list:
|
||||
if action_state.last_change_time == 0:
|
||||
continue
|
||||
|
||||
if now - action_state.last_change_time > 10:
|
||||
if action_state.regression_count >= 3:
|
||||
continue
|
||||
|
||||
logger.info(f"chat {action_state.chat_id} 开始动作回归, 这是第 {action_state.regression_count + 1} 次")
|
||||
await action_state.regress_action()
|
||||
|
||||
|
||||
class ActionManager:
|
||||
def __init__(self):
|
||||
self.action_state_list: list[ChatAction] = []
|
||||
"""当前动作状态"""
|
||||
self.task_started: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""启动动作回归后台任务"""
|
||||
if self.task_started:
|
||||
return
|
||||
|
||||
logger.info("启动动作回归任务...")
|
||||
task = ActionRegressionTask(self)
|
||||
await async_task_manager.add_task(task)
|
||||
self.task_started = True
|
||||
logger.info("动作回归任务已启动")
|
||||
|
||||
def get_action_state_by_chat_id(self, chat_id: str) -> ChatAction:
|
||||
for action_state in self.action_state_list:
|
||||
if action_state.chat_id == chat_id:
|
||||
return action_state
|
||||
|
||||
new_action_state = ChatAction(chat_id)
|
||||
self.action_state_list.append(new_action_state)
|
||||
return new_action_state
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
action_manager = ActionManager()
|
||||
"""全局动作管理器"""
|
||||
|
|
@ -1,692 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from aiohttp import web, WSMsgType
|
||||
import aiohttp_cors
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("context_web")
|
||||
|
||||
|
||||
class ContextMessage:
|
||||
"""上下文消息类"""
|
||||
|
||||
def __init__(self, message: MessageRecv):
|
||||
self.user_name = message.message_info.user_info.user_nickname
|
||||
self.user_id = message.message_info.user_info.user_id
|
||||
self.content = message.processed_plain_text
|
||||
self.timestamp = datetime.now()
|
||||
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
|
||||
|
||||
# 识别消息类型
|
||||
self.is_gift = getattr(message, "is_gift", False)
|
||||
self.is_superchat = getattr(message, "is_superchat", False)
|
||||
|
||||
# 添加礼物和SC相关信息
|
||||
if self.is_gift:
|
||||
self.gift_name = getattr(message, "gift_name", "")
|
||||
self.gift_count = getattr(message, "gift_count", "1")
|
||||
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
|
||||
elif self.is_superchat:
|
||||
self.superchat_price = getattr(message, "superchat_price", "0")
|
||||
self.superchat_message = getattr(message, "superchat_message_text", "")
|
||||
if self.superchat_message:
|
||||
self.content = f"[¥{self.superchat_price}] {self.superchat_message}"
|
||||
else:
|
||||
self.content = f"[¥{self.superchat_price}] {self.content}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"user_name": self.user_name,
|
||||
"user_id": self.user_id,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
|
||||
"group_name": self.group_name,
|
||||
"is_gift": self.is_gift,
|
||||
"is_superchat": self.is_superchat,
|
||||
}
|
||||
|
||||
|
||||
class ContextWebManager:
|
||||
"""上下文网页管理器"""
|
||||
|
||||
def __init__(self, max_messages: int = 10, port: int = 8765):
|
||||
self.max_messages = max_messages
|
||||
self.port = port
|
||||
self.contexts: Dict[str, deque] = {} # chat_id -> deque of ContextMessage
|
||||
self.websockets: List[web.WebSocketResponse] = []
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False # 添加启动标志防止并发
|
||||
|
||||
async def start_server(self):
|
||||
"""启动web服务器"""
|
||||
if self.site is not None:
|
||||
logger.debug("Web服务器已经启动,跳过重复启动")
|
||||
return
|
||||
|
||||
if self._server_starting:
|
||||
logger.debug("Web服务器正在启动中,等待启动完成...")
|
||||
# 等待启动完成
|
||||
while self._server_starting and self.site is None:
|
||||
await asyncio.sleep(0.1)
|
||||
return
|
||||
|
||||
self._server_starting = True
|
||||
|
||||
try:
|
||||
self.app = web.Application()
|
||||
|
||||
# 设置CORS
|
||||
cors = aiohttp_cors.setup(
|
||||
self.app,
|
||||
defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# 添加路由
|
||||
self.app.router.add_get("/", self.index_handler)
|
||||
self.app.router.add_get("/ws", self.websocket_handler)
|
||||
self.app.router.add_get("/api/contexts", self.get_contexts_handler)
|
||||
self.app.router.add_get("/debug", self.debug_handler)
|
||||
|
||||
# 为所有路由添加CORS
|
||||
for route in list(self.app.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
|
||||
self.site = web.TCPSite(self.runner, "localhost", self.port)
|
||||
await self.site.start()
|
||||
|
||||
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动Web服务器失败: {e}")
|
||||
# 清理部分启动的资源
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
raise
|
||||
finally:
|
||||
self._server_starting = False
|
||||
|
||||
async def stop_server(self):
|
||||
"""停止web服务器"""
|
||||
if self.site:
|
||||
await self.site.stop()
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False
|
||||
|
||||
async def index_handler(self, request):
|
||||
"""主页处理器"""
|
||||
html_content = (
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>聊天上下文</title>
|
||||
<style>
|
||||
html, body {
|
||||
background: transparent !important;
|
||||
background-color: transparent !important;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
font-family: 'Microsoft YaHei', Arial, sans-serif;
|
||||
color: #ffffff;
|
||||
text-shadow: 2px 2px 4px rgba(0,0,0,0.8);
|
||||
}
|
||||
.container {
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
background: transparent !important;
|
||||
}
|
||||
.message {
|
||||
background: rgba(0, 0, 0, 0.3);
|
||||
margin: 10px 0;
|
||||
padding: 15px;
|
||||
border-radius: 10px;
|
||||
border-left: 4px solid #00ff88;
|
||||
backdrop-filter: blur(5px);
|
||||
animation: slideIn 0.3s ease-out;
|
||||
transform: translateY(0);
|
||||
transition: transform 0.5s ease, opacity 0.5s ease;
|
||||
}
|
||||
.message:hover {
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
transform: translateX(5px);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
.message.gift {
|
||||
border-left: 4px solid #ff8800;
|
||||
background: rgba(255, 136, 0, 0.2);
|
||||
}
|
||||
.message.gift:hover {
|
||||
background: rgba(255, 136, 0, 0.3);
|
||||
}
|
||||
.message.gift .username {
|
||||
color: #ff8800;
|
||||
}
|
||||
.message.superchat {
|
||||
border-left: 4px solid #ff6b6b;
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.2), rgba(107, 255, 107, 0.2), rgba(107, 107, 255, 0.2));
|
||||
background-size: 200% 200%;
|
||||
animation: rainbow 3s ease infinite;
|
||||
}
|
||||
.message.superchat:hover {
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.4), rgba(107, 255, 107, 0.4), rgba(107, 107, 255, 0.4));
|
||||
background-size: 200% 200%;
|
||||
}
|
||||
.message.superchat .username {
|
||||
background: linear-gradient(45deg, #ff6b6b, #4ecdc4, #45b7d1, #96ceb4, #feca57);
|
||||
background-size: 300% 300%;
|
||||
animation: rainbow-text 2s ease infinite;
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
background-clip: text;
|
||||
}
|
||||
@keyframes rainbow {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
@keyframes rainbow-text {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
.message-line {
|
||||
line-height: 1.4;
|
||||
word-wrap: break-word;
|
||||
font-size: 24px;
|
||||
}
|
||||
.username {
|
||||
color: #00ff88;
|
||||
}
|
||||
.content {
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.new-message {
|
||||
animation: slideInNew 0.6s ease-out;
|
||||
}
|
||||
|
||||
.debug-btn {
|
||||
position: fixed;
|
||||
bottom: 20px;
|
||||
right: 20px;
|
||||
background: rgba(0, 0, 0, 0.7);
|
||||
color: #00ff88;
|
||||
font-size: 12px;
|
||||
padding: 8px 12px;
|
||||
border-radius: 20px;
|
||||
backdrop-filter: blur(10px);
|
||||
z-index: 1000;
|
||||
text-decoration: none;
|
||||
border: 1px solid #00ff88;
|
||||
}
|
||||
.debug-btn:hover {
|
||||
background: rgba(0, 255, 136, 0.2);
|
||||
}
|
||||
@keyframes slideIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
@keyframes slideInNew {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(50px) scale(0.95);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0) scale(1);
|
||||
}
|
||||
}
|
||||
.no-messages {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
margin-top: 50px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<a href="/debug" class="debug-btn">🔧 调试</a>
|
||||
<div id="messages">
|
||||
<div class="no-messages">暂无消息</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws;
|
||||
let reconnectInterval;
|
||||
let currentMessages = []; // 存储当前显示的消息
|
||||
|
||||
function connectWebSocket() {
|
||||
console.log('正在连接WebSocket...');
|
||||
ws = new WebSocket('ws://localhost:"""
|
||||
+ str(self.port)
|
||||
+ """/ws');
|
||||
|
||||
ws.onopen = function() {
|
||||
console.log('WebSocket连接已建立');
|
||||
if (reconnectInterval) {
|
||||
clearInterval(reconnectInterval);
|
||||
reconnectInterval = null;
|
||||
}
|
||||
};
|
||||
|
||||
ws.onmessage = function(event) {
|
||||
console.log('收到WebSocket消息:', event.data);
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
updateMessages(data.contexts);
|
||||
} catch (e) {
|
||||
console.error('解析消息失败:', e, event.data);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = function(event) {
|
||||
console.log('WebSocket连接关闭:', event.code, event.reason);
|
||||
|
||||
if (!reconnectInterval) {
|
||||
reconnectInterval = setInterval(connectWebSocket, 3000);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = function(error) {
|
||||
console.error('WebSocket错误:', error);
|
||||
};
|
||||
}
|
||||
|
||||
function updateMessages(contexts) {
|
||||
const messagesDiv = document.getElementById('messages');
|
||||
|
||||
if (!contexts || contexts.length === 0) {
|
||||
messagesDiv.innerHTML = '<div class="no-messages">暂无消息</div>';
|
||||
currentMessages = [];
|
||||
return;
|
||||
}
|
||||
|
||||
// 如果是第一次加载或者消息完全不同,进行完全重新渲染
|
||||
if (currentMessages.length === 0) {
|
||||
console.log('首次加载消息,数量:', contexts.length);
|
||||
messagesDiv.innerHTML = '';
|
||||
|
||||
contexts.forEach(function(msg) {
|
||||
const messageDiv = createMessageElement(msg);
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
});
|
||||
|
||||
currentMessages = [...contexts];
|
||||
window.scrollTo(0, document.body.scrollHeight);
|
||||
return;
|
||||
}
|
||||
|
||||
// 检测新消息 - 使用更可靠的方法
|
||||
const newMessages = findNewMessages(contexts, currentMessages);
|
||||
|
||||
if (newMessages.length > 0) {
|
||||
console.log('添加新消息,数量:', newMessages.length);
|
||||
|
||||
// 先检查是否需要移除老消息(保持DOM清洁)
|
||||
const maxDisplayMessages = 15; // 比服务器端稍多一些,确保流畅性
|
||||
const currentMessageElements = messagesDiv.querySelectorAll('.message');
|
||||
const willExceedLimit = currentMessageElements.length + newMessages.length > maxDisplayMessages;
|
||||
|
||||
if (willExceedLimit) {
|
||||
const removeCount = (currentMessageElements.length + newMessages.length) - maxDisplayMessages;
|
||||
console.log('需要移除老消息数量:', removeCount);
|
||||
|
||||
for (let i = 0; i < removeCount && i < currentMessageElements.length; i++) {
|
||||
const oldMessage = currentMessageElements[i];
|
||||
oldMessage.style.transition = 'opacity 0.3s ease, transform 0.3s ease';
|
||||
oldMessage.style.opacity = '0';
|
||||
oldMessage.style.transform = 'translateY(-20px)';
|
||||
|
||||
setTimeout(() => {
|
||||
if (oldMessage.parentNode) {
|
||||
oldMessage.parentNode.removeChild(oldMessage);
|
||||
}
|
||||
}, 300);
|
||||
}
|
||||
}
|
||||
|
||||
// 添加新消息
|
||||
newMessages.forEach(function(msg) {
|
||||
const messageDiv = createMessageElement(msg, true); // true表示是新消息
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
|
||||
// 移除动画类,避免重复动画
|
||||
setTimeout(() => {
|
||||
messageDiv.classList.remove('new-message');
|
||||
}, 600);
|
||||
});
|
||||
|
||||
// 更新当前消息列表
|
||||
currentMessages = [...contexts];
|
||||
|
||||
// 平滑滚动到底部
|
||||
setTimeout(() => {
|
||||
window.scrollTo({
|
||||
top: document.body.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
|
||||
function findNewMessages(contexts, currentMessages) {
|
||||
// 如果当前消息为空,所有消息都是新的
|
||||
if (currentMessages.length === 0) {
|
||||
return contexts;
|
||||
}
|
||||
|
||||
// 找到最后一条当前消息在新消息列表中的位置
|
||||
const lastCurrentMsg = currentMessages[currentMessages.length - 1];
|
||||
let lastIndex = -1;
|
||||
|
||||
// 从后往前找,因为新消息通常在末尾
|
||||
for (let i = contexts.length - 1; i >= 0; i--) {
|
||||
const msg = contexts[i];
|
||||
if (msg.user_id === lastCurrentMsg.user_id &&
|
||||
msg.content === lastCurrentMsg.content &&
|
||||
msg.timestamp === lastCurrentMsg.timestamp) {
|
||||
lastIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到了,返回之后的消息;否则返回所有消息(可能是完全刷新)
|
||||
if (lastIndex >= 0) {
|
||||
return contexts.slice(lastIndex + 1);
|
||||
} else {
|
||||
console.log('未找到匹配的最后消息,可能需要完全刷新');
|
||||
return contexts.slice(Math.max(0, contexts.length - (currentMessages.length + 1)));
|
||||
}
|
||||
}
|
||||
|
||||
function createMessageElement(msg, isNew = false) {
|
||||
const messageDiv = document.createElement('div');
|
||||
let className = 'message';
|
||||
|
||||
// 根据消息类型添加对应的CSS类
|
||||
if (msg.is_gift) {
|
||||
className += ' gift';
|
||||
} else if (msg.is_superchat) {
|
||||
className += ' superchat';
|
||||
}
|
||||
|
||||
if (isNew) {
|
||||
className += ' new-message';
|
||||
}
|
||||
|
||||
messageDiv.className = className;
|
||||
messageDiv.innerHTML = `
|
||||
<div class="message-line">
|
||||
<span class="username">${escapeHtml(msg.user_name)}:</span><span class="content">${escapeHtml(msg.content)}</span>
|
||||
</div>
|
||||
`;
|
||||
return messageDiv;
|
||||
}
|
||||
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
// 初始加载数据
|
||||
fetch('/api/contexts')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('初始数据加载成功:', data);
|
||||
updateMessages(data.contexts);
|
||||
})
|
||||
.catch(err => console.error('加载初始数据失败:', err));
|
||||
|
||||
// 连接WebSocket
|
||||
connectWebSocket();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def websocket_handler(self, request):
|
||||
"""WebSocket处理器"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.websockets.append(ws)
|
||||
logger.debug(f"WebSocket连接建立,当前连接数: {len(self.websockets)}")
|
||||
|
||||
# 发送初始数据
|
||||
await self.send_contexts_to_websocket(ws)
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket错误: {ws.exception()}")
|
||||
break
|
||||
|
||||
# 清理断开的连接
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
logger.debug(f"WebSocket连接断开,当前连接数: {len(self.websockets)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def get_contexts_handler(self, request):
|
||||
"""获取上下文API"""
|
||||
all_context_msgs = []
|
||||
for _chat_id, contexts in self.contexts.items():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
|
||||
return web.json_response({"contexts": contexts_data})
|
||||
|
||||
async def debug_handler(self, request):
|
||||
"""调试信息处理器"""
|
||||
debug_info = {
|
||||
"server_status": "running",
|
||||
"websocket_connections": len(self.websockets),
|
||||
"total_chats": len(self.contexts),
|
||||
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
|
||||
}
|
||||
|
||||
# 构建聊天详情HTML
|
||||
chats_html = ""
|
||||
for chat_id, contexts in self.contexts.items():
|
||||
messages_html = ""
|
||||
for msg in contexts:
|
||||
timestamp = msg.timestamp.strftime("%H:%M:%S")
|
||||
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
||||
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
|
||||
|
||||
chats_html += f"""
|
||||
<div class="chat">
|
||||
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
|
||||
{messages_html}
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>调试信息</title>
|
||||
<style>
|
||||
body {{ font-family: monospace; margin: 20px; }}
|
||||
.section {{ margin: 20px 0; padding: 10px; border: 1px solid #ccc; }}
|
||||
.chat {{ margin: 10px 0; padding: 10px; background: #f5f5f5; }}
|
||||
.message {{ margin: 5px 0; padding: 5px; background: white; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>上下文网页管理器调试信息</h1>
|
||||
|
||||
<div class="section">
|
||||
<h2>服务器状态</h2>
|
||||
<p>状态: {debug_info["server_status"]}</p>
|
||||
<p>WebSocket连接数: {debug_info["websocket_connections"]}</p>
|
||||
<p>聊天总数: {debug_info["total_chats"]}</p>
|
||||
<p>消息总数: {debug_info["total_messages"]}</p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>聊天详情</h2>
|
||||
{chats_html}
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>操作</h2>
|
||||
<button onclick="location.reload()">刷新页面</button>
|
||||
<button onclick="window.location.href='/'">返回主页</button>
|
||||
<button onclick="window.location.href='/api/contexts'">查看API数据</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
console.log('调试信息:', {json.dumps(debug_info, ensure_ascii=False, indent=2)});
|
||||
setTimeout(() => location.reload(), 5000); // 5秒自动刷新
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def add_message(self, chat_id: str, message: MessageRecv):
|
||||
"""添加新消息到上下文"""
|
||||
if chat_id not in self.contexts:
|
||||
self.contexts[chat_id] = deque(maxlen=self.max_messages)
|
||||
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
|
||||
|
||||
context_msg = ContextMessage(message)
|
||||
self.contexts[chat_id].append(context_msg)
|
||||
|
||||
# 统计当前总消息数
|
||||
total_messages = sum(len(contexts) for contexts in self.contexts.values())
|
||||
|
||||
logger.info(
|
||||
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
|
||||
)
|
||||
|
||||
# 调试:打印当前所有消息
|
||||
logger.info("📝 当前上下文中的所有消息:")
|
||||
for cid, contexts in self.contexts.items():
|
||||
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
|
||||
for i, msg in enumerate(contexts):
|
||||
logger.info(
|
||||
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
|
||||
)
|
||||
|
||||
# 广播更新给所有WebSocket连接
|
||||
await self.broadcast_contexts()
|
||||
|
||||
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
|
||||
"""向单个WebSocket发送上下文数据"""
|
||||
all_context_msgs = []
|
||||
for _chat_id, contexts in self.contexts.items():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
await ws.send_str(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
async def broadcast_contexts(self):
|
||||
"""向所有WebSocket连接广播上下文更新"""
|
||||
if not self.websockets:
|
||||
logger.debug("没有WebSocket连接,跳过广播")
|
||||
return
|
||||
|
||||
all_context_msgs = []
|
||||
for _chat_id, contexts in self.contexts.items():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
message = json.dumps(data, ensure_ascii=False)
|
||||
|
||||
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
|
||||
|
||||
# 创建WebSocket列表的副本,避免在遍历时修改
|
||||
websockets_copy = self.websockets.copy()
|
||||
removed_count = 0
|
||||
|
||||
for ws in websockets_copy:
|
||||
if ws.closed:
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
removed_count += 1
|
||||
else:
|
||||
try:
|
||||
await ws.send_str(message)
|
||||
logger.debug("消息发送成功")
|
||||
except Exception as e:
|
||||
logger.error(f"发送WebSocket消息失败: {e}")
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
|
||||
|
||||
|
||||
# 全局实例
|
||||
_context_web_manager: Optional[ContextWebManager] = None
|
||||
|
||||
|
||||
def get_context_web_manager() -> ContextWebManager:
|
||||
"""获取上下文网页管理器实例"""
|
||||
global _context_web_manager
|
||||
if _context_web_manager is None:
|
||||
_context_web_manager = ContextWebManager()
|
||||
return _context_web_manager
|
||||
|
||||
|
||||
async def init_context_web_manager():
|
||||
"""初始化上下文网页管理器"""
|
||||
manager = get_context_web_manager()
|
||||
await manager.start_server()
|
||||
return manager
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
import asyncio
|
||||
from typing import Dict, Tuple, Callable, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("gift_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingGift:
|
||||
"""等待中的礼物消息"""
|
||||
|
||||
message: MessageRecvS4U
|
||||
total_count: int
|
||||
timer_task: asyncio.Task
|
||||
callback: Callable[[MessageRecvS4U], None]
|
||||
|
||||
|
||||
class GiftManager:
|
||||
"""礼物管理器,提供防抖功能"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化礼物管理器"""
|
||||
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
|
||||
self.debounce_timeout = 5.0 # 3秒防抖时间
|
||||
|
||||
async def handle_gift(
|
||||
self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None
|
||||
) -> bool:
|
||||
"""处理礼物消息,返回是否应该立即处理
|
||||
|
||||
Args:
|
||||
message: 礼物消息
|
||||
callback: 防抖完成后的回调函数
|
||||
|
||||
Returns:
|
||||
bool: False表示消息被暂存等待防抖,True表示应该立即处理
|
||||
"""
|
||||
if not message.is_gift:
|
||||
return True
|
||||
|
||||
# 构建礼物的唯一键:(发送人ID, 礼物名称)
|
||||
gift_key = (message.message_info.user_info.user_id, message.gift_name)
|
||||
|
||||
# 如果已经有相同的礼物在等待中,则合并
|
||||
if gift_key in self.pending_gifts:
|
||||
await self._merge_gift(gift_key, message)
|
||||
return False
|
||||
|
||||
# 创建新的等待礼物
|
||||
await self._create_pending_gift(gift_key, message, callback)
|
||||
return False
|
||||
|
||||
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
|
||||
"""合并礼物消息"""
|
||||
pending_gift = self.pending_gifts[gift_key]
|
||||
|
||||
# 取消之前的定时器
|
||||
if not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
|
||||
# 累加礼物数量
|
||||
try:
|
||||
new_count = int(new_message.gift_count)
|
||||
pending_gift.total_count += new_count
|
||||
|
||||
# 更新消息为最新的(保留最新的消息,但累加数量)
|
||||
pending_gift.message = new_message
|
||||
pending_gift.message.gift_count = str(pending_gift.total_count)
|
||||
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
|
||||
|
||||
except ValueError:
|
||||
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
|
||||
# 如果无法解析数量,保持原有数量不变
|
||||
|
||||
# 重新创建定时器
|
||||
pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
|
||||
|
||||
async def _create_pending_gift(
|
||||
self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]]
|
||||
) -> None:
|
||||
"""创建新的等待礼物"""
|
||||
try:
|
||||
initial_count = int(message.gift_count)
|
||||
except ValueError:
|
||||
initial_count = 1
|
||||
logger.warning(f"无法解析礼物数量: {message.gift_count},默认设为1")
|
||||
|
||||
# 创建定时器任务
|
||||
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
# 创建等待礼物对象
|
||||
pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback)
|
||||
|
||||
self.pending_gifts[gift_key] = pending_gift
|
||||
|
||||
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
|
||||
|
||||
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
|
||||
"""礼物防抖超时处理"""
|
||||
try:
|
||||
# 等待防抖时间
|
||||
await asyncio.sleep(self.debounce_timeout)
|
||||
|
||||
# 获取等待中的礼物
|
||||
if gift_key not in self.pending_gifts:
|
||||
return
|
||||
|
||||
pending_gift = self.pending_gifts.pop(gift_key)
|
||||
|
||||
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
|
||||
|
||||
message = pending_gift.message
|
||||
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
|
||||
|
||||
# 执行回调
|
||||
if pending_gift.callback:
|
||||
try:
|
||||
pending_gift.callback(message)
|
||||
except Exception as e:
|
||||
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 定时器被取消,不需要处理
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
|
||||
|
||||
def get_pending_count(self) -> int:
|
||||
"""获取当前等待中的礼物数量"""
|
||||
return len(self.pending_gifts)
|
||||
|
||||
async def flush_all(self) -> None:
|
||||
"""立即处理所有等待中的礼物"""
|
||||
for gift_key in list(self.pending_gifts.keys()):
|
||||
pending_gift = self.pending_gifts.get(gift_key)
|
||||
if pending_gift and not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
await self._gift_timeout(gift_key)
|
||||
|
||||
|
||||
# 创建全局礼物管理器实例
|
||||
gift_manager = GiftManager()
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
class InternalManager:
|
||||
def __init__(self):
|
||||
self.now_internal_state = str()
|
||||
|
||||
def set_internal_state(self, internal_state: str):
|
||||
self.now_internal_state = internal_state
|
||||
|
||||
def get_internal_state(self):
|
||||
return self.now_internal_state
|
||||
|
||||
def get_internal_state_str(self):
|
||||
return f"你今天的直播内容是直播QQ水群,你正在一边回复弹幕,一边在QQ群聊天,你在QQ群聊天中产生的想法是:{self.now_internal_state}"
|
||||
|
||||
|
||||
internal_manager = InternalManager()
|
||||
|
|
@ -1,584 +0,0 @@
|
|||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import random
|
||||
from typing import Optional, Dict, Tuple, List # 导入类型提示
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from .s4u_stream_generator import S4UStreamGenerator
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U
|
||||
from src.config.config import global_config
|
||||
from src.common.message.api import get_global_api
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from .s4u_watching_manager import watching_manager
|
||||
import json
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import get_person_id
|
||||
from .yes_or_no import yes_or_no_head
|
||||
|
||||
logger = get_logger("S4U_chat")
|
||||
|
||||
|
||||
class MessageSenderContainer:
|
||||
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
|
||||
|
||||
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
|
||||
self.chat_stream = chat_stream
|
||||
self.original_message = original_message
|
||||
self.queue = asyncio.Queue()
|
||||
self.storage = MessageStorage()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._paused_event = asyncio.Event()
|
||||
self._paused_event.set() # 默认设置为非暂停状态
|
||||
|
||||
self.msg_id = ""
|
||||
|
||||
self.last_msg_id = ""
|
||||
|
||||
self.voice_done = ""
|
||||
|
||||
async def add_message(self, chunk: str):
|
||||
"""向队列中添加一个消息块。"""
|
||||
await self.queue.put(chunk)
|
||||
|
||||
async def close(self):
|
||||
"""表示没有更多消息了,关闭队列。"""
|
||||
await self.queue.put(None) # Sentinel
|
||||
|
||||
def pause(self):
|
||||
"""暂停发送。"""
|
||||
self._paused_event.clear()
|
||||
|
||||
def resume(self):
|
||||
"""恢复发送。"""
|
||||
self._paused_event.set()
|
||||
|
||||
def _calculate_typing_delay(self, text: str) -> float:
|
||||
"""根据文本长度计算模拟打字延迟。"""
|
||||
chars_per_second = s4u_config.chars_per_second
|
||||
min_delay = s4u_config.min_typing_delay
|
||||
max_delay = s4u_config.max_typing_delay
|
||||
|
||||
delay = len(text) / chars_per_second
|
||||
return max(min_delay, min(delay, max_delay))
|
||||
|
||||
async def _send_worker(self):
|
||||
"""从队列中取出消息并发送。"""
|
||||
while True:
|
||||
try:
|
||||
# This structure ensures that task_done() is called for every item retrieved,
|
||||
# even if the worker is cancelled while processing the item.
|
||||
chunk = await self.queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
try:
|
||||
if chunk is None:
|
||||
break
|
||||
|
||||
# Check for pause signal *after* getting an item.
|
||||
await self._paused_event.wait()
|
||||
|
||||
# 根据配置选择延迟模式
|
||||
if s4u_config.enable_dynamic_typing_delay:
|
||||
delay = self._calculate_typing_delay(chunk)
|
||||
else:
|
||||
delay = s4u_config.typing_delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
message_segment = Seg(type="tts_text", data=f"{self.msg_id}:{chunk}")
|
||||
bot_message = MessageSending(
|
||||
message_id=self.msg_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.original_message.message_info.platform,
|
||||
),
|
||||
sender_info=self.original_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=self.original_message,
|
||||
is_emoji=False,
|
||||
apply_set_reply_logic=True,
|
||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||
)
|
||||
|
||||
await bot_message.process()
|
||||
|
||||
await get_global_api().send_message(bot_message)
|
||||
logger.info(f"已将消息 '{self.msg_id}:{chunk}' 发往平台 '{bot_message.message_info.platform}'")
|
||||
|
||||
message_segment = Seg(type="text", data=chunk)
|
||||
bot_message = MessageSending(
|
||||
message_id=self.msg_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.original_message.message_info.platform,
|
||||
),
|
||||
sender_info=self.original_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=self.original_message,
|
||||
is_emoji=False,
|
||||
apply_set_reply_logic=True,
|
||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||
)
|
||||
await bot_message.process()
|
||||
|
||||
await self.storage.store_message(bot_message, self.chat_stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
||||
self.queue.task_done()
|
||||
|
||||
def start(self):
|
||||
"""启动发送任务。"""
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._send_worker())
|
||||
|
||||
async def join(self):
|
||||
"""等待所有消息发送完毕。"""
|
||||
if self._task:
|
||||
await self._task
|
||||
|
||||
|
||||
class S4UChatManager:
|
||||
def __init__(self):
|
||||
self.s4u_chats: Dict[str, "S4UChat"] = {}
|
||||
|
||||
def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
|
||||
if chat_stream.stream_id not in self.s4u_chats:
|
||||
stream_name = get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
|
||||
logger.info(f"Creating new S4UChat for stream: {stream_name}")
|
||||
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
|
||||
return self.s4u_chats[chat_stream.stream_id]
|
||||
|
||||
|
||||
if not s4u_config.enable_s4u:
|
||||
s4u_chat_manager = None
|
||||
else:
|
||||
s4u_chat_manager = S4UChatManager()
|
||||
|
||||
|
||||
def get_s4u_chat_manager() -> S4UChatManager:
|
||||
return s4u_chat_manager
|
||||
|
||||
|
||||
class S4UChat:
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
"""初始化 S4UChat 实例。"""
|
||||
|
||||
self.chat_stream = chat_stream
|
||||
self.stream_id = chat_stream.stream_id
|
||||
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
||||
|
||||
# 两个消息队列
|
||||
self._vip_queue = asyncio.PriorityQueue()
|
||||
self._normal_queue = asyncio.PriorityQueue()
|
||||
|
||||
self._entry_counter = 0 # 保证FIFO的全局计数器
|
||||
self._new_message_event = asyncio.Event() # 用于唤醒处理器
|
||||
|
||||
self._processing_task = asyncio.create_task(self._message_processor())
|
||||
self._current_generation_task: Optional[asyncio.Task] = None
|
||||
# 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象)
|
||||
self._current_message_being_replied: Optional[Tuple[str, float, int, MessageRecv]] = None
|
||||
|
||||
self._is_replying = False
|
||||
self.gpt = S4UStreamGenerator()
|
||||
self.gpt.chat_stream = self.chat_stream
|
||||
self.interest_dict: Dict[str, float] = {} # 用户兴趣分
|
||||
|
||||
self.internal_message: List[MessageRecvS4U] = []
|
||||
|
||||
self.msg_id = ""
|
||||
self.voice_done = ""
|
||||
|
||||
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
|
||||
|
||||
def _get_priority_info(self, message: MessageRecv) -> dict:
|
||||
"""安全地从消息中提取和解析 priority_info"""
|
||||
priority_info_raw = message.priority_info
|
||||
priority_info = {}
|
||||
if isinstance(priority_info_raw, str):
|
||||
try:
|
||||
priority_info = json.loads(priority_info_raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse priority_info JSON: {priority_info_raw}")
|
||||
elif isinstance(priority_info_raw, dict):
|
||||
priority_info = priority_info_raw
|
||||
return priority_info
|
||||
|
||||
def _is_vip(self, priority_info: dict) -> bool:
|
||||
"""检查消息是否来自VIP用户。"""
|
||||
return priority_info.get("message_type") == "vip"
|
||||
|
||||
def _get_interest_score(self, user_id: str) -> float:
|
||||
"""获取用户的兴趣分,默认为1.0"""
|
||||
return self.interest_dict.get(user_id, 1.0)
|
||||
|
||||
def go_processing(self):
|
||||
if self.voice_done == self.last_msg_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float:
|
||||
"""
|
||||
为消息计算基础优先级分数。分数越高,优先级越高。
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 加上消息自带的优先级
|
||||
score += priority_info.get("message_priority", 0.0)
|
||||
|
||||
# 加上用户的固有兴趣分
|
||||
score += self._get_interest_score(message.message_info.user_info.user_id)
|
||||
return score
|
||||
|
||||
def decay_interest_score(self):
|
||||
for person_id, score in self.interest_dict.items():
|
||||
if score > 0:
|
||||
self.interest_dict[person_id] = score * 0.95
|
||||
else:
|
||||
self.interest_dict[person_id] = 0
|
||||
|
||||
async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
|
||||
self.decay_interest_score()
|
||||
|
||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||
user_id = message.message_info.user_info.user_id
|
||||
platform = message.message_info.platform
|
||||
_person_id = get_person_id(platform, user_id)
|
||||
|
||||
# try:
|
||||
# is_gift = message.is_gift
|
||||
# is_superchat = message.is_superchat
|
||||
# # print(is_gift)
|
||||
# # print(is_superchat)
|
||||
# if is_gift:
|
||||
# await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# # 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
# current_score = self.interest_dict.get(person_id, 1.0)
|
||||
# self.interest_dict[person_id] = current_score + 0.1 * message.gift_count
|
||||
# elif is_superchat:
|
||||
# await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# # 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
# current_score = self.interest_dict.get(person_id, 1.0)
|
||||
# self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
|
||||
|
||||
# # 添加SuperChat到管理器
|
||||
# super_chat_manager = get_super_chat_manager()
|
||||
# await super_chat_manager.add_superchat(message)
|
||||
# else:
|
||||
# await self.relationship_builder.build_relation(20)
|
||||
# except Exception:
|
||||
# traceback.print_exc()
|
||||
|
||||
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
|
||||
|
||||
priority_info = self._get_priority_info(message)
|
||||
is_vip = self._is_vip(priority_info)
|
||||
new_priority_score = self._calculate_base_priority_score(message, priority_info)
|
||||
|
||||
should_interrupt = False
|
||||
if (
|
||||
s4u_config.enable_message_interruption
|
||||
and self._current_generation_task
|
||||
and not self._current_generation_task.done()
|
||||
):
|
||||
if self._current_message_being_replied:
|
||||
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
||||
|
||||
# 规则:VIP从不被打断
|
||||
if current_queue == "vip":
|
||||
pass # Do nothing
|
||||
|
||||
# 规则:普通消息可以被打断
|
||||
elif current_queue == "normal":
|
||||
# VIP消息可以打断普通消息
|
||||
if is_vip:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.")
|
||||
# 普通消息的内部打断逻辑
|
||||
else:
|
||||
new_sender_id = message.message_info.user_info.user_id
|
||||
current_sender_id = current_msg.message_info.user_info.user_id
|
||||
# 新消息优先级更高
|
||||
if new_priority_score > current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.")
|
||||
# 同用户,新消息的优先级不能更低
|
||||
elif new_sender_id == current_sender_id and new_priority_score >= current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
|
||||
|
||||
if should_interrupt:
|
||||
if self.gpt.partial_response:
|
||||
logger.warning(
|
||||
f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'"
|
||||
)
|
||||
self._current_generation_task.cancel()
|
||||
|
||||
# asyncio.PriorityQueue 是最小堆,所以我们存入分数的相反数
|
||||
# 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前
|
||||
item = (-new_priority_score, self._entry_counter, time.time(), message)
|
||||
|
||||
if is_vip and s4u_config.vip_queue_priority:
|
||||
await self._vip_queue.put(item)
|
||||
logger.info(f"[{self.stream_name}] VIP message added to queue.")
|
||||
else:
|
||||
await self._normal_queue.put(item)
|
||||
|
||||
self._entry_counter += 1
|
||||
self._new_message_event.set() # 唤醒处理器
|
||||
|
||||
def _cleanup_old_normal_messages(self):
|
||||
"""清理普通队列中不在最近N条消息范围内的消息"""
|
||||
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
|
||||
return
|
||||
|
||||
# 计算阈值:保留最近 recent_message_keep_count 条消息
|
||||
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
|
||||
|
||||
# 临时存储需要保留的消息
|
||||
temp_messages = []
|
||||
removed_count = 0
|
||||
|
||||
# 取出所有普通队列中的消息
|
||||
while not self._normal_queue.empty():
|
||||
try:
|
||||
item = self._normal_queue.get_nowait()
|
||||
neg_priority, entry_count, timestamp, message = item
|
||||
|
||||
# 如果消息在最近N条消息范围内,保留它
|
||||
logger.info(
|
||||
f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}"
|
||||
)
|
||||
|
||||
if entry_count >= cutoff_counter:
|
||||
temp_messages.append(item)
|
||||
else:
|
||||
removed_count += 1
|
||||
self._normal_queue.task_done() # 标记被移除的任务为完成
|
||||
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# 将保留的消息重新放入队列
|
||||
for item in temp_messages:
|
||||
self._normal_queue.put_nowait(item)
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(
|
||||
f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}条,现在counter:{self._entry_counter}被移除"
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range."
|
||||
)
|
||||
|
||||
async def _message_processor(self):
|
||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||
while True:
|
||||
try:
|
||||
# 等待有新消息的信号,避免空转
|
||||
await self._new_message_event.wait()
|
||||
self._new_message_event.clear()
|
||||
|
||||
# 清理普通队列中的过旧消息
|
||||
self._cleanup_old_normal_messages()
|
||||
|
||||
# 优先处理VIP队列
|
||||
if not self._vip_queue.empty():
|
||||
neg_priority, entry_count, _, message = self._vip_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
queue_name = "vip"
|
||||
# 其次处理普通队列
|
||||
elif not self._normal_queue.empty():
|
||||
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
# 检查普通消息是否超时
|
||||
if time.time() - timestamp > s4u_config.message_timeout_seconds:
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
|
||||
)
|
||||
self._normal_queue.task_done()
|
||||
continue # 处理下一条
|
||||
queue_name = "normal"
|
||||
else:
|
||||
if self.internal_message:
|
||||
message = self.internal_message[-1]
|
||||
self.internal_message = []
|
||||
|
||||
priority = 0
|
||||
neg_priority = 0
|
||||
entry_count = 0
|
||||
queue_name = "internal"
|
||||
|
||||
logger.info(
|
||||
f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..."
|
||||
)
|
||||
else:
|
||||
continue # 没有消息了,回去等事件
|
||||
|
||||
self._current_message_being_replied = (queue_name, priority, entry_count, message)
|
||||
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
||||
|
||||
try:
|
||||
await self._current_generation_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded."
|
||||
)
|
||||
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
|
||||
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True)
|
||||
finally:
|
||||
self._current_generation_task = None
|
||||
self._current_message_being_replied = None
|
||||
# 标记任务完成
|
||||
if queue_name == "vip":
|
||||
self._vip_queue.task_done()
|
||||
elif queue_name == "internal":
|
||||
# 如果使用 internal_message 生成回复,则不从 normal 队列中移除
|
||||
pass
|
||||
else:
|
||||
self._normal_queue.task_done()
|
||||
|
||||
# 检查是否还有任务,有则立即再次触发事件
|
||||
if not self._vip_queue.empty() or not self._normal_queue.empty():
|
||||
self._new_message_event.set()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] Message processor is shutting down.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def get_processing_message_id(self):
|
||||
self.last_msg_id = self.msg_id
|
||||
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
|
||||
|
||||
async def _generate_and_send(self, message: MessageRecv):
|
||||
"""为单个消息生成文本回复。整个过程可以被中断。"""
|
||||
self._is_replying = True
|
||||
total_chars_sent = 0 # 跟踪发送的总字符数
|
||||
|
||||
self.get_processing_message_id()
|
||||
|
||||
# 视线管理:开始生成回复时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
|
||||
if message.is_internal:
|
||||
await chat_watching.on_internal_message_start()
|
||||
else:
|
||||
await chat_watching.on_reply_start()
|
||||
|
||||
sender_container = MessageSenderContainer(self.chat_stream, message)
|
||||
sender_container.start()
|
||||
|
||||
async def generate_and_send_inner():
|
||||
nonlocal total_chars_sent
|
||||
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
|
||||
|
||||
if s4u_config.enable_streaming_output:
|
||||
logger.info("[S4U] 开始流式输出")
|
||||
# 流式输出,边生成边发送
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message(chunk)
|
||||
total_chars_sent += len(chunk)
|
||||
else:
|
||||
logger.info("[S4U] 开始一次性输出")
|
||||
# 一次性输出,先收集所有chunk
|
||||
all_chunks = []
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
all_chunks.append(chunk)
|
||||
total_chars_sent += len(chunk)
|
||||
# 一次性发送
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message("".join(all_chunks))
|
||||
|
||||
try:
|
||||
try:
|
||||
await asyncio.wait_for(generate_and_send_inner(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[{self.stream_name}] 回复生成超时,发送默认回复。")
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message("麦麦不知道哦")
|
||||
total_chars_sent = len("麦麦不知道哦")
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(self.stream_id)
|
||||
await yes_or_no_head(
|
||||
text=total_chars_sent,
|
||||
emotion=mood.mood_state,
|
||||
chat_history=message.processed_plain_text,
|
||||
chat_id=self.stream_id,
|
||||
)
|
||||
|
||||
# 等待所有文本消息发送完成
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
|
||||
await chat_watching.on_thinking_finished()
|
||||
|
||||
start_time = time.time()
|
||||
logged = False
|
||||
while not self.go_processing():
|
||||
if time.time() - start_time > 60:
|
||||
logger.warning(f"[{self.stream_name}] 等待消息发送超时(60秒),强制跳出循环。")
|
||||
break
|
||||
if not logged:
|
||||
logger.info(f"[{self.stream_name}] 等待消息发送完成...")
|
||||
logged = True
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。")
|
||||
raise # 将取消异常向上传播
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
|
||||
# 回复生成实时展示:清空内容(出错时)
|
||||
finally:
|
||||
self._is_replying = False
|
||||
|
||||
# 视线管理:回复结束时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
await chat_watching.on_reply_finished()
|
||||
|
||||
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
|
||||
sender_container.resume()
|
||||
if not sender_container._task.done():
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
|
||||
|
||||
async def shutdown(self):
|
||||
"""平滑关闭处理任务。"""
|
||||
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
|
||||
|
||||
# 取消正在运行的任务
|
||||
if self._current_generation_task and not self._current_generation_task.done():
|
||||
self._current_generation_task.cancel()
|
||||
|
||||
if self._processing_task and not self._processing_task.done():
|
||||
self._processing_task.cancel()
|
||||
|
||||
# 等待任务响应取消
|
||||
try:
|
||||
await self._processing_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"处理任务已成功取消: {self.stream_name}")
|
||||
|
|
@ -1,456 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
|
||||
"""
|
||||
情绪管理系统使用说明:
|
||||
|
||||
1. 情绪数值系统:
|
||||
- 情绪包含四个维度:joy(喜), anger(怒), sorrow(哀), fear(惧)
|
||||
- 每个维度的取值范围为1-10
|
||||
- 当情绪发生变化时,会自动发送到ws端处理
|
||||
|
||||
2. 情绪更新机制:
|
||||
- 接收到新消息时会更新情绪状态
|
||||
- 定期进行情绪回归(冷静下来)
|
||||
- 每次情绪变化都会发送到ws端,格式为:
|
||||
type: "emotion"
|
||||
data: {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
|
||||
3. ws端处理:
|
||||
- 本地只负责情绪计算和发送情绪数值
|
||||
- 表情渲染和动作由ws端根据情绪数值处理
|
||||
"""
|
||||
|
||||
logger = get_logger("mood")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里正在进行的对话
|
||||
|
||||
{indentify_block}
|
||||
你刚刚的情绪状态是:{mood_state}
|
||||
|
||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态,不要输出任何其他内容
|
||||
请只输出情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"change_mood_prompt_vtb",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里最近的对话
|
||||
|
||||
{indentify_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
|
||||
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态
|
||||
请只输出情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"regress_mood_prompt_vtb",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里正在进行的对话
|
||||
|
||||
{indentify_block}
|
||||
你刚刚的情绪状态是:{mood_state}
|
||||
具体来说,从1-10分,你的情绪状态是:
|
||||
喜(Joy): {joy}
|
||||
怒(Anger): {anger}
|
||||
哀(Sorrow): {sorrow}
|
||||
惧(Fear): {fear}
|
||||
|
||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考。请基于对话内容,评估你新的情绪状态。
|
||||
请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。
|
||||
键值请使用英文: "joy", "anger", "sorrow", "fear".
|
||||
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
|
||||
不要输出任何其他内容,只输出JSON。
|
||||
""",
|
||||
"change_mood_numerical_prompt",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里最近的对话
|
||||
|
||||
{indentify_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
具体来说,从1-10分,你的情绪状态是:
|
||||
喜(Joy): {joy}
|
||||
怒(Anger): {anger}
|
||||
哀(Sorrow): {sorrow}
|
||||
惧(Fear): {fear}
|
||||
|
||||
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来。请基于此,评估你现在的情绪状态。
|
||||
请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。
|
||||
键值请使用英文: "joy", "anger", "sorrow", "fear".
|
||||
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
|
||||
不要输出任何其他内容,只输出JSON。
|
||||
""",
|
||||
"regress_mood_numerical_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChatMood:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
self.mood_state: str = "感觉很平静"
|
||||
self.mood_values: dict[str, int] = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
|
||||
self.regression_count: int = 0
|
||||
|
||||
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
|
||||
self.mood_model_numerical = LLMRequest(
|
||||
model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
|
||||
)
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(self.send_emotion_update(self.mood_values))
|
||||
|
||||
def _parse_numerical_mood(self, response: str) -> dict[str, int] | None:
|
||||
try:
|
||||
# The LLM might output markdown with json inside
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
data = json.loads(response)
|
||||
|
||||
# Validate
|
||||
required_keys = {"joy", "anger", "sorrow", "fear"}
|
||||
if not required_keys.issubset(data.keys()):
|
||||
logger.warning(f"Numerical mood response missing keys: {response}")
|
||||
return None
|
||||
|
||||
for key in required_keys:
|
||||
value = data[key]
|
||||
if not isinstance(value, int) or not (1 <= value <= 10):
|
||||
logger.warning(f"Numerical mood response invalid value for {key}: {value} in {response}")
|
||||
return None
|
||||
|
||||
return {key: data[key] for key in required_keys}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse numerical mood JSON: {response}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing numerical mood: {e}, response: {response}")
|
||||
return None
|
||||
|
||||
async def update_mood_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
async def _update_text_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_mood_prompt_vtb",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text mood prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text mood response: {response}")
|
||||
logger.debug(f"text mood reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
|
||||
async def _update_numerical_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_mood_numerical_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
joy=self.mood_values["joy"],
|
||||
anger=self.mood_values["anger"],
|
||||
sorrow=self.mood_values["sorrow"],
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical mood prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt, temperature=0.4
|
||||
)
|
||||
logger.info(f"numerical mood response: {response}")
|
||||
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
|
||||
return self._parse_numerical_mood(response)
|
||||
|
||||
results = await asyncio.gather(_update_text_mood(), _update_numerical_mood())
|
||||
text_mood_response, numerical_mood_response = results
|
||||
|
||||
if text_mood_response:
|
||||
self.mood_state = text_mood_response
|
||||
|
||||
if numerical_mood_response:
|
||||
_old_mood_values = self.mood_values.copy()
|
||||
self.mood_values = numerical_mood_response
|
||||
|
||||
# 发送情绪更新到ws端
|
||||
await self.send_emotion_update(self.mood_values)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}")
|
||||
|
||||
self.last_change_time = message_time
|
||||
|
||||
async def regress_mood(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=5,
|
||||
limit_mode="last",
|
||||
)
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
async def _regress_text_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_mood_prompt_vtb",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text regress prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text regress response: {response}")
|
||||
logger.debug(f"text regress reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
|
||||
async def _regress_numerical_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_mood_numerical_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
joy=self.mood_values["joy"],
|
||||
anger=self.mood_values["anger"],
|
||||
sorrow=self.mood_values["sorrow"],
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical regress prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.4,
|
||||
)
|
||||
logger.info(f"numerical regress response: {response}")
|
||||
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
|
||||
return self._parse_numerical_mood(response)
|
||||
|
||||
results = await asyncio.gather(_regress_text_mood(), _regress_numerical_mood())
|
||||
text_mood_response, numerical_mood_response = results
|
||||
|
||||
if text_mood_response:
|
||||
self.mood_state = text_mood_response
|
||||
|
||||
if numerical_mood_response:
|
||||
_old_mood_values = self.mood_values.copy()
|
||||
self.mood_values = numerical_mood_response
|
||||
|
||||
# 发送情绪更新到ws端
|
||||
await self.send_emotion_update(self.mood_values)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}")
|
||||
|
||||
self.regression_count += 1
|
||||
|
||||
async def send_emotion_update(self, mood_values: dict[str, int]):
|
||||
"""发送情绪更新到ws端"""
|
||||
emotion_data = {
|
||||
"joy": mood_values.get("joy", 5),
|
||||
"anger": mood_values.get("anger", 1),
|
||||
"sorrow": mood_values.get("sorrow", 1),
|
||||
"fear": mood_values.get("fear", 1),
|
||||
}
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="emotion",
|
||||
content=emotion_data,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}")
|
||||
|
||||
|
||||
class MoodRegressionTask(AsyncTask):
|
||||
def __init__(self, mood_manager: "MoodManager"):
|
||||
super().__init__(task_name="MoodRegressionTask", run_interval=30)
|
||||
self.mood_manager = mood_manager
|
||||
self.run_count = 0
|
||||
|
||||
async def run(self):
|
||||
self.run_count += 1
|
||||
logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态")
|
||||
|
||||
now = time.time()
|
||||
regression_executed = 0
|
||||
|
||||
for mood in self.mood_manager.mood_list:
|
||||
chat_info = f"chat {mood.chat_id}"
|
||||
|
||||
if mood.last_change_time == 0:
|
||||
logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归")
|
||||
continue
|
||||
|
||||
time_since_last_change = now - mood.last_change_time
|
||||
|
||||
# 检查是否有极端情绪需要快速回归
|
||||
high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8}
|
||||
has_extreme_emotion = len(high_emotions) > 0
|
||||
|
||||
# 回归条件:1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s
|
||||
should_regress = False
|
||||
regress_reason = ""
|
||||
|
||||
if time_since_last_change > 120:
|
||||
should_regress = True
|
||||
regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)"
|
||||
elif has_extreme_emotion and time_since_last_change > 30:
|
||||
should_regress = True
|
||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||
regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)"
|
||||
|
||||
if should_regress:
|
||||
if mood.regression_count >= 3:
|
||||
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
|
||||
)
|
||||
await mood.regress_mood()
|
||||
regression_executed += 1
|
||||
else:
|
||||
if has_extreme_emotion:
|
||||
remaining_time = 5 - time_since_last_change
|
||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||
logger.debug(
|
||||
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒"
|
||||
)
|
||||
else:
|
||||
remaining_time = 120 - time_since_last_change
|
||||
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒")
|
||||
|
||||
if regression_executed > 0:
|
||||
logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归")
|
||||
else:
|
||||
logger.debug("[回归任务] 本次没有符合回归条件的聊天")
|
||||
|
||||
|
||||
class MoodManager:
|
||||
def __init__(self):
|
||||
self.mood_list: list[ChatMood] = []
|
||||
"""当前情绪状态"""
|
||||
self.task_started: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""启动情绪回归后台任务"""
|
||||
if self.task_started:
|
||||
return
|
||||
|
||||
logger.info("启动情绪管理任务...")
|
||||
|
||||
# 启动情绪回归任务
|
||||
regression_task = MoodRegressionTask(self)
|
||||
await async_task_manager.add_task(regression_task)
|
||||
|
||||
self.task_started = True
|
||||
logger.info("情绪管理任务已启动(情绪回归)")
|
||||
|
||||
def get_mood_by_chat_id(self, chat_id: str) -> ChatMood:
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
return mood
|
||||
|
||||
new_mood = ChatMood(chat_id)
|
||||
self.mood_list.append(new_mood)
|
||||
return new_mood
|
||||
|
||||
def reset_mood_by_chat_id(self, chat_id: str):
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
mood.mood_state = "感觉很平静"
|
||||
mood.mood_values = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
mood.regression_count = 0
|
||||
# 发送重置后的情绪状态到ws端
|
||||
asyncio.create_task(mood.send_emotion_update(mood.mood_values))
|
||||
return
|
||||
|
||||
# 如果没有找到现有的mood,创建新的
|
||||
new_mood = ChatMood(chat_id)
|
||||
self.mood_list.append(new_mood)
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
|
||||
|
||||
|
||||
if s4u_config.enable_s4u:
|
||||
init_prompt()
|
||||
mood_manager = MoodManager()
|
||||
else:
|
||||
mood_manager = None
|
||||
|
||||
"""全局情绪管理器"""
|
||||
|
|
@ -1,255 +0,0 @@
|
|||
import asyncio
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from maim_message.message_base import GroupInfo
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
|
||||
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
|
||||
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
|
||||
from src.mais4u.mais4u_chat.gift_manager import gift_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
|
||||
from .s4u_chat import get_s4u_chat_manager
|
||||
|
||||
|
||||
# from ..message_receive.message_buffer import message_buffer
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool]: (兴趣度, 是否被提及)
|
||||
"""
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
interested_rate, _, _ = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
# 1-5字符:线性增长 0.01 -> 0.03
|
||||
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
|
||||
elif text_len <= 10:
|
||||
# 6-10字符:线性增长 0.03 -> 0.06
|
||||
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
|
||||
elif text_len <= 20:
|
||||
# 11-20字符:线性增长 0.06 -> 0.12
|
||||
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
|
||||
elif text_len <= 30:
|
||||
# 21-30字符:线性增长 0.12 -> 0.18
|
||||
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
|
||||
elif text_len <= 50:
|
||||
# 31-50字符:线性增长 0.18 -> 0.22
|
||||
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
|
||||
elif text_len <= 100:
|
||||
# 51-100字符:线性增长 0.22 -> 0.26
|
||||
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
interested_rate += base_interest
|
||||
|
||||
if is_mentioned:
|
||||
interest_increase_on_mention = 1
|
||||
interested_rate += interest_increase_on_mention
|
||||
|
||||
return interested_rate, is_mentioned
|
||||
|
||||
|
||||
class S4UMessageProcessor:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
1. 消息解析与初始化
|
||||
2. 消息缓冲处理
|
||||
3. 过滤检查
|
||||
4. 兴趣度计算
|
||||
5. 关系处理
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
|
||||
# 1. 消息解析与初始化
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
message_info = message.message_info
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message_info.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
if await self.handle_internal_message(message):
|
||||
return
|
||||
|
||||
if await self.hadle_if_voice_done(message):
|
||||
return
|
||||
|
||||
# 处理礼物消息,如果消息被暂存则停止当前处理流程
|
||||
if not skip_gift_debounce and not await self.handle_if_gift(message):
|
||||
return
|
||||
await self.check_if_fake_gift(message)
|
||||
|
||||
# 处理屏幕消息
|
||||
if await self.handle_screen_message(message):
|
||||
return
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
|
||||
await s4u_chat.add_message(message)
|
||||
|
||||
_interested_rate, _ = await _calculate_interest(message)
|
||||
|
||||
await mood_manager.start()
|
||||
|
||||
# 一系列llm驱动的前处理
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message))
|
||||
chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_action.update_action_by_message(message))
|
||||
# 视线管理:收到消息时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id)
|
||||
await chat_watching.on_message_received()
|
||||
|
||||
# 上下文网页管理:启动独立task处理消息上下文
|
||||
asyncio.create_task(self._handle_context_web_update(chat.stream_id, message))
|
||||
|
||||
# 日志记录
|
||||
if message.is_gift:
|
||||
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
|
||||
else:
|
||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
|
||||
async def handle_internal_message(self, message: MessageRecvS4U):
|
||||
if message.is_internal:
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
|
||||
)
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
message.message_info.group_info = s4u_chat.chat_stream.group_info
|
||||
message.message_info.platform = s4u_chat.chat_stream.platform
|
||||
|
||||
s4u_chat.internal_message.append(message)
|
||||
s4u_chat._new_message_event.set()
|
||||
|
||||
logger.info(
|
||||
f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}"
|
||||
)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
async def handle_screen_message(self, message: MessageRecvS4U):
|
||||
if message.is_screen:
|
||||
screen_manager.set_screen(message.screen_info)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def hadle_if_voice_done(self, message: MessageRecvS4U):
|
||||
if message.voice_done:
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
|
||||
s4u_chat.voice_done = message.voice_done
|
||||
return True
|
||||
return False
|
||||
|
||||
async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""检查消息是否为假礼物"""
|
||||
if message.is_gift:
|
||||
return False
|
||||
|
||||
gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"]
|
||||
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
|
||||
message.is_fake_gift = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""处理礼物消息
|
||||
|
||||
Returns:
|
||||
bool: True表示应该继续处理消息,False表示消息已被暂存不需要继续处理
|
||||
"""
|
||||
if message.is_gift:
|
||||
# 定义防抖完成后的回调函数
|
||||
def gift_callback(merged_message: MessageRecvS4U):
|
||||
"""礼物防抖完成后的回调"""
|
||||
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
|
||||
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
|
||||
|
||||
# 交给礼物管理器处理,并传入回调函数
|
||||
# 对于礼物消息,handle_gift 总是返回 False(消息被暂存)
|
||||
await gift_manager.handle_gift(message, gift_callback)
|
||||
return False # 消息被暂存,不继续处理
|
||||
|
||||
return True # 非礼物消息,继续正常处理
|
||||
|
||||
async def _handle_context_web_update(self, chat_id: str, message: MessageRecv):
|
||||
"""处理上下文网页更新的独立task
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
context_manager = get_context_web_manager()
|
||||
|
||||
# 只在服务器未启动时启动(避免重复启动)
|
||||
if context_manager.site is None:
|
||||
logger.info("🚀 首次启动上下文网页服务器...")
|
||||
await context_manager.start_server()
|
||||
|
||||
# 添加消息到上下文并更新网页
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
await context_manager.add_message(chat_id, message)
|
||||
|
||||
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)
|
||||
|
|
@ -1,400 +0,0 @@
|
|||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
import random
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.person_info.person_info import Person, get_person_id
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
from typing import List
|
||||
|
||||
logger = get_logger("prompt")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
||||
Prompt("\n关于你们的关系,你需要知道:\n{relation_info}\n", "relation_prompt")
|
||||
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
|
||||
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
|
||||
你可以看见用户发送的弹幕,礼物和superchat
|
||||
{screen_info}
|
||||
{internal_state}
|
||||
|
||||
{relation_info_block}
|
||||
{memory_block}
|
||||
{expression_habits_block}
|
||||
|
||||
你现在的主要任务是和 {sender_name} 发送的弹幕聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||
|
||||
{sc_info}
|
||||
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
{time_block}
|
||||
这是你和{sender_name}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt}
|
||||
|
||||
对方最新发送的内容:{message_txt}
|
||||
{gift_info}
|
||||
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞。
|
||||
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
||||
你的发言:
|
||||
""",
|
||||
"s4u_prompt", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
|
||||
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
|
||||
你可以看见用户发送的弹幕,礼物和superchat
|
||||
你可以看见面前的屏幕,目前屏幕的内容是:
|
||||
{screen_info}
|
||||
|
||||
{memory_block}
|
||||
{expression_habits_block}
|
||||
|
||||
{sc_info}
|
||||
|
||||
{time_block}
|
||||
{chat_info_danmu}
|
||||
--------------------------------
|
||||
以上是你和弹幕的对话,与此同时,你在与QQ群友聊天,聊天记录如下:
|
||||
{chat_info_qq}
|
||||
--------------------------------
|
||||
你刚刚回复了QQ群,你内心的想法是:{mind}
|
||||
请根据你内心的想法,组织一条回复,在直播间进行发言,可以点名吐槽对象,让观众知道你在说谁
|
||||
{gift_info}
|
||||
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格。不要浮夸,有逻辑和条理。
|
||||
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。
|
||||
你的发言:
|
||||
""",
|
||||
"s4u_prompt_internal", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
def __init__(self):
|
||||
self.prompt_built = ""
|
||||
self.activate_messages = ""
|
||||
|
||||
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
|
||||
style_habits = []
|
||||
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions, _ = await expression_selector.select_suitable_expressions_llm(
|
||||
chat_stream.stream_id, chat_history, max_num=12, target_message=target
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
logger.debug(f" 使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
|
||||
style_habits_str = "\n".join(style_habits)
|
||||
|
||||
# 动态构建expression habits块
|
||||
expression_habits_block = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
||||
|
||||
return expression_habits_block
|
||||
|
||||
async def build_relation_info(self, chat_stream) -> str:
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
who_chat_in_group = []
|
||||
if is_group_chat:
|
||||
who_chat_in_group = get_recent_group_speaker(
|
||||
chat_stream.stream_id,
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
elif chat_stream.user_info:
|
||||
who_chat_in_group.append(
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
||||
)
|
||||
|
||||
relation_prompt = ""
|
||||
if global_config.relationship.enable_relationship and who_chat_in_group:
|
||||
# 将 (platform, user_id, nickname) 转换为 person_id
|
||||
person_ids = []
|
||||
for person in who_chat_in_group:
|
||||
person_id = get_person_id(person[0], person[1])
|
||||
person_ids.append(person_id)
|
||||
|
||||
# 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为
|
||||
relation_info_list = [Person(person_id=person_id).build_relationship() for person_id in person_ids]
|
||||
if relation_info := "".join(relation_info_list):
|
||||
relation_prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_prompt", relation_info=relation_info
|
||||
)
|
||||
return relation_prompt
|
||||
|
||||
async def build_memory_block(self, text: str) -> str:
|
||||
# 待更新记忆系统
|
||||
return ""
|
||||
|
||||
related_memory = await hippocampus_manager.get_memory_from_text(
|
||||
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
|
||||
related_memory_info = ""
|
||||
if related_memory:
|
||||
for memory in related_memory:
|
||||
related_memory_info += memory[1]
|
||||
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info)
|
||||
return ""
|
||||
|
||||
def build_chat_history_prompts(self, chat_stream: ChatStream, message: MessageRecvS4U):
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
|
||||
limit=300,
|
||||
)
|
||||
|
||||
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
|
||||
|
||||
core_dialogue_list: List[DatabaseMessages] = []
|
||||
background_dialogue_list: List[DatabaseMessages] = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||
|
||||
for msg in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg.user_info.user_id)
|
||||
if msg_user_id == bot_id:
|
||||
if msg.reply_to and talk_type == msg.reply_to:
|
||||
core_dialogue_list.append(msg)
|
||||
elif msg.reply_to and talk_type != msg.reply_to:
|
||||
background_dialogue_list.append(msg)
|
||||
# else:
|
||||
# background_dialogue_list.append(msg_dict)
|
||||
elif msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg)
|
||||
else:
|
||||
background_dialogue_list.append(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
|
||||
|
||||
background_dialogue_prompt = ""
|
||||
if background_dialogue_list:
|
||||
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
|
||||
background_dialogue_prompt_str = build_readable_messages(
|
||||
context_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
|
||||
|
||||
core_msg_str = ""
|
||||
if core_dialogue_list:
|
||||
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
|
||||
|
||||
first_msg = core_dialogue_list[0]
|
||||
start_speaking_user_id = first_msg.user_info.user_id
|
||||
if start_speaking_user_id == bot_id:
|
||||
last_speaking_user_id = bot_id
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
start_speaking_user_id = target_user_id
|
||||
last_speaking_user_id = start_speaking_user_id
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += (
|
||||
f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
|
||||
)
|
||||
|
||||
all_msg_seg_list = []
|
||||
for msg in core_dialogue_list[1:]:
|
||||
speaker = msg.user_info.user_id
|
||||
if speaker == last_speaking_user_id:
|
||||
msg_seg_str += (
|
||||
f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
|
||||
)
|
||||
else:
|
||||
msg_seg_str = f"{msg_seg_str}\n"
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
|
||||
if speaker == bot_id:
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||
last_speaking_user_id = speaker
|
||||
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
for msg in all_msg_seg_list:
|
||||
core_msg_str += msg
|
||||
|
||||
all_dialogue_history = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=20,
|
||||
)
|
||||
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
all_dialogue_history,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
|
||||
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
|
||||
|
||||
def build_gift_info(self, message: MessageRecvS4U):
|
||||
if message.is_gift:
|
||||
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
|
||||
else:
|
||||
if message.is_fake_gift:
|
||||
return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)"
|
||||
|
||||
return ""
|
||||
|
||||
def build_sc_info(self, message: MessageRecvS4U):
|
||||
super_chat_manager = get_super_chat_manager()
|
||||
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
|
||||
|
||||
async def build_prompt_normal(
|
||||
self,
|
||||
message: MessageRecvS4U,
|
||||
message_txt: str,
|
||||
) -> str:
|
||||
chat_stream = message.chat_stream
|
||||
|
||||
person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
|
||||
if message.chat_stream.user_info.user_nickname:
|
||||
if person_name:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
else:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
else:
|
||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
|
||||
self.build_relation_info(chat_stream),
|
||||
self.build_memory_block(message_txt),
|
||||
self.build_expression_habits(chat_stream, message_txt, sender_name),
|
||||
)
|
||||
|
||||
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts(
|
||||
chat_stream, message
|
||||
)
|
||||
|
||||
gift_info = self.build_gift_info(message)
|
||||
|
||||
sc_info = self.build_sc_info(message)
|
||||
|
||||
screen_info = screen_manager.get_screen_str()
|
||||
|
||||
internal_state = internal_manager.get_internal_state_str()
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id)
|
||||
|
||||
template_name = "s4u_prompt"
|
||||
|
||||
if not message.is_internal:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
time_block=time_block,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info_block,
|
||||
memory_block=memory_block,
|
||||
screen_info=screen_info,
|
||||
internal_state=internal_state,
|
||||
gift_info=gift_info,
|
||||
sc_info=sc_info,
|
||||
sender_name=sender_name,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
message_txt=message_txt,
|
||||
mood_state=mood.mood_state,
|
||||
)
|
||||
else:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"s4u_prompt_internal",
|
||||
time_block=time_block,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info_block,
|
||||
memory_block=memory_block,
|
||||
screen_info=screen_info,
|
||||
gift_info=gift_info,
|
||||
sc_info=sc_info,
|
||||
chat_info_danmu=all_dialogue_prompt,
|
||||
chat_info_qq=message.chat_info,
|
||||
mind=message.processed_plain_text,
|
||||
mood_state=mood.mood_state,
|
||||
)
|
||||
|
||||
# print(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
加权且不放回地随机抽取k个元素。
|
||||
|
||||
参数:
|
||||
items: 待抽取的元素列表
|
||||
weights: 每个元素对应的权重(与items等长,且为正数)
|
||||
k: 需要抽取的元素个数
|
||||
返回:
|
||||
selected: 按权重加权且不重复抽取的k个元素组成的列表
|
||||
|
||||
如果 items 中的元素不足 k 个,就只会返回所有可用的元素
|
||||
|
||||
实现思路:
|
||||
每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。
|
||||
这样保证了:
|
||||
1. count越大被选中概率越高
|
||||
2. 不会重复选中同一个元素
|
||||
"""
|
||||
selected = []
|
||||
pool = list(zip(items, weights, strict=False))
|
||||
for _ in range(min(k, len(pool))):
|
||||
total = sum(w for _, w in pool)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for idx, (item, weight) in enumerate(pool):
|
||||
upto += weight
|
||||
if upto >= r:
|
||||
selected.append(item)
|
||||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
init_prompt()
|
||||
prompt_builder = PromptBuilder()
|
||||
|
|
@ -1,203 +0,0 @@
|
|||
from typing import AsyncGenerator
|
||||
from src.llm_models.utils_model import LLMRequest, RequestType
|
||||
from src.llm_models.payload_content.message import MessageBuilder
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||
from src.common.logger import get_logger
|
||||
import re
|
||||
|
||||
|
||||
logger = get_logger("s4u_stream_generator")
|
||||
|
||||
|
||||
class S4UStreamGenerator:
|
||||
def __init__(self):
|
||||
# 使用LLMRequest替代AsyncOpenAIClient
|
||||
self.llm_request = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="s4u_replyer")
|
||||
|
||||
self.current_model_name = "unknown model"
|
||||
self.partial_response = ""
|
||||
|
||||
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
|
||||
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
||||
self.sentence_split_pattern = re.compile(
|
||||
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
|
||||
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符
|
||||
re.UNICODE | re.DOTALL,
|
||||
)
|
||||
|
||||
self.chat_stream = None
|
||||
|
||||
async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""):
|
||||
# person_id = PersonInfoManager.get_person_id(
|
||||
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
# )
|
||||
# person_info_manager = get_person_info_manager()
|
||||
# person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
# if message.chat_stream.user_info.user_nickname:
|
||||
# if person_name:
|
||||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
# else:
|
||||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
# else:
|
||||
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
# 构建prompt
|
||||
if previous_reply_context:
|
||||
message_txt = f"""
|
||||
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
|
||||
[你已经对上一条消息说的话]: {previous_reply_context}
|
||||
---
|
||||
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
|
||||
{message.processed_plain_text}
|
||||
"""
|
||||
return True, message_txt
|
||||
else:
|
||||
message_txt = message.processed_plain_text
|
||||
return False, message_txt
|
||||
|
||||
async def generate_response(
|
||||
self, message: MessageRecvS4U, previous_reply_context: str = ""
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""根据当前模型类型选择对应的生成函数"""
|
||||
# 从global_config中获取模型概率值并选择模型
|
||||
self.partial_response = ""
|
||||
message_txt = message.processed_plain_text
|
||||
if not message.is_internal:
|
||||
interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
|
||||
if interupted:
|
||||
message_txt = message_txt_added
|
||||
|
||||
message.chat_stream = self.chat_stream
|
||||
prompt = await prompt_builder.build_prompt_normal(
|
||||
message=message,
|
||||
message_txt=message_txt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
||||
) # noqa: E501
|
||||
|
||||
# 使用LLMRequest进行流式生成
|
||||
async for chunk in self._generate_response_with_llm_request(prompt):
|
||||
yield chunk
|
||||
|
||||
async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]:
|
||||
"""使用LLMRequest进行流式响应生成"""
|
||||
|
||||
# 构建消息
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
# 选择模型
|
||||
model_info, api_provider, client = self.llm_request._select_model()
|
||||
self.current_model_name = model_info.name
|
||||
|
||||
# 如果模型支持强制流式模式,使用真正的流式处理
|
||||
if model_info.force_stream_mode:
|
||||
# 简化流式处理:直接使用LLMRequest的流式功能
|
||||
try:
|
||||
# 直接调用LLMRequest的流式处理
|
||||
response = await self.llm_request._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
)
|
||||
|
||||
# 处理响应内容
|
||||
content = response.content or ""
|
||||
if content:
|
||||
# 将内容按句子分割并输出
|
||||
async for chunk in self._process_content_streaming(content):
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式请求执行失败: {e}")
|
||||
# 如果流式请求失败,回退到普通模式
|
||||
response = await self.llm_request._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
)
|
||||
content = response.content or ""
|
||||
async for chunk in self._process_content_streaming(content):
|
||||
yield chunk
|
||||
|
||||
else:
|
||||
# 如果不支持流式,使用普通方式然后模拟流式输出
|
||||
response = await self.llm_request._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
)
|
||||
|
||||
content = response.content or ""
|
||||
async for chunk in self._process_content_streaming(content):
|
||||
yield chunk
|
||||
|
||||
async def _process_buffer_streaming(self, buffer: str) -> AsyncGenerator[str, None]:
|
||||
"""实时处理缓冲区内容,输出完整句子"""
|
||||
# 使用正则表达式匹配完整句子
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
sentence = match.group(0).strip()
|
||||
if sentence and match.end(0) <= len(buffer):
|
||||
# 检查句子是否完整(以标点符号结尾)
|
||||
if sentence.endswith(("。", "!", "?", ".", "!", "?")):
|
||||
if sentence not in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||
self.partial_response += sentence
|
||||
yield sentence
|
||||
|
||||
async def _process_content_streaming(self, content: str) -> AsyncGenerator[str, None]:
|
||||
"""处理内容进行流式输出(用于非流式模型的模拟流式输出)"""
|
||||
buffer = content
|
||||
punctuation_buffer = ""
|
||||
|
||||
# 使用正则表达式匹配句子
|
||||
last_match_end = 0
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
sentence = match.group(0).strip()
|
||||
if sentence:
|
||||
# 检查是否只是一个标点符号
|
||||
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||
punctuation_buffer += sentence
|
||||
else:
|
||||
# 发送之前累积的标点和当前句子
|
||||
to_yield = punctuation_buffer + sentence
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
|
||||
self.partial_response += to_yield
|
||||
yield to_yield
|
||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||
|
||||
last_match_end = match.end(0)
|
||||
|
||||
# 发送缓冲区中剩余的任何内容
|
||||
remaining = buffer[last_match_end:].strip()
|
||||
to_yield = (punctuation_buffer + remaining).strip()
|
||||
if to_yield:
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
if to_yield:
|
||||
self.partial_response += to_yield
|
||||
yield to_yield
|
||||
|
||||
async def _generate_response_with_model(
|
||||
self,
|
||||
prompt: str,
|
||||
client,
|
||||
model_name: str,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""保留原有方法签名以保持兼容性,但重定向到新的实现"""
|
||||
async for chunk in self._generate_response_with_llm_request(prompt):
|
||||
yield chunk
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
"""
|
||||
视线管理系统使用说明:
|
||||
|
||||
1. 视线状态:
|
||||
- wandering: 随意看
|
||||
- danmu: 看弹幕
|
||||
- lens: 看镜头
|
||||
|
||||
2. 状态切换逻辑:
|
||||
- 收到消息时 → 切换为看弹幕,立即发送更新
|
||||
- 开始生成回复时 → 切换为看镜头或随意,立即发送更新
|
||||
- 生成完毕后 → 看弹幕1秒,然后回到看镜头直到有新消息,状态变化时立即发送更新
|
||||
|
||||
3. 使用方法:
|
||||
# 获取视线管理器
|
||||
watching = watching_manager.get_watching_by_chat_id(chat_id)
|
||||
|
||||
# 收到消息时调用
|
||||
await watching.on_message_received()
|
||||
|
||||
# 开始生成回复时调用
|
||||
await watching.on_reply_start()
|
||||
|
||||
# 生成回复完毕时调用
|
||||
await watching.on_reply_finished()
|
||||
|
||||
4. 自动更新系统:
|
||||
- 状态变化时立即发送type为"watching",data为状态值的websocket消息
|
||||
- 使用定时器自动处理状态转换(如看弹幕时间结束后自动切换到看镜头)
|
||||
- 无需定期检查,所有状态变化都是事件驱动的
|
||||
"""
|
||||
|
||||
logger = get_logger("watching")
|
||||
|
||||
HEAD_CODE = {
|
||||
"看向上方": "(0,0.5,0)",
|
||||
"看向下方": "(0,-0.5,0)",
|
||||
"看向左边": "(-1,0,0)",
|
||||
"看向右边": "(1,0,0)",
|
||||
"随意朝向": "random",
|
||||
"看向摄像机": "camera",
|
||||
"注视对方": "(0,0,0)",
|
||||
"看向正前方": "(0,0,0)",
|
||||
}
|
||||
|
||||
|
||||
class ChatWatching:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
|
||||
async def on_reply_start(self):
|
||||
"""开始生成回复时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_reply_finished(self):
|
||||
"""生成回复完毕时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_thinking_finished(self):
|
||||
"""思考完毕时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="finish_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_message_received(self):
|
||||
"""收到消息时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_internal_message_start(self):
|
||||
"""收到消息时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
class WatchingManager:
|
||||
def __init__(self):
|
||||
self.watching_list: list[ChatWatching] = []
|
||||
"""当前视线状态列表"""
|
||||
self.task_started: bool = False
|
||||
|
||||
def get_watching_by_chat_id(self, chat_id: str) -> ChatWatching:
|
||||
"""获取或创建聊天对应的视线管理器"""
|
||||
for watching in self.watching_list:
|
||||
if watching.chat_id == chat_id:
|
||||
return watching
|
||||
|
||||
new_watching = ChatWatching(chat_id)
|
||||
self.watching_list.append(new_watching)
|
||||
logger.info(f"为chat {chat_id}创建新的视线管理器")
|
||||
|
||||
return new_watching
|
||||
|
||||
|
||||
# 全局视线管理器实例
|
||||
watching_manager = WatchingManager()
|
||||
"""全局视线管理器"""
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
class ScreenManager:
|
||||
def __init__(self):
|
||||
self.now_screen = str()
|
||||
|
||||
def set_screen(self, screen_str: str):
|
||||
self.now_screen = screen_str
|
||||
|
||||
def get_screen(self):
|
||||
return self.now_screen
|
||||
|
||||
def get_screen_str(self):
|
||||
return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}"
|
||||
|
||||
|
||||
screen_manager = ScreenManager()
|
||||
|
|
@ -1,303 +0,0 @@
|
|||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
|
||||
# 全局SuperChat管理器实例
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
|
||||
logger = get_logger("super_chat_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuperChatRecord:
|
||||
"""SuperChat记录数据类"""
|
||||
|
||||
user_id: str
|
||||
user_nickname: str
|
||||
platform: str
|
||||
chat_id: str
|
||||
price: float
|
||||
message_text: str
|
||||
timestamp: float
|
||||
expire_time: float
|
||||
group_name: Optional[str] = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""检查SuperChat是否已过期"""
|
||||
return time.time() > self.expire_time
|
||||
|
||||
def remaining_time(self) -> float:
|
||||
"""获取剩余时间(秒)"""
|
||||
return max(0, self.expire_time - time.time())
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"user_nickname": self.user_nickname,
|
||||
"platform": self.platform,
|
||||
"chat_id": self.chat_id,
|
||||
"price": self.price,
|
||||
"message_text": self.message_text,
|
||||
"timestamp": self.timestamp,
|
||||
"expire_time": self.expire_time,
|
||||
"group_name": self.group_name,
|
||||
"remaining_time": self.remaining_time(),
|
||||
}
|
||||
|
||||
|
||||
class SuperChatManager:
|
||||
"""SuperChat管理器,负责管理和跟踪SuperChat消息"""
|
||||
|
||||
def __init__(self):
|
||||
self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._is_initialized = False
|
||||
logger.info("SuperChat管理器已初始化")
|
||||
|
||||
def _ensure_cleanup_task_started(self):
|
||||
"""确保清理任务已启动(延迟启动)"""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._cleanup_task = loop.create_task(self._cleanup_expired_superchats())
|
||||
self._is_initialized = True
|
||||
logger.info("SuperChat清理任务已启动")
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,稍后再启动
|
||||
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动清理任务(已弃用,保留向后兼容)"""
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
async def _cleanup_expired_superchats(self):
|
||||
"""定期清理过期的SuperChat"""
|
||||
while True:
|
||||
try:
|
||||
total_removed = 0
|
||||
|
||||
for chat_id in list(self.super_chats.keys()):
|
||||
original_count = len(self.super_chats[chat_id])
|
||||
# 移除过期的SuperChat
|
||||
self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
|
||||
removed_count = original_count - len(self.super_chats[chat_id])
|
||||
total_removed += removed_count
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
|
||||
|
||||
# 如果列表为空,删除该聊天的记录
|
||||
if not self.super_chats[chat_id]:
|
||||
del self.super_chats[chat_id]
|
||||
|
||||
if total_removed > 0:
|
||||
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
|
||||
|
||||
# 每30秒检查一次
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
|
||||
await asyncio.sleep(60) # 出错时等待更长时间
|
||||
|
||||
def _calculate_expire_time(self, price: float) -> float:
|
||||
"""根据SuperChat金额计算过期时间"""
|
||||
current_time = time.time()
|
||||
|
||||
# 根据金额阶梯设置不同的存活时间
|
||||
if price >= 500:
|
||||
# 500元以上:保持4小时
|
||||
duration = 4 * 3600
|
||||
elif price >= 200:
|
||||
# 200-499元:保持2小时
|
||||
duration = 2 * 3600
|
||||
elif price >= 100:
|
||||
# 100-199元:保持1小时
|
||||
duration = 1 * 3600
|
||||
elif price >= 50:
|
||||
# 50-99元:保持30分钟
|
||||
duration = 30 * 60
|
||||
elif price >= 20:
|
||||
# 20-49元:保持15分钟
|
||||
duration = 15 * 60
|
||||
elif price >= 10:
|
||||
# 10-19元:保持10分钟
|
||||
duration = 10 * 60
|
||||
else:
|
||||
# 10元以下:保持5分钟
|
||||
duration = 5 * 60
|
||||
|
||||
return current_time + duration
|
||||
|
||||
async def add_superchat(self, message: MessageRecvS4U) -> None:
|
||||
"""添加新的SuperChat记录"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
if not message.is_superchat or not message.superchat_price:
|
||||
logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
|
||||
return
|
||||
|
||||
try:
|
||||
price = float(message.superchat_price)
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"无效的SuperChat价格: {message.superchat_price}")
|
||||
return
|
||||
|
||||
user_info = message.message_info.user_info
|
||||
group_info = message.message_info.group_info
|
||||
chat_id = getattr(message, "chat_stream", None)
|
||||
if chat_id:
|
||||
chat_id = chat_id.stream_id
|
||||
else:
|
||||
# 生成chat_id的备用方法
|
||||
chat_id = f"{message.message_info.platform}_{user_info.user_id}"
|
||||
if group_info:
|
||||
chat_id = f"{message.message_info.platform}_{group_info.group_id}"
|
||||
|
||||
expire_time = self._calculate_expire_time(price)
|
||||
|
||||
record = SuperChatRecord(
|
||||
user_id=user_info.user_id,
|
||||
user_nickname=user_info.user_nickname,
|
||||
platform=message.message_info.platform,
|
||||
chat_id=chat_id,
|
||||
price=price,
|
||||
message_text=message.superchat_message_text or "",
|
||||
timestamp=message.message_info.time,
|
||||
expire_time=expire_time,
|
||||
group_name=group_info.group_name if group_info else None,
|
||||
)
|
||||
|
||||
# 添加到对应聊天的SuperChat列表
|
||||
if chat_id not in self.super_chats:
|
||||
self.super_chats[chat_id] = []
|
||||
|
||||
self.super_chats[chat_id].append(record)
|
||||
|
||||
# 按价格降序排序(价格高的在前)
|
||||
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
|
||||
|
||||
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
|
||||
|
||||
def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]:
|
||||
"""获取指定聊天的所有有效SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
if chat_id not in self.super_chats:
|
||||
return []
|
||||
|
||||
# 过滤掉过期的SuperChat
|
||||
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
return valid_superchats
|
||||
|
||||
def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]:
|
||||
"""获取所有有效的SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
result = {}
|
||||
for chat_id, superchats in self.super_chats.items():
|
||||
valid_superchats = [sc for sc in superchats if not sc.is_expired()]
|
||||
if valid_superchats:
|
||||
result[chat_id] = valid_superchats
|
||||
return result
|
||||
|
||||
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
|
||||
"""构建SuperChat显示字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return ""
|
||||
|
||||
# 限制显示数量
|
||||
display_superchats = superchats[:max_count]
|
||||
|
||||
lines = ["📢 当前有效超级弹幕:"]
|
||||
for i, sc in enumerate(display_superchats, 1):
|
||||
remaining_minutes = int(sc.remaining_time() / 60)
|
||||
remaining_seconds = int(sc.remaining_time() % 60)
|
||||
|
||||
time_display = (
|
||||
f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒"
|
||||
)
|
||||
|
||||
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
|
||||
if len(line) > 100: # 限制单行长度
|
||||
line = f"{line[:97]}..."
|
||||
line += f" (剩余{time_display})"
|
||||
lines.append(line)
|
||||
|
||||
if len(superchats) > max_count:
|
||||
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_superchat_summary_string(self, chat_id: str) -> str:
|
||||
"""构建SuperChat摘要字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return "当前没有有效的超级弹幕"
|
||||
lines = []
|
||||
for sc in superchats:
|
||||
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
|
||||
if len(single_sc_str) > 100:
|
||||
single_sc_str = f"{single_sc_str[:97]}..."
|
||||
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
|
||||
lines.append(single_sc_str)
|
||||
|
||||
total_amount = sum(sc.price for sc in superchats)
|
||||
count = len(superchats)
|
||||
highest_amount = max(sc.price for sc in superchats)
|
||||
|
||||
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}元"
|
||||
if lines:
|
||||
final_str += "\n" + "\n".join(lines)
|
||||
return final_str
|
||||
|
||||
def get_superchat_statistics(self, chat_id: str) -> dict:
|
||||
"""获取SuperChat统计信息"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0}
|
||||
|
||||
amounts = [sc.price for sc in superchats]
|
||||
|
||||
return {
|
||||
"count": len(superchats),
|
||||
"total_amount": sum(amounts),
|
||||
"average_amount": sum(amounts) / len(amounts),
|
||||
"highest_amount": max(amounts),
|
||||
"lowest_amount": min(amounts),
|
||||
}
|
||||
|
||||
async def shutdown(self): # sourcery skip: use-contextlib-suppress
|
||||
"""关闭管理器,清理资源"""
|
||||
if self._cleanup_task and not self._cleanup_task.done():
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("SuperChat管理器已关闭")
|
||||
|
||||
|
||||
# sourcery skip: assign-if-exp
|
||||
if s4u_config.enable_s4u:
|
||||
super_chat_manager = SuperChatManager()
|
||||
else:
|
||||
super_chat_manager = None
|
||||
|
||||
|
||||
def get_super_chat_manager() -> SuperChatManager:
|
||||
"""获取全局SuperChat管理器实例"""
|
||||
|
||||
return super_chat_manager
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"]
|
||||
|
||||
|
||||
async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""):
|
||||
prompt = f"""
|
||||
{chat_history}
|
||||
以上是对方的发言:
|
||||
|
||||
对这个发言,你的心情是:{emotion}
|
||||
对上面的发言,你的回复是:{text}
|
||||
请判断时是否要伴随回复做头部动作,你可以选择:
|
||||
|
||||
不做额外动作
|
||||
点头一次
|
||||
点头两次
|
||||
摇头
|
||||
歪脑袋
|
||||
低头望向一边
|
||||
|
||||
请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。"""
|
||||
model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||
|
||||
try:
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7)
|
||||
logger.info(f"response: {response}")
|
||||
|
||||
head_action = response if response in head_actions_list else "不做额外动作"
|
||||
await send_api.custom_to_stream(
|
||||
message_type="head_action",
|
||||
content=head_action,
|
||||
stream_id=chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"yes_or_no_head error: {e}")
|
||||
return "不做额外动作"
|
||||
|
|
@ -1,368 +0,0 @@
|
|||
import os
|
||||
import tomlkit
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from dataclasses import dataclass, fields, MISSING, field
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("s4u_config")
|
||||
|
||||
|
||||
# 新增:兼容dict和tomlkit Table
|
||||
def is_dict_like(obj):
|
||||
return isinstance(obj, (dict, Table))
|
||||
|
||||
|
||||
# 新增:递归将Table转为dict
|
||||
def table_to_dict(obj):
|
||||
if isinstance(obj, Table):
|
||||
return {k: table_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, dict):
|
||||
return {k: table_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [table_to_dict(i) for i in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
# 获取mais4u模块目录
|
||||
MAIS4U_ROOT = os.path.dirname(__file__)
|
||||
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
|
||||
TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml")
|
||||
CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml")
|
||||
|
||||
# S4U配置版本
|
||||
S4U_VERSION = "1.1.0"
|
||||
|
||||
T = TypeVar("T", bound="S4UConfigBase")
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UConfigBase:
|
||||
"""S4U配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
|
||||
"""从字典加载配置字段"""
|
||||
data = table_to_dict(data) # 递归转dict,兼容tomlkit Table
|
||||
if not is_dict_like(data):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
field_type = f.type
|
||||
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
|
||||
except TypeError as e:
|
||||
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
|
||||
|
||||
return cls(**init_args)
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
"""转换字段值为指定类型"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
|
||||
if not is_dict_like(value):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
return field_type.from_dict(value)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_type_args = get_args(field_type)
|
||||
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
if (
|
||||
field_type_args
|
||||
and isinstance(field_type_args[0], type)
|
||||
and issubclass(field_type_args[0], S4UConfigBase)
|
||||
):
|
||||
return [field_type_args[0].from_dict(item) for item in value]
|
||||
return [cls._convert_field(item, field_type_args[0]) for item in value]
|
||||
elif field_origin_type is set:
|
||||
return {cls._convert_field(item, field_type_args[0]) for item in value}
|
||||
elif field_origin_type is tuple:
|
||||
if len(value) != len(field_type_args):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
|
||||
|
||||
if field_origin_type is dict:
|
||||
if not is_dict_like(value):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if len(field_type_args) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_type_args
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理基础类型,例如 int, str 等
|
||||
if field_origin_type is type(None) and value is None: # 处理Optional类型
|
||||
return None
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal or get_origin(field_type) is Literal:
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
if field_type is Any or isinstance(value, field_type):
|
||||
return value
|
||||
|
||||
# 其他类型,尝试直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UModelConfig(S4UConfigBase):
|
||||
"""S4U模型配置类"""
|
||||
|
||||
# 主要对话模型配置
|
||||
chat: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""主要对话模型配置"""
|
||||
|
||||
# 规划模型配置(原model_motion)
|
||||
motion: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""规划模型配置"""
|
||||
|
||||
# 情感分析模型配置
|
||||
emotion: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""情感分析模型配置"""
|
||||
|
||||
# 记忆模型配置
|
||||
memory: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""记忆模型配置"""
|
||||
|
||||
# 工具使用模型配置
|
||||
tool_use: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""工具使用模型配置"""
|
||||
|
||||
# 嵌入模型配置
|
||||
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""嵌入模型配置"""
|
||||
|
||||
# 视觉语言模型配置
|
||||
vlm: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""视觉语言模型配置"""
|
||||
|
||||
# 知识库模型配置
|
||||
knowledge: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""知识库模型配置"""
|
||||
|
||||
# 实体提取模型配置
|
||||
entity_extract: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""实体提取模型配置"""
|
||||
|
||||
# 问答模型配置
|
||||
qa: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""问答模型配置"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UConfig(S4UConfigBase):
|
||||
"""S4U聊天系统配置类"""
|
||||
|
||||
enable_s4u: bool = False
|
||||
"""是否启用S4U聊天系统"""
|
||||
|
||||
message_timeout_seconds: int = 120
|
||||
"""普通消息存活时间(秒),超过此时间的消息将被丢弃"""
|
||||
|
||||
at_bot_priority_bonus: float = 100.0
|
||||
"""@机器人时的优先级加成分数"""
|
||||
|
||||
recent_message_keep_count: int = 6
|
||||
"""保留最近N条消息,超出范围的普通消息将被移除"""
|
||||
|
||||
typing_delay: float = 0.1
|
||||
"""打字延迟时间(秒),模拟真实打字速度"""
|
||||
|
||||
chars_per_second: float = 15.0
|
||||
"""每秒字符数,用于计算动态打字延迟"""
|
||||
|
||||
min_typing_delay: float = 0.2
|
||||
"""最小打字延迟(秒)"""
|
||||
|
||||
max_typing_delay: float = 2.0
|
||||
"""最大打字延迟(秒)"""
|
||||
|
||||
enable_dynamic_typing_delay: bool = False
|
||||
"""是否启用基于文本长度的动态打字延迟"""
|
||||
|
||||
vip_queue_priority: bool = True
|
||||
"""是否启用VIP队列优先级系统"""
|
||||
|
||||
enable_message_interruption: bool = True
|
||||
"""是否允许高优先级消息中断当前回复"""
|
||||
|
||||
enable_old_message_cleanup: bool = True
|
||||
"""是否自动清理过旧的普通消息"""
|
||||
|
||||
enable_streaming_output: bool = True
|
||||
"""是否启用流式输出,false时全部生成后一次性发送"""
|
||||
|
||||
max_context_message_length: int = 20
|
||||
"""上下文消息最大长度"""
|
||||
|
||||
max_core_message_length: int = 30
|
||||
"""核心消息最大长度"""
|
||||
|
||||
# 模型配置
|
||||
models: S4UModelConfig = field(default_factory=S4UModelConfig)
|
||||
"""S4U模型配置"""
|
||||
|
||||
# 兼容性字段,保持向后兼容
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UGlobalConfig(S4UConfigBase):
|
||||
"""S4U总配置类"""
|
||||
|
||||
s4u: S4UConfig
|
||||
S4U_VERSION: str = S4U_VERSION
|
||||
|
||||
|
||||
def update_s4u_config():
|
||||
"""更新S4U配置文件"""
|
||||
# 创建配置目录(如果不存在)
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
|
||||
# 检查模板文件是否存在
|
||||
if not os.path.exists(TEMPLATE_PATH):
|
||||
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||||
logger.error("请确保模板文件存在后重新运行")
|
||||
raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(CONFIG_PATH):
|
||||
logger.info("S4U配置文件不存在,从模板创建新配置")
|
||||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||||
logger.info(f"已创建S4U配置文件: {CONFIG_PATH}")
|
||||
return
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version") # type: ignore
|
||||
new_version = new_config["inner"].get("version") # type: ignore
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
else:
|
||||
logger.info("S4U配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建备份目录
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
os.makedirs(old_config_dir, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml")
|
||||
|
||||
# 移动旧配置文件到old目录
|
||||
shutil.move(CONFIG_PATH, old_backup_path)
|
||||
logger.info(f"已备份旧S4U配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||||
logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
|
||||
update_dict(target_value, value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并S4U新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
|
||||
logger.info("S4U配置文件更新完成")
|
||||
|
||||
|
||||
def load_s4u_config(config_path: str) -> S4UGlobalConfig:
|
||||
"""
|
||||
加载S4U配置文件
|
||||
:param config_path: 配置文件路径
|
||||
:return: S4UGlobalConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建S4UGlobalConfig对象
|
||||
try:
|
||||
return S4UGlobalConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("S4U配置文件解析失败")
|
||||
raise e
|
||||
|
||||
# 初始化S4U配置
|
||||
|
||||
|
||||
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
||||
update_s4u_config()
|
||||
|
||||
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
||||
logger.info("S4U配置文件加载完成!")
|
||||
|
||||
s4u_config: S4UConfig = s4u_config_main.s4u
|
||||
|
|
@ -0,0 +1,797 @@
|
|||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import random
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.message_api import build_readable_messages
|
||||
from src.plugin_system.apis.message_api import get_raw_msg_by_timestamp_with_chat
|
||||
from json_repair import repair_json
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
|
||||
from .memory_utils import (
|
||||
find_best_matching_memory,
|
||||
check_title_exists_fuzzy,
|
||||
get_all_titles,
|
||||
get_memory_titles_by_chat_id_weighted,
|
||||
|
||||
)
|
||||
|
||||
logger = get_logger("memory")
|
||||
|
||||
class MemoryChest:
|
||||
def __init__(self):
|
||||
|
||||
self.LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory_chest",
|
||||
)
|
||||
|
||||
self.LLMRequest_build = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="memory_chest_build",
|
||||
)
|
||||
|
||||
|
||||
self.running_content_list = {} # {chat_id: {"content": running_content, "last_update_time": timestamp, "create_time": timestamp}}
|
||||
self.fetched_memory_list = [] # [(chat_id, (question, answer, timestamp)), ...]
|
||||
|
||||
async def build_running_content(self, chat_id: str = None) -> str:
|
||||
"""
|
||||
构建记忆仓库的运行内容
|
||||
|
||||
Args:
|
||||
message_str: 消息内容
|
||||
chat_id: 聊天ID,用于提取对应的运行内容
|
||||
|
||||
Returns:
|
||||
str: 构建后的运行内容
|
||||
"""
|
||||
# 检查是否需要更新:基于消息数量和最新消息时间差的智能更新机制
|
||||
#
|
||||
# 更新机制说明:
|
||||
# 1. 消息数量 > 100:直接触发更新(高频消息场景)
|
||||
# 2. 消息数量 > 70 且最新消息时间差 > 30秒:触发更新(中高频消息场景)
|
||||
# 3. 消息数量 > 50 且最新消息时间差 > 60秒:触发更新(中频消息场景)
|
||||
# 4. 消息数量 > 30 且最新消息时间差 > 300秒:触发更新(低频消息场景)
|
||||
#
|
||||
# 设计理念:
|
||||
# - 消息越密集,时间阈值越短,确保及时更新记忆
|
||||
# - 消息越稀疏,时间阈值越长,避免频繁无意义的更新
|
||||
# - 通过最新消息时间差判断消息活跃度,而非简单的总时间差
|
||||
# - 平衡更新频率与性能,在保证记忆及时性的同时减少计算开销
|
||||
if chat_id not in self.running_content_list:
|
||||
self.running_content_list[chat_id] = {
|
||||
"content": "",
|
||||
"last_update_time": time.time(),
|
||||
"create_time": time.time()
|
||||
}
|
||||
|
||||
should_update = True
|
||||
if chat_id and chat_id in self.running_content_list:
|
||||
last_update_time = self.running_content_list[chat_id]["last_update_time"]
|
||||
current_time = time.time()
|
||||
# 使用message_api获取消息数量
|
||||
message_list = get_raw_msg_by_timestamp_with_chat(
|
||||
timestamp_start=last_update_time,
|
||||
timestamp_end=current_time,
|
||||
chat_id=chat_id,
|
||||
limit=global_config.chat.max_context_size * 2,
|
||||
)
|
||||
|
||||
new_messages_count = len(message_list)
|
||||
|
||||
# 获取最新消息的时间戳
|
||||
latest_message_time = last_update_time
|
||||
if message_list:
|
||||
# 假设消息列表按时间排序,取最后一条消息的时间戳
|
||||
latest_message = message_list[-1]
|
||||
if hasattr(latest_message, 'timestamp'):
|
||||
latest_message_time = latest_message.timestamp
|
||||
elif isinstance(latest_message, dict) and 'timestamp' in latest_message:
|
||||
latest_message_time = latest_message['timestamp']
|
||||
|
||||
# 计算最新消息时间与现在时间的差(秒)
|
||||
latest_message_time_diff = current_time - latest_message_time
|
||||
|
||||
# 智能更新条件判断 - 按优先级从高到低检查
|
||||
should_update = False
|
||||
update_reason = ""
|
||||
|
||||
if global_config.memory.memory_build_frequency > 0:
|
||||
if new_messages_count > 100/global_config.memory.memory_build_frequency:
|
||||
# 条件1:消息数量 > 100,直接触发更新
|
||||
# 适用场景:群聊刷屏、高频讨论等消息密集场景
|
||||
# 无需时间限制,确保重要信息不被遗漏
|
||||
should_update = True
|
||||
update_reason = f"消息数量 {new_messages_count} > 100,直接触发更新"
|
||||
elif new_messages_count > 70/global_config.memory.memory_build_frequency and latest_message_time_diff > 30:
|
||||
# 条件2:消息数量 > 70 且最新消息时间差 > 30秒
|
||||
# 适用场景:中高频讨论,但需要确保消息流已稳定
|
||||
# 30秒的时间差确保不是正在进行的实时对话
|
||||
should_update = True
|
||||
update_reason = f"消息数量 {new_messages_count} > 70 且最新消息时间差 {latest_message_time_diff:.1f}s > 30s"
|
||||
elif new_messages_count > 50/global_config.memory.memory_build_frequency and latest_message_time_diff > 60:
|
||||
# 条件3:消息数量 > 50 且最新消息时间差 > 60秒
|
||||
# 适用场景:中等频率讨论,等待1分钟确保对话告一段落
|
||||
# 平衡及时性与稳定性
|
||||
should_update = True
|
||||
update_reason = f"消息数量 {new_messages_count} > 50 且最新消息时间差 {latest_message_time_diff:.1f}s > 60s"
|
||||
elif new_messages_count > 30/global_config.memory.memory_build_frequency and latest_message_time_diff > 300:
|
||||
# 条件4:消息数量 > 30 且最新消息时间差 > 300秒(5分钟)
|
||||
# 适用场景:低频但有一定信息量的讨论
|
||||
# 5分钟的时间差确保对话完全结束,避免频繁更新
|
||||
should_update = True
|
||||
update_reason = f"消息数量 {new_messages_count} > 30 且最新消息时间差 {latest_message_time_diff:.1f}s > 300s"
|
||||
|
||||
logger.debug(f"chat_id {chat_id} 更新检查: {update_reason if should_update else f'消息数量 {new_messages_count},最新消息时间差 {latest_message_time_diff:.1f}s,不满足更新条件'}")
|
||||
|
||||
|
||||
if should_update:
|
||||
# 如果有chat_id,先提取对应的running_content
|
||||
message_str = build_readable_messages(
|
||||
message_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=False,
|
||||
remove_emoji_stickers=True,
|
||||
)
|
||||
|
||||
# 随机从格式示例列表中选取若干行用于提示
|
||||
format_candidates = [
|
||||
"[概念] 是 [概念的含义(简短描述,不超过十个字)]",
|
||||
"[概念] 不是 [对概念的负面含义(简短描述,不超过十个字)]",
|
||||
"[概念1] 与 [概念2] 是 [概念1和概念2的关联(简短描述,不超过二十个字)]",
|
||||
"[概念1] 包含 [概念2] 和 [概念3]",
|
||||
"[概念1] 属于 [概念2]",
|
||||
"[概念1] 的例子是 [例子1] 和 [例子2]",
|
||||
"[概念] 的特征是 [特征1]、[特征2]",
|
||||
"[概念1] 导致 [概念2]",
|
||||
"[概念1] 需要 [条件1] 和 [条件2]",
|
||||
"[概念1] 的用途是 [用途1] 和 [用途2]",
|
||||
"[概念1] 与 [概念2] 的区别是 [区别点]",
|
||||
"[概念] 的别名是 [别名]",
|
||||
"[概念1] 包括但不限于 [概念2]、[概念3]",
|
||||
"[概念] 的反义是 [反义概念]",
|
||||
"[概念] 的组成有 [部分1]、[部分2]",
|
||||
"[概念] 出现于 [时间或场景]",
|
||||
"[概念] 的方法有 [方法1]、[方法2]",
|
||||
]
|
||||
|
||||
selected_count = random.randint(3, 6)
|
||||
selected_lines = random.sample(format_candidates, selected_count)
|
||||
format_section = "\n".join(selected_lines) + "\n......(不要包含中括号)"
|
||||
|
||||
prompt = f"""
|
||||
以下是一段你参与的聊天记录,请你在其中总结出记忆:
|
||||
|
||||
<聊天记录>
|
||||
{message_str}
|
||||
</聊天记录>
|
||||
聊天记录中可能包含有效信息,也可能信息密度很低,请你根据聊天记录中的信息,总结出记忆内容
|
||||
--------------------------------
|
||||
对[图片]的处理:
|
||||
1.除非与文本有关,不要将[图片]的内容整合到记忆中
|
||||
2.如果图片与某个概念相关,将图片中的关键内容也整合到记忆中,不要写入图片原文,例如:
|
||||
|
||||
聊天记录(与图片有关):
|
||||
用户说:[图片1:这是一个黄色的龙形状玩偶,被一只手拿着。]
|
||||
用户说:这个玩偶看起来很可爱,是我新买的奶龙
|
||||
总结的记忆内容:
|
||||
黄色的龙形状玩偶 是 奶龙
|
||||
|
||||
聊天记录(概念与图片无关):
|
||||
用户说:[图片1:这是一个台电脑,屏幕上显示了某种游戏。]
|
||||
用户说:使命召唤今天发售了新一代,有没有人玩
|
||||
总结的记忆内容:
|
||||
使命召唤新一代 是 最新发售的游戏
|
||||
|
||||
请主要关注概念和知识或者时效性较强的信息!!,而不是聊天的琐事
|
||||
1.不要关注诸如某个用户做了什么,说了什么,不要关注某个用户的行为,而是关注其中的概念性信息
|
||||
2.概念要求精确,不啰嗦,像科普读物或教育课本那样
|
||||
3.记忆为一段纯文本,逻辑清晰,指出概念的含义,并说明关系
|
||||
|
||||
记忆内容的格式,你必须仿照下面的格式,但不一定全部使用:
|
||||
{format_section}
|
||||
|
||||
请仿照上述格式输出,每个知识点一句话。输出成一段平文本
|
||||
现在请你输出,不要输出其他内容,注意一定要直白,白话,口语化不要浮夸,修辞。:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库构建运行内容 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库构建运行内容 prompt: {prompt}")
|
||||
|
||||
running_content, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
||||
|
||||
print(f"prompt: {prompt}\n记忆仓库构建运行内容: {running_content}")
|
||||
|
||||
# 直接保存:每次构建后立即入库,并刷新时间戳窗口
|
||||
if chat_id and running_content:
|
||||
await self._save_to_database_and_clear(chat_id, running_content)
|
||||
|
||||
|
||||
return running_content
|
||||
|
||||
|
||||
async def get_answer_by_question(self, chat_id: str = "", question: str = "") -> str:
|
||||
"""
|
||||
根据问题获取答案
|
||||
"""
|
||||
logger.info(f"正在回忆问题答案: {question}")
|
||||
|
||||
title = await self.select_title_by_question(question)
|
||||
|
||||
if not title:
|
||||
return ""
|
||||
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title == title:
|
||||
content = memory.content
|
||||
|
||||
if random.random() < 0.5:
|
||||
type = "要求原文能够较为全面的回答问题"
|
||||
else:
|
||||
type = "要求提取简短的内容"
|
||||
|
||||
prompt = f"""
|
||||
目标文段:
|
||||
{content}
|
||||
|
||||
你现在需要从目标文段中找出合适的信息来回答问题:{question}
|
||||
请务必从目标文段中提取相关信息的**原文**并输出,{type}
|
||||
如果没有原文能够回答问题,输出"无有效信息"即可,不要输出其他内容:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库获取答案 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库获取答案 prompt: {prompt}")
|
||||
|
||||
answer, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
||||
|
||||
if "无有效" in answer or "无有效信息" in answer or "无信息" in answer:
|
||||
logger.info(f"没有能够回答{question}的记忆")
|
||||
return ""
|
||||
|
||||
logger.info(f"记忆仓库对问题 “{question}” 获取答案: {answer}")
|
||||
|
||||
# 将问题和答案存到fetched_memory_list
|
||||
if chat_id and answer:
|
||||
self.fetched_memory_list.append((chat_id, (question, answer, time.time())))
|
||||
|
||||
# 清理fetched_memory_list
|
||||
self._cleanup_fetched_memory_list()
|
||||
|
||||
return answer
|
||||
|
||||
def get_chat_memories_as_string(self, chat_id: str) -> str:
|
||||
"""
|
||||
获取某个chat_id的所有记忆,并构建成字符串
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 格式化的记忆字符串,格式:问题:xxx,答案:xxxxx\n问题:xxx,答案:xxxxx\n...
|
||||
"""
|
||||
try:
|
||||
memories = []
|
||||
|
||||
# 从fetched_memory_list中获取该chat_id的所有记忆
|
||||
for cid, (question, answer, timestamp) in self.fetched_memory_list:
|
||||
if cid == chat_id:
|
||||
memories.append(f"问题:{question},答案:{answer}")
|
||||
|
||||
# 按时间戳排序(最新的在后面)
|
||||
memories.sort()
|
||||
|
||||
# 用换行符连接所有记忆
|
||||
result = "\n".join(memories)
|
||||
|
||||
# logger.info(f"chat_id {chat_id} 共有 {len(memories)} 条记忆")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取chat_id {chat_id} 的记忆时出错: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def select_title_by_question(self, question: str) -> str:
|
||||
"""
|
||||
根据消息内容选择最匹配的标题
|
||||
|
||||
Args:
|
||||
question: 问题
|
||||
|
||||
Returns:
|
||||
str: 选择的标题
|
||||
"""
|
||||
# 获取所有标题并构建格式化字符串(排除锁定的记忆)
|
||||
titles = get_all_titles(exclude_locked=True)
|
||||
formatted_titles = ""
|
||||
for title in titles:
|
||||
formatted_titles += f"{title}\n"
|
||||
|
||||
prompt = f"""
|
||||
所有主题:
|
||||
{formatted_titles}
|
||||
|
||||
请根据以下问题,选择一个能够回答问题的主题:
|
||||
问题:{question}
|
||||
请你输出主题,不要输出其他内容,完整输出主题名:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库选择标题 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库选择标题 prompt: {prompt}")
|
||||
|
||||
|
||||
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
||||
|
||||
# 根据 title 获取 titles 里的对应项
|
||||
selected_title = None
|
||||
|
||||
# 使用模糊查找匹配标题
|
||||
best_match = find_best_matching_memory(title, similarity_threshold=0.8)
|
||||
if best_match:
|
||||
selected_title = best_match[0] # 获取匹配的标题
|
||||
logger.info(f"记忆仓库选择标题: {selected_title} (相似度: {best_match[2]:.3f})")
|
||||
else:
|
||||
logger.warning(f"未找到相似度 >= 0.7 的标题匹配: {title}")
|
||||
selected_title = None
|
||||
|
||||
return selected_title
|
||||
|
||||
def _cleanup_fetched_memory_list(self):
|
||||
"""
|
||||
清理fetched_memory_list,移除超过10分钟的记忆和超过10条的最旧记忆
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
ten_minutes_ago = current_time - 600 # 10分钟 = 600秒
|
||||
|
||||
# 移除超过10分钟的记忆
|
||||
self.fetched_memory_list = [
|
||||
(chat_id, (question, answer, timestamp))
|
||||
for chat_id, (question, answer, timestamp) in self.fetched_memory_list
|
||||
if timestamp > ten_minutes_ago
|
||||
]
|
||||
|
||||
# 如果记忆条数超过10条,移除最旧的5条
|
||||
if len(self.fetched_memory_list) > 10:
|
||||
# 按时间戳排序,移除最旧的5条
|
||||
self.fetched_memory_list.sort(key=lambda x: x[1][2]) # 按timestamp排序
|
||||
self.fetched_memory_list = self.fetched_memory_list[5:] # 保留最新的5条
|
||||
|
||||
logger.debug(f"fetched_memory_list清理后,当前有 {len(self.fetched_memory_list)} 条记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理fetched_memory_list时出错: {e}")
|
||||
|
||||
async def _save_to_database_and_clear(self, chat_id: str, content: str):
|
||||
"""
|
||||
生成标题,保存到数据库,并清空对应chat_id的running_content
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
content: 要保存的内容
|
||||
"""
|
||||
try:
|
||||
# 生成标题
|
||||
title = ""
|
||||
title_prompt = f"""
|
||||
请为以下内容生成一个描述全面的标题,要求描述内容的主要概念和事件:
|
||||
{content}
|
||||
|
||||
标题不要分点,不要换行,不要输出其他内容
|
||||
请只输出标题,不要输出其他内容:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库生成标题 prompt: {title_prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库生成标题 prompt: {title_prompt}")
|
||||
|
||||
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(title_prompt)
|
||||
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if title:
|
||||
# 保存到数据库
|
||||
MemoryChestModel.create(
|
||||
title=title.strip(),
|
||||
content=content,
|
||||
chat_id=chat_id
|
||||
)
|
||||
logger.info(f"已保存记忆仓库内容,标题: {title.strip()}, chat_id: {chat_id}")
|
||||
|
||||
# 清空内容并刷新时间戳,但保留条目用于增量计算
|
||||
if chat_id in self.running_content_list:
|
||||
current_time = time.time()
|
||||
self.running_content_list[chat_id] = {
|
||||
"content": "",
|
||||
"last_update_time": current_time,
|
||||
"create_time": current_time
|
||||
}
|
||||
logger.info(f"已保存并刷新chat_id {chat_id} 的时间戳,准备下一次增量构建")
|
||||
else:
|
||||
logger.warning(f"生成标题失败,chat_id: {chat_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存记忆仓库内容时出错: {e}")
|
||||
|
||||
async def choose_merge_target(self, memory_title: str, chat_id: str = None) -> list[str]:
|
||||
"""
|
||||
选择与给定记忆标题相关的记忆目标
|
||||
|
||||
Args:
|
||||
memory_title: 要匹配的记忆标题
|
||||
chat_id: 聊天ID,用于加权抽样
|
||||
|
||||
Returns:
|
||||
list[str]: 选中的记忆内容列表
|
||||
"""
|
||||
try:
|
||||
# 如果提供了chat_id,使用加权抽样
|
||||
all_titles = get_memory_titles_by_chat_id_weighted(chat_id)
|
||||
# 剔除掉输入的 memory_title 本身
|
||||
all_titles = [title for title in all_titles if title and title.strip() != (memory_title or "").strip()]
|
||||
|
||||
content = ""
|
||||
display_index = 1
|
||||
for title in all_titles:
|
||||
content += f"{display_index}. {title}\n"
|
||||
display_index += 1
|
||||
|
||||
prompt = f"""
|
||||
所有记忆列表
|
||||
{content}
|
||||
|
||||
请根据以上记忆列表,选择一个与"{memory_title}"相关的记忆,用json输出:
|
||||
如果没有相关记忆,输出:
|
||||
{{
|
||||
"selected_title": ""
|
||||
}}
|
||||
可以选择多个相关的记忆,但最多不超过5个
|
||||
例如:
|
||||
{{
|
||||
"selected_title": "选择的相关记忆标题"
|
||||
}},
|
||||
{{
|
||||
"selected_title": "选择的相关记忆标题"
|
||||
}},
|
||||
{{
|
||||
"selected_title": "选择的相关记忆标题"
|
||||
}}
|
||||
...
|
||||
注意:请返回原始标题本身作为 selected_title,不要包含前面的序号或多余字符。
|
||||
请输出JSON格式,不要输出其他内容:
|
||||
"""
|
||||
|
||||
# logger.info(f"选择合并目标 prompt: {prompt}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"选择合并目标 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"选择合并目标 prompt: {prompt}")
|
||||
|
||||
merge_target_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
||||
|
||||
# 解析JSON响应
|
||||
selected_titles = self._parse_merge_target_json(merge_target_response)
|
||||
|
||||
# 根据标题查找对应的内容
|
||||
selected_contents = self._get_memories_by_titles(selected_titles)
|
||||
|
||||
logger.info(f"选择合并目标结果: {len(selected_contents)} 条记忆:{selected_titles}")
|
||||
return selected_titles,selected_contents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择合并目标时出错: {e}")
|
||||
return []
|
||||
|
||||
def _get_memories_by_titles(self, titles: list[str]) -> list[str]:
|
||||
"""
|
||||
根据标题列表查找对应的记忆内容
|
||||
|
||||
Args:
|
||||
titles: 记忆标题列表
|
||||
|
||||
Returns:
|
||||
list[str]: 记忆内容列表
|
||||
"""
|
||||
try:
|
||||
contents = []
|
||||
for title in titles:
|
||||
if not title or not title.strip():
|
||||
continue
|
||||
|
||||
# 使用模糊查找匹配记忆
|
||||
try:
|
||||
best_match = find_best_matching_memory(title.strip(), similarity_threshold=0.8)
|
||||
if best_match:
|
||||
# 检查记忆是否被锁定
|
||||
memory_title = best_match[0]
|
||||
memory_content = best_match[1]
|
||||
|
||||
# 查询数据库中的锁定状态
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title == memory_title and memory.locked:
|
||||
logger.warning(f"记忆 '{memory_title}' 已锁定,跳过合并")
|
||||
continue
|
||||
|
||||
contents.append(memory_content)
|
||||
logger.debug(f"找到记忆: {memory_title} (相似度: {best_match[2]:.3f})")
|
||||
else:
|
||||
logger.warning(f"未找到相似度 >= 0.8 的标题匹配: '{title}'")
|
||||
except Exception as e:
|
||||
logger.error(f"查找标题 '{title}' 的记忆时出错: {e}")
|
||||
continue
|
||||
|
||||
# logger.info(f"成功找到 {len(contents)} 条记忆内容")
|
||||
return contents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"根据标题查找记忆时出错: {e}")
|
||||
return []
|
||||
|
||||
def _parse_merged_parts(self, merged_response: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析合并记忆的part1和part2内容
|
||||
|
||||
Args:
|
||||
merged_response: LLM返回的合并记忆响应
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (part1_content, part2_content)
|
||||
"""
|
||||
try:
|
||||
# 使用正则表达式提取part1和part2内容
|
||||
import re
|
||||
|
||||
# 提取part1内容
|
||||
part1_pattern = r'<part1>(.*?)</part1>'
|
||||
part1_match = re.search(part1_pattern, merged_response, re.DOTALL)
|
||||
part1_content = part1_match.group(1).strip() if part1_match else ""
|
||||
|
||||
# 提取part2内容
|
||||
part2_pattern = r'<part2>(.*?)</part2>'
|
||||
part2_match = re.search(part2_pattern, merged_response, re.DOTALL)
|
||||
part2_content = part2_match.group(1).strip() if part2_match else ""
|
||||
|
||||
# 检查是否包含none或None(不区分大小写)
|
||||
def is_none_content(content: str) -> bool:
|
||||
if not content:
|
||||
return True
|
||||
# 检查是否只包含"none"或"None"(不区分大小写)
|
||||
return re.match(r'^\s*none\s*$', content, re.IGNORECASE) is not None
|
||||
|
||||
# 如果包含none,则设置为空字符串
|
||||
if is_none_content(part1_content):
|
||||
part1_content = ""
|
||||
logger.info("part1内容为none,设置为空")
|
||||
|
||||
if is_none_content(part2_content):
|
||||
part2_content = ""
|
||||
logger.info("part2内容为none,设置为空")
|
||||
|
||||
return part1_content, part2_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析合并记忆part1/part2时出错: {e}")
|
||||
return "", ""
|
||||
|
||||
def _parse_merge_target_json(self, json_text: str) -> list[str]:
|
||||
"""
|
||||
解析choose_merge_target生成的JSON响应
|
||||
|
||||
Args:
|
||||
json_text: LLM返回的JSON文本
|
||||
|
||||
Returns:
|
||||
list[str]: 解析出的记忆标题列表
|
||||
"""
|
||||
try:
|
||||
# 清理JSON文本,移除可能的额外内容
|
||||
repaired_content = repair_json(json_text)
|
||||
|
||||
# 尝试直接解析JSON
|
||||
try:
|
||||
parsed_data = json.loads(repaired_content)
|
||||
if isinstance(parsed_data, list):
|
||||
# 如果是列表,提取selected_title字段
|
||||
titles = []
|
||||
for item in parsed_data:
|
||||
if isinstance(item, dict) and "selected_title" in item:
|
||||
value = item.get("selected_title", "")
|
||||
if isinstance(value, str) and value.strip():
|
||||
titles.append(value)
|
||||
return titles
|
||||
elif isinstance(parsed_data, dict) and "selected_title" in parsed_data:
|
||||
# 如果是单个对象
|
||||
value = parsed_data.get("selected_title", "")
|
||||
if isinstance(value, str) and value.strip():
|
||||
return [value]
|
||||
else:
|
||||
# 空字符串表示没有相关记忆
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果直接解析失败,尝试提取JSON对象
|
||||
# 查找所有包含selected_title的JSON对象
|
||||
pattern = r'\{[^}]*"selected_title"[^}]*\}'
|
||||
matches = re.findall(pattern, repaired_content)
|
||||
|
||||
titles = []
|
||||
for match in matches:
|
||||
try:
|
||||
obj = json.loads(match)
|
||||
if "selected_title" in obj:
|
||||
value = obj.get("selected_title", "")
|
||||
if isinstance(value, str) and value.strip():
|
||||
titles.append(value)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if titles:
|
||||
return titles
|
||||
|
||||
logger.warning(f"无法解析JSON响应: {json_text[:200]}...")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析合并目标JSON时出错: {e}")
|
||||
return []
|
||||
|
||||
async def merge_memory(self,memory_list: list[str], chat_id: str = None) -> tuple[str, str]:
|
||||
"""
|
||||
合并记忆
|
||||
"""
|
||||
try:
|
||||
content = ""
|
||||
for memory in memory_list:
|
||||
content += f"{memory}\n"
|
||||
|
||||
prompt = f"""
|
||||
以下是多段记忆内容,请将它们进行整合和修改:
|
||||
{content}
|
||||
--------------------------------
|
||||
请将上面的多段记忆内容,合并成两部分内容,第一部分是可以整合,不冲突的概念和知识,第二部分是相互有冲突的概念和知识
|
||||
请主要关注概念和知识,而不是聊天的琐事
|
||||
重要!!你要关注的概念和知识必须是较为不常见的信息,或者时效性较强的信息!!
|
||||
不要!!关注常见的只是,或者已经过时的信息!!
|
||||
1.不要关注诸如某个用户做了什么,说了什么,不要关注某个用户的行为,而是关注其中的概念性信息
|
||||
2.概念要求精确,不啰嗦,像科普读物或教育课本那样
|
||||
3.如果有图片,请只关注图片和文本结合的知识和概念性内容
|
||||
4.记忆为一段纯文本,逻辑清晰,指出概念的含义,并说明关系
|
||||
**第一部分**
|
||||
1.如果两个概念在描述同一件事情,且相互之间逻辑不冲突(请你严格判断),且相互之间没有矛盾,请将它们整合成一个概念,并输出到第一部分
|
||||
2.如果某个概念在时间上更新了另一个概念,请用新概念更新就概念来整合,并输出到第一部分
|
||||
3.如果没有可整合的概念,请你输出none
|
||||
**第二部分**
|
||||
1.如果记忆中有无法整合的地方,例如概念不一致,有逻辑上的冲突,请你输出到第二部分
|
||||
2.如果两个概念在描述同一件事情,但相互之间逻辑冲突,请将它们输出到第二部分
|
||||
3.如果没有无法整合的概念,请你输出none
|
||||
|
||||
**输出格式要求**
|
||||
请你按以下格式输出:
|
||||
<part1>
|
||||
第一部分内容,整合后的概念,如果第一部分为none,请输出none
|
||||
</part1>
|
||||
<part2>
|
||||
第二部分内容,无法整合,冲突的概念,如果第二部分为none,请输出none
|
||||
</part2>
|
||||
不要输出其他内容,现在请你输出,不要输出其他内容,注意一定要直白,白话,口语化不要浮夸,修辞。:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"合并记忆 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"合并记忆 prompt: {prompt}")
|
||||
|
||||
merged_memory, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
||||
|
||||
# 解析part1和part2
|
||||
part1_content, part2_content = self._parse_merged_parts(merged_memory)
|
||||
|
||||
# 处理part2:独立记录冲突内容(无论part1是否为空)
|
||||
if part2_content and part2_content.strip() != "none":
|
||||
logger.info(f"合并记忆part2记录冲突内容: {len(part2_content)} 字符")
|
||||
# 记录冲突到数据库
|
||||
await global_conflict_tracker.record_memory_merge_conflict(part2_content,chat_id)
|
||||
|
||||
# 处理part1:生成标题并保存
|
||||
if part1_content and part1_content.strip() != "none":
|
||||
merged_title = await self._generate_title_for_merged_memory(part1_content)
|
||||
|
||||
# 保存part1到数据库
|
||||
MemoryChestModel.create(
|
||||
title=merged_title,
|
||||
content=part1_content,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
logger.info(f"合并记忆part1已保存: {merged_title}")
|
||||
|
||||
return merged_title, part1_content
|
||||
else:
|
||||
logger.warning("合并记忆part1为空,跳过保存")
|
||||
return "", ""
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆时出错: {e}")
|
||||
return "", ""
|
||||
|
||||
async def _generate_title_for_merged_memory(self, merged_content: str) -> str:
|
||||
"""
|
||||
为合并后的记忆生成标题
|
||||
|
||||
Args:
|
||||
merged_content: 合并后的记忆内容
|
||||
|
||||
Returns:
|
||||
str: 生成的标题
|
||||
"""
|
||||
try:
|
||||
prompt = f"""
|
||||
请为以下内容生成一个描述全面的标题,要求描述内容的主要概念和事件:
|
||||
例如:
|
||||
<example>
|
||||
标题:达尔文的自然选择理论
|
||||
内容:达尔文的自然选择是生物进化理论的重要组成部分,它解释了生物进化过程中的自然选择机制。
|
||||
</example>
|
||||
<example>
|
||||
标题:麦麦的禁言插件和支持版本
|
||||
内容:
|
||||
麦麦的禁言插件是一款能够实现禁言的插件
|
||||
麦麦的禁言插件可能不支持0.10.2
|
||||
MutePlugin 是禁言插件的名称
|
||||
</example>
|
||||
|
||||
|
||||
需要对以下内容生成标题:
|
||||
{merged_content}
|
||||
|
||||
|
||||
标题不要分点,不要换行,不要输出其他内容,不要浮夸,以白话简洁的风格输出标题
|
||||
请只输出标题,不要输出其他内容:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"生成合并记忆标题 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"生成合并记忆标题 prompt: {prompt}")
|
||||
|
||||
title_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
||||
|
||||
# 清理标题,移除可能的引号或多余字符
|
||||
title = title_response.strip().strip('"').strip("'").strip()
|
||||
|
||||
if title:
|
||||
# 检查是否存在相似标题
|
||||
if check_title_exists_fuzzy(title, similarity_threshold=0.9):
|
||||
logger.warning(f"生成的标题 '{title}' 与现有标题相似,使用时间戳后缀")
|
||||
title = f"{title}_{int(time.time())}"
|
||||
|
||||
logger.info(f"生成合并记忆标题: {title}")
|
||||
return title
|
||||
else:
|
||||
logger.warning("生成合并记忆标题失败,使用默认标题")
|
||||
return f"合并记忆_{int(time.time())}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成合并记忆标题时出错: {e}")
|
||||
return f"合并记忆_{int(time.time())}"
|
||||
|
||||
|
||||
global_memory_chest = MemoryChest()
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.config.config import global_config
|
||||
from src.memory_system.memory_utils import get_all_titles
|
||||
|
||||
logger = get_logger("memory")
|
||||
|
||||
|
||||
class MemoryManagementTask(AsyncTask):
|
||||
"""记忆管理定时任务
|
||||
|
||||
根据Memory_chest中的记忆数量与MAX_MEMORY_NUMBER的比例来决定执行频率:
|
||||
- 小于50%:每600秒执行一次
|
||||
- 大于等于50%:每300秒执行一次
|
||||
|
||||
每次执行时随机选择一个title,执行choose_merge_target和merge_memory,
|
||||
然后删除原始记忆
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
task_name="Memory Management Task",
|
||||
wait_before_start=10, # 启动后等待10秒再开始
|
||||
run_interval=300 # 默认300秒间隔,会根据记忆数量动态调整
|
||||
)
|
||||
self.max_memory_number = global_config.memory.max_memory_number
|
||||
|
||||
async def start_task(self, abort_flag: asyncio.Event):
|
||||
"""重写start_task方法,支持动态调整执行间隔"""
|
||||
if self.wait_before_start > 0:
|
||||
# 等待指定时间后开始任务
|
||||
await asyncio.sleep(self.wait_before_start)
|
||||
|
||||
while not abort_flag.is_set():
|
||||
await self.run()
|
||||
|
||||
# 动态调整执行间隔
|
||||
current_interval = self._calculate_interval()
|
||||
logger.info(f"[记忆管理] 下次执行间隔: {current_interval}秒")
|
||||
|
||||
if current_interval > 0:
|
||||
await asyncio.sleep(current_interval)
|
||||
else:
|
||||
break
|
||||
|
||||
def _calculate_interval(self) -> int:
|
||||
"""根据当前记忆数量计算执行间隔"""
|
||||
try:
|
||||
current_count = self._get_memory_count()
|
||||
percentage = current_count / self.max_memory_number
|
||||
|
||||
if percentage < 0.5:
|
||||
# 小于50%,每600秒执行一次
|
||||
return 3600
|
||||
elif percentage < 0.7:
|
||||
# 大于等于50%,每300秒执行一次
|
||||
return 1800
|
||||
elif percentage < 0.9:
|
||||
# 大于等于70%,每120秒执行一次
|
||||
return 300
|
||||
elif percentage < 1.2:
|
||||
return 30
|
||||
else:
|
||||
return 10
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 计算执行间隔时出错: {e}")
|
||||
return 300 # 默认300秒
|
||||
|
||||
def _get_memory_count(self) -> int:
|
||||
"""获取当前记忆数量"""
|
||||
try:
|
||||
count = MemoryChestModel.select().count()
|
||||
logger.debug(f"[记忆管理] 当前记忆数量: {count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 获取记忆数量时出错: {e}")
|
||||
return 0
|
||||
|
||||
async def run(self):
|
||||
"""执行记忆管理任务"""
|
||||
try:
|
||||
|
||||
# 获取当前记忆数量
|
||||
current_count = self._get_memory_count()
|
||||
percentage = current_count / self.max_memory_number
|
||||
logger.info(f"当前记忆数量: {current_count}/{self.max_memory_number} ({percentage:.1%})")
|
||||
|
||||
# 如果记忆数量为0,跳过执行
|
||||
if current_count < 10:
|
||||
return
|
||||
|
||||
# 随机选择一个记忆标题和chat_id
|
||||
selected_title, selected_chat_id = self._get_random_memory_title()
|
||||
if not selected_title:
|
||||
logger.warning("无法获取随机记忆标题,跳过执行")
|
||||
return
|
||||
|
||||
# 执行choose_merge_target获取相关记忆(标题与内容)
|
||||
related_titles, related_contents = await global_memory_chest.choose_merge_target(selected_title, selected_chat_id)
|
||||
if not related_titles or not related_contents:
|
||||
logger.info("无合适合并内容,跳过本次合并")
|
||||
return
|
||||
|
||||
logger.info(f"为 [{selected_title}] 找到 {len(related_contents)} 条相关记忆:{related_titles}")
|
||||
|
||||
# 执行merge_memory合并记忆
|
||||
merged_title, merged_content = await global_memory_chest.merge_memory(related_contents,selected_chat_id)
|
||||
if not merged_title or not merged_content:
|
||||
logger.warning("[记忆管理] 记忆合并失败,跳过删除")
|
||||
return
|
||||
|
||||
logger.info(f"记忆合并成功,新标题: {merged_title}")
|
||||
|
||||
# 删除原始记忆(包括选中的标题和相关的记忆标题)
|
||||
titles_to_delete = [selected_title] + related_titles
|
||||
deleted_count = self._delete_original_memories(titles_to_delete)
|
||||
logger.info(f"已删除 {deleted_count} 条原始记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 执行记忆管理任务时发生错误: {e}", exc_info=True)
|
||||
|
||||
def _get_random_memory_title(self) -> tuple[str, str]:
|
||||
"""随机获取一个记忆标题和对应的chat_id"""
|
||||
try:
|
||||
# 获取所有记忆记录
|
||||
all_memories = MemoryChestModel.select()
|
||||
if not all_memories:
|
||||
return "", ""
|
||||
|
||||
# 随机选择一个记忆
|
||||
selected_memory = random.choice(list(all_memories))
|
||||
return selected_memory.title, selected_memory.chat_id or ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 获取随机记忆标题时发生错误: {e}")
|
||||
return "", ""
|
||||
|
||||
def _delete_original_memories(self, related_titles: List[str]) -> int:
|
||||
"""按标题删除原始记忆"""
|
||||
try:
|
||||
deleted_count = 0
|
||||
# 删除相关记忆(通过标题匹配)
|
||||
for title in related_titles:
|
||||
try:
|
||||
# 通过标题查找并删除对应的记忆
|
||||
memories_to_delete = MemoryChestModel.select().where(MemoryChestModel.title == title)
|
||||
for memory in memories_to_delete:
|
||||
MemoryChestModel.delete().where(MemoryChestModel.id == memory.id).execute()
|
||||
deleted_count += 1
|
||||
logger.debug(f"[记忆管理] 删除相关记忆: {memory.title}")
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 删除相关记忆时出错: {e}")
|
||||
continue
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 删除原始记忆时发生错误: {e}")
|
||||
return 0
|
||||
|
|
@ -0,0 +1,306 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统工具函数
|
||||
包含模糊查找、相似度计算等工具函数
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.common.logger import get_logger
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
def get_all_titles(exclude_locked: bool = False) -> list[str]:
|
||||
"""
|
||||
获取记忆仓库中的所有标题
|
||||
|
||||
Args:
|
||||
exclude_locked: 是否排除锁定的记忆,默认为 False
|
||||
|
||||
Returns:
|
||||
list: 包含所有标题的列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有记忆记录的标题
|
||||
titles = []
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title:
|
||||
# 如果 exclude_locked 为 True 且记忆已锁定,则跳过
|
||||
if exclude_locked and memory.locked:
|
||||
continue
|
||||
titles.append(memory.title)
|
||||
return titles
|
||||
except Exception as e:
|
||||
print(f"获取记忆标题时出错: {e}")
|
||||
return []
|
||||
|
||||
def parse_md_json(json_text: str) -> list[str]:
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
reasoning_content = ""
|
||||
|
||||
# 使用正则表达式查找```json包裹的JSON内容
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, json_text, re.DOTALL)
|
||||
|
||||
# 提取JSON之前的内容作为推理文本
|
||||
if matches:
|
||||
# 找到第一个```json的位置
|
||||
first_json_pos = json_text.find("```json")
|
||||
if first_json_pos > 0:
|
||||
reasoning_content = json_text[:first_json_pos].strip()
|
||||
# 清理推理内容中的注释标记
|
||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||
reasoning_content = reasoning_content.strip()
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
# 清理可能的注释和格式问题
|
||||
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
|
||||
if json_str := json_str.strip():
|
||||
json_obj = json.loads(json_str)
|
||||
if isinstance(json_obj, dict):
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict):
|
||||
json_objects.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
|
||||
continue
|
||||
|
||||
return json_objects, reasoning_content
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
Returns:
|
||||
float: 相似度分数 (0-1)
|
||||
"""
|
||||
try:
|
||||
# 预处理文本
|
||||
text1 = preprocess_text(text1)
|
||||
text2 = preprocess_text(text2)
|
||||
|
||||
# 使用SequenceMatcher计算相似度
|
||||
similarity = SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
# 如果其中一个文本包含另一个,提高相似度
|
||||
if text1 in text2 or text2 in text1:
|
||||
similarity = max(similarity, 0.8)
|
||||
|
||||
return similarity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算相似度时出错: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
"""
|
||||
预处理文本,提高匹配准确性
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
str: 预处理后的文本
|
||||
"""
|
||||
try:
|
||||
# 转换为小写
|
||||
text = text.lower()
|
||||
|
||||
# 移除标点符号和特殊字符
|
||||
text = re.sub(r'[^\w\s]', '', text)
|
||||
|
||||
# 移除多余空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理文本时出错: {e}")
|
||||
return text
|
||||
|
||||
|
||||
def fuzzy_find_memory_by_title(target_title: str, similarity_threshold: float = 0.9) -> List[Tuple[str, str, float]]:
|
||||
"""
|
||||
根据标题模糊查找记忆
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
similarity_threshold: 相似度阈值,默认0.9
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, float]]: 匹配的记忆列表,每个元素为(title, content, similarity_score)
|
||||
"""
|
||||
try:
|
||||
# 获取所有记忆
|
||||
all_memories = MemoryChestModel.select()
|
||||
|
||||
matches = []
|
||||
for memory in all_memories:
|
||||
similarity = calculate_similarity(target_title, memory.title)
|
||||
if similarity >= similarity_threshold:
|
||||
matches.append((memory.title, memory.content, similarity))
|
||||
|
||||
# 按相似度降序排序
|
||||
matches.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# logger.info(f"模糊查找标题 '{target_title}' 找到 {len(matches)} 个匹配项")
|
||||
return matches
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模糊查找记忆时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def find_best_matching_memory(target_title: str, similarity_threshold: float = 0.9) -> Optional[Tuple[str, str, float]]:
|
||||
"""
|
||||
查找最佳匹配的记忆
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
similarity_threshold: 相似度阈值
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, float]]: 最佳匹配的记忆(title, content, similarity)或None
|
||||
"""
|
||||
try:
|
||||
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
|
||||
|
||||
if matches:
|
||||
best_match = matches[0] # 已经按相似度排序,第一个是最佳匹配
|
||||
# logger.info(f"找到最佳匹配: '{best_match[0]}' (相似度: {best_match[2]:.3f})")
|
||||
return best_match
|
||||
else:
|
||||
logger.info(f"未找到相似度 >= {similarity_threshold} 的记忆")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找最佳匹配记忆时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_title_exists_fuzzy(target_title: str, similarity_threshold: float = 0.9) -> bool:
|
||||
"""
|
||||
检查标题是否已存在(模糊匹配)
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
similarity_threshold: 相似度阈值,默认0.9(较高阈值避免误判)
|
||||
|
||||
Returns:
|
||||
bool: 是否存在相似标题
|
||||
"""
|
||||
try:
|
||||
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
|
||||
exists = len(matches) > 0
|
||||
|
||||
if exists:
|
||||
logger.info(f"发现相似标题: '{matches[0][0]}' (相似度: {matches[0][2]:.3f})")
|
||||
else:
|
||||
logger.debug("未发现相似标题")
|
||||
|
||||
return exists
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查标题是否存在时出错: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_memories_by_chat_id_weighted(target_chat_id: str, same_chat_weight: float = 0.95, other_chat_weight: float = 0.05) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
根据chat_id进行加权抽样获取记忆列表
|
||||
|
||||
Args:
|
||||
target_chat_id: 目标聊天ID
|
||||
same_chat_weight: 同chat_id记忆的权重,默认0.95(95%概率)
|
||||
other_chat_weight: 其他chat_id记忆的权重,默认0.05(5%概率)
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 选中的记忆列表,每个元素为(title, content, chat_id)
|
||||
"""
|
||||
try:
|
||||
# 获取所有记忆
|
||||
all_memories = MemoryChestModel.select()
|
||||
|
||||
# 按chat_id分组
|
||||
same_chat_memories = []
|
||||
other_chat_memories = []
|
||||
|
||||
for memory in all_memories:
|
||||
if memory.title and not memory.locked: # 排除锁定的记忆
|
||||
if memory.chat_id == target_chat_id:
|
||||
same_chat_memories.append((memory.title, memory.content, memory.chat_id))
|
||||
else:
|
||||
other_chat_memories.append((memory.title, memory.content, memory.chat_id))
|
||||
|
||||
# 如果没有同chat_id的记忆,返回空列表
|
||||
if not same_chat_memories:
|
||||
logger.warning(f"未找到chat_id为 '{target_chat_id}' 的记忆")
|
||||
return []
|
||||
|
||||
# 计算抽样数量
|
||||
total_same = len(same_chat_memories)
|
||||
total_other = len(other_chat_memories)
|
||||
|
||||
# 根据权重计算抽样数量
|
||||
if total_other > 0:
|
||||
# 计算其他chat_id记忆的抽样数量(至少1个,最多不超过总数的10%)
|
||||
other_sample_count = max(1, min(total_other, int(total_same * other_chat_weight / same_chat_weight)))
|
||||
else:
|
||||
other_sample_count = 0
|
||||
|
||||
# 随机抽样
|
||||
selected_memories = []
|
||||
|
||||
# 选择同chat_id的记忆(全部选择,因为权重很高)
|
||||
selected_memories.extend(same_chat_memories)
|
||||
|
||||
# 随机选择其他chat_id的记忆
|
||||
if other_sample_count > 0 and total_other > 0:
|
||||
import random
|
||||
other_selected = random.sample(other_chat_memories, min(other_sample_count, total_other))
|
||||
selected_memories.extend(other_selected)
|
||||
|
||||
logger.info(f"加权抽样结果: 同chat_id记忆 {len(same_chat_memories)} 条,其他chat_id记忆 {min(other_sample_count, total_other)} 条")
|
||||
|
||||
return selected_memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"按chat_id加权抽样记忆时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_memory_titles_by_chat_id_weighted(target_chat_id: str, same_chat_weight: float = 0.95, other_chat_weight: float = 0.05) -> List[str]:
|
||||
"""
|
||||
根据chat_id进行加权抽样获取记忆标题列表(用于合并选择)
|
||||
|
||||
Args:
|
||||
target_chat_id: 目标聊天ID
|
||||
same_chat_weight: 同chat_id记忆的权重,默认0.95(95%概率)
|
||||
other_chat_weight: 其他chat_id记忆的权重,默认0.05(5%概率)
|
||||
|
||||
Returns:
|
||||
List[str]: 选中的记忆标题列表
|
||||
"""
|
||||
try:
|
||||
memories = get_memories_by_chat_id_weighted(target_chat_id, same_chat_weight, other_chat_weight)
|
||||
titles = [memory[0] for memory in memories] # 提取标题
|
||||
return titles
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"按chat_id加权抽样记忆标题时出错: {e}")
|
||||
return []
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
import time
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.common.database.database_model import MemoryConflict
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
class QuestionMaker:
|
||||
def __init__(self, chat_id: str, context: str = "") -> None:
|
||||
"""问题生成器。
|
||||
|
||||
- chat_id: 会话 ID,用于筛选该会话下的冲突记录。
|
||||
- context: 额外上下文,可用于后续扩展。
|
||||
|
||||
用法示例:
|
||||
>>> qm = QuestionMaker(chat_id="some_chat")
|
||||
>>> question, chat_ctx, conflict_ctx = await qm.make_question()
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.context = context
|
||||
|
||||
def get_context(self, timestamp: float = time.time()) -> str:
|
||||
"""获取指定时间点之前的对话上下文字符串。"""
|
||||
latest_30_msgs = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=timestamp,
|
||||
limit=30,
|
||||
)
|
||||
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
latest_30_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
)
|
||||
return all_dialogue_prompt_str
|
||||
|
||||
|
||||
async def get_all_conflicts(self) -> List[MemoryConflict]:
|
||||
"""获取当前会话下的所有记忆冲突记录。"""
|
||||
conflicts: List[MemoryConflict] = list(MemoryConflict.select().where(MemoryConflict.chat_id == self.chat_id))
|
||||
return conflicts
|
||||
|
||||
async def get_un_answered_conflict(self) -> List[MemoryConflict]:
|
||||
"""获取未回答的记忆冲突记录(answer 为空)。"""
|
||||
conflicts = await self.get_all_conflicts()
|
||||
return [conflict for conflict in conflicts if not conflict.answer]
|
||||
|
||||
async def get_random_unanswered_conflict(self) -> Optional[MemoryConflict]:
|
||||
"""按权重随机选取一个未回答的冲突并自增 raise_time。
|
||||
|
||||
选择规则:
|
||||
- 若存在 `raise_time == 0` 的项:按权重抽样(0 次权重 1.0,≥1 次权重 0.05)。
|
||||
- 若不存在 `raise_time == 0` 的项:仅 5% 概率返回其中任意一条,否则返回 None。
|
||||
- 每次成功选中后,将该条目的 `raise_time` 自增 1 并保存。
|
||||
"""
|
||||
conflicts = await self.get_un_answered_conflict()
|
||||
if not conflicts:
|
||||
return None
|
||||
|
||||
# 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个
|
||||
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
|
||||
if conflicts_with_zero:
|
||||
# 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.01
|
||||
weights = []
|
||||
for conflict in conflicts:
|
||||
current_raise_time = getattr(conflict, "raise_time", 0) or 0
|
||||
weight = 1.0 if current_raise_time == 0 else 0.01
|
||||
weights.append(weight)
|
||||
|
||||
# 按权重随机选择
|
||||
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
|
||||
|
||||
# 选中后,自增 raise_time 并保存
|
||||
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
|
||||
chosen_conflict.save()
|
||||
|
||||
|
||||
return chosen_conflict
|
||||
|
||||
async def make_question(self) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""生成一条用于询问用户的冲突问题与上下文。
|
||||
|
||||
返回三元组 (question, chat_context, conflict_context):
|
||||
- question: 冲突文本;若本次未选中任何冲突则为 None。
|
||||
- chat_context: 该冲突创建时间点前的会话上下文字符串;若无则为 None。
|
||||
- conflict_context: 冲突在 DB 中存储的上下文;若无则为 None。
|
||||
"""
|
||||
conflict = await self.get_random_unanswered_conflict()
|
||||
if not conflict:
|
||||
return None, None, None
|
||||
question = conflict.conflict_content
|
||||
conflict_context = conflict.context
|
||||
create_time = conflict.create_time
|
||||
chat_context = self.get_context(create_time)
|
||||
|
||||
return question, chat_context, conflict_context
|
||||
|
|
@ -0,0 +1,478 @@
|
|||
import time
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import MemoryConflict
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from typing import List
|
||||
from src.memory_system.memory_utils import parse_md_json
|
||||
|
||||
logger = get_logger("conflict_tracker")
|
||||
|
||||
class QuestionTracker:
|
||||
"""
|
||||
用于跟踪一个问题在后续聊天中的解答情况
|
||||
"""
|
||||
|
||||
def __init__(self, question: str, chat_id: str, context: str = "") -> None:
|
||||
self.question = question
|
||||
self.chat_id = chat_id
|
||||
now = time.time()
|
||||
self.context = context
|
||||
self.start_time = now
|
||||
self.last_read_time = now
|
||||
self.last_judge_time = now # 上次判定的时间
|
||||
self.judge_debounce_interval = 10.0 # 判定防抖间隔:10秒
|
||||
self.consecutive_end_count = 0 # 连续END计数
|
||||
self.active = True
|
||||
# 将 LLM 实例作为类属性,使用 utils 模型
|
||||
self.llm_request = LLMRequest(model_set=model_config.model_task_config.utils, request_type="conflict.judge")
|
||||
|
||||
def stop(self) -> None:
|
||||
self.active = False
|
||||
|
||||
def should_judge_now(self) -> bool:
|
||||
"""
|
||||
检查是否应该进行判定(防抖检查)
|
||||
|
||||
Returns:
|
||||
bool: 是否可以判定
|
||||
"""
|
||||
now = time.time()
|
||||
# 检查是否已经过了10秒的防抖间隔
|
||||
return (now - self.last_judge_time) >= self.judge_debounce_interval
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""比较两个追踪器是否相等(基于问题内容和聊天ID)"""
|
||||
if not isinstance(other, QuestionTracker):
|
||||
return False
|
||||
return self.question == other.question and self.chat_id == other.chat_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""为对象提供哈希值,支持集合操作"""
|
||||
return hash((self.question, self.chat_id))
|
||||
|
||||
async def judge_answer(self, conversation_text: str,chat_len: int) -> tuple[bool, str, str]:
|
||||
"""
|
||||
使用模型判定问题是否已得到解答。
|
||||
|
||||
Returns:
|
||||
tuple[bool, str, str]: (是否结束跟踪, 结束原因或答案, 判定类型)
|
||||
- True: 结束跟踪(已解答、话题转向等)
|
||||
- False: 继续跟踪
|
||||
判定类型: "ANSWERED", "END", "CONTINUE"
|
||||
"""
|
||||
|
||||
end_prompt = ""
|
||||
if chat_len > 20:
|
||||
end_prompt = "\n- 如果最新20条聊天记录内容与问题无关,话题已转向其他方向,请只输出:END"
|
||||
|
||||
prompt = f"""你是一个严谨的判定器。下面给出聊天记录以及一个问题。
|
||||
任务:判断在这段聊天中,该问题是否已经得到明确解答。
|
||||
**你必须严格按照聊天记录的内容,不要添加额外的信息**
|
||||
|
||||
输出规则:
|
||||
- 如果聊天记录内容的信息已解答问题,请只输出:YES: <简短答案>{end_prompt}
|
||||
- 如果问题尚未解答但聊天仍在相关话题上,请只输出:NO
|
||||
|
||||
**问题**
|
||||
{self.question}
|
||||
|
||||
|
||||
**聊天记录**
|
||||
{conversation_text}
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"判定提示词: {prompt}")
|
||||
else:
|
||||
logger.debug("已发送判定提示词")
|
||||
|
||||
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.5)
|
||||
|
||||
logger.info(f"判定结果: {prompt}\n{result_text}")
|
||||
|
||||
# 更新上次判定时间
|
||||
self.last_judge_time = time.time()
|
||||
|
||||
if not result_text:
|
||||
return False, "", "CONTINUE"
|
||||
|
||||
text = result_text.strip()
|
||||
if text.upper().startswith("YES:"):
|
||||
answer = text[4:].strip()
|
||||
return True, answer, "ANSWERED"
|
||||
if text.upper().startswith("YES"):
|
||||
# 兼容仅输出 YES 或 YES <answer>
|
||||
answer = text[3:].strip().lstrip(":").strip()
|
||||
return True, answer, "ANSWERED"
|
||||
if text.upper().startswith("END"):
|
||||
# 聊天内容与问题无关,放弃该问题思考
|
||||
return True, "话题已转向其他方向,放弃该问题思考", "END"
|
||||
return False, "", "CONTINUE"
|
||||
|
||||
class ConflictTracker:
|
||||
"""
|
||||
记忆整合冲突追踪器
|
||||
|
||||
用于记录和存储记忆整合过程中的冲突内容
|
||||
"""
|
||||
def __init__(self):
|
||||
self.question_tracker_list:List[QuestionTracker] = []
|
||||
|
||||
self.LLMRequest_tracker = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="conflict_tracker",
|
||||
)
|
||||
|
||||
def get_questions_by_chat_id(self, chat_id: str) -> List[QuestionTracker]:
|
||||
return [tracker for tracker in self.question_tracker_list if tracker.chat_id == chat_id]
|
||||
|
||||
async def track_conflict(self, question: str, context: str = "",start_following: bool = False,chat_id: str = "") -> bool:
|
||||
"""
|
||||
跟踪冲突内容
|
||||
"""
|
||||
tracker = QuestionTracker(question.strip(), chat_id, context)
|
||||
self.question_tracker_list.append(tracker)
|
||||
asyncio.create_task(self._follow_and_record(tracker, question.strip()))
|
||||
return True
|
||||
|
||||
async def record_conflict(self, conflict_content: str, context: str = "",start_following: bool = False,chat_id: str = "") -> bool:
|
||||
"""
|
||||
记录冲突内容
|
||||
|
||||
Args:k
|
||||
conflict_content: 冲突内容
|
||||
|
||||
Returns:
|
||||
bool: 是否成功记录
|
||||
"""
|
||||
try:
|
||||
if not conflict_content or conflict_content.strip() == "":
|
||||
return False
|
||||
|
||||
# 若需要跟随后续消息以判断是否得到解答,则进入跟踪流程
|
||||
if start_following and chat_id:
|
||||
tracker = QuestionTracker(conflict_content.strip(), chat_id, context)
|
||||
self.question_tracker_list.append(tracker)
|
||||
# 后台启动跟踪任务,避免阻塞
|
||||
asyncio.create_task(self._follow_and_record(tracker, conflict_content.strip()))
|
||||
return True
|
||||
|
||||
# 默认:直接记录,不进行跟踪
|
||||
MemoryConflict.create(
|
||||
conflict_content=conflict_content,
|
||||
create_time=time.time(),
|
||||
update_time=time.time(),
|
||||
answer="",
|
||||
chat_id=chat_id,
|
||||
)
|
||||
|
||||
logger.info(f"记录冲突内容: {len(conflict_content)} 字符")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录冲突内容时出错: {e}")
|
||||
return False
|
||||
|
||||
async def _follow_and_record(self, tracker: QuestionTracker, original_question: str) -> None:
|
||||
"""
|
||||
后台任务:跟踪问题是否被解答,并写入数据库。
|
||||
"""
|
||||
try:
|
||||
max_duration = 10 * 60 # 30 分钟
|
||||
max_messages = 50 # 最多 100 条消息
|
||||
poll_interval = 2.0 # 秒
|
||||
logger.info(f"开始跟踪问题: {original_question}")
|
||||
while tracker.active:
|
||||
now_ts = time.time()
|
||||
# 终止条件:时长达到上限
|
||||
if now_ts - tracker.start_time >= max_duration:
|
||||
logger.info("问题跟踪达到10分钟上限,判定为未解答")
|
||||
break
|
||||
|
||||
# 统计最近一段是否有新消息(不过滤机器人,过滤命令)
|
||||
recent_msgs = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=tracker.chat_id,
|
||||
timestamp_start=tracker.last_read_time,
|
||||
timestamp_end=now_ts,
|
||||
limit=30,
|
||||
limit_mode="latest",
|
||||
filter_bot=False,
|
||||
filter_command=True,
|
||||
)
|
||||
|
||||
if len(recent_msgs) > 0:
|
||||
tracker.last_read_time = now_ts
|
||||
|
||||
# 统计从开始到现在的总消息数(用于触发100条上限)
|
||||
all_msgs = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=tracker.chat_id,
|
||||
timestamp_start=tracker.start_time,
|
||||
timestamp_end=now_ts,
|
||||
limit=0,
|
||||
limit_mode="latest",
|
||||
filter_bot=False,
|
||||
filter_command=True,
|
||||
)
|
||||
|
||||
# 检查是否应该进行判定(防抖检查)
|
||||
if not tracker.should_judge_now():
|
||||
logger.debug(f"判定防抖中,跳过本次判定: {tracker.question}")
|
||||
await asyncio.sleep(poll_interval)
|
||||
continue
|
||||
|
||||
# 构建可读聊天文本
|
||||
chat_text = build_readable_messages(
|
||||
all_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=False,
|
||||
remove_emoji_stickers=True,
|
||||
)
|
||||
chat_len = len(all_msgs)
|
||||
# 让小模型判断是否有答案
|
||||
answered, answer_text, judge_type = await tracker.judge_answer(chat_text,chat_len)
|
||||
|
||||
if judge_type == "ANSWERED":
|
||||
# 问题已解答,直接结束跟踪
|
||||
logger.info("问题已得到解答,结束跟踪并写入答案")
|
||||
await self.add_or_update_conflict(
|
||||
conflict_content=tracker.question,
|
||||
create_time=tracker.start_time,
|
||||
update_time=time.time(),
|
||||
answer=answer_text or "",
|
||||
chat_id=tracker.chat_id,
|
||||
)
|
||||
return
|
||||
elif judge_type == "END":
|
||||
# 话题转向,增加END计数
|
||||
tracker.consecutive_end_count += 1
|
||||
logger.info(f"话题已转向,连续END次数: {tracker.consecutive_end_count}")
|
||||
|
||||
if tracker.consecutive_end_count >= 2:
|
||||
# 连续两次END,结束跟踪
|
||||
logger.info("连续两次END,结束跟踪")
|
||||
break
|
||||
else:
|
||||
# 第一次END,重置计数器并继续跟踪
|
||||
logger.info("第一次END,继续跟踪")
|
||||
continue
|
||||
elif judge_type == "CONTINUE":
|
||||
# 继续跟踪,重置END计数器
|
||||
tracker.consecutive_end_count = 0
|
||||
continue
|
||||
|
||||
if len(all_msgs) >= max_messages:
|
||||
logger.info("问题跟踪达到100条消息上限,判定为未解答")
|
||||
logger.info(f"追踪结束:{tracker.question}")
|
||||
break
|
||||
|
||||
# 无新消息时稍作等待
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
# 未获取到答案,检查是否需要删除记录
|
||||
# 查找现有的冲突记录
|
||||
existing_conflict = MemoryConflict.get_or_none(
|
||||
MemoryConflict.conflict_content == original_question,
|
||||
MemoryConflict.chat_id == tracker.chat_id
|
||||
)
|
||||
|
||||
if existing_conflict:
|
||||
# 检查raise_time是否大于3且没有答案
|
||||
current_raise_time = getattr(existing_conflict, "raise_time", 0) or 0
|
||||
if current_raise_time > 0 and not existing_conflict.answer:
|
||||
# 删除该条目
|
||||
await self.delete_conflict(original_question, tracker.chat_id)
|
||||
logger.info(f"追踪结束后删除条目(raise_time={current_raise_time}且无答案): {original_question}")
|
||||
else:
|
||||
# 更新记录但不删除
|
||||
await self.add_or_update_conflict(
|
||||
conflict_content=original_question,
|
||||
create_time=existing_conflict.create_time,
|
||||
update_time=time.time(),
|
||||
answer="",
|
||||
chat_id=tracker.chat_id,
|
||||
)
|
||||
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
|
||||
else:
|
||||
# 如果没有现有记录,创建新记录
|
||||
await self.add_or_update_conflict(
|
||||
conflict_content=original_question,
|
||||
create_time=time.time(),
|
||||
update_time=time.time(),
|
||||
answer="",
|
||||
chat_id=tracker.chat_id,
|
||||
)
|
||||
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
|
||||
|
||||
logger.info(f"问题跟踪结束:{original_question}")
|
||||
except Exception as e:
|
||||
logger.error(f"后台问题跟踪任务异常: {e}")
|
||||
finally:
|
||||
# 无论任务成功还是失败,都要从追踪列表中移除
|
||||
tracker.stop()
|
||||
self.remove_tracker(tracker)
|
||||
|
||||
def remove_tracker(self, tracker: QuestionTracker) -> None:
|
||||
"""
|
||||
从追踪列表中移除指定的追踪器
|
||||
|
||||
Args:
|
||||
tracker: 要移除的追踪器对象
|
||||
"""
|
||||
try:
|
||||
if tracker in self.question_tracker_list:
|
||||
self.question_tracker_list.remove(tracker)
|
||||
logger.info(f"已从追踪列表中移除追踪器: {tracker.question}")
|
||||
else:
|
||||
logger.warning(f"尝试移除不存在的追踪器: {tracker.question}")
|
||||
except Exception as e:
|
||||
logger.error(f"移除追踪器时出错: {e}")
|
||||
|
||||
async def add_or_update_conflict(
|
||||
self,
|
||||
conflict_content: str,
|
||||
create_time: float,
|
||||
update_time: float,
|
||||
answer: str = "",
|
||||
context: str = "",
|
||||
chat_id: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
根据conflict_content匹配数据库内容,如果找到相同的就更新update_time和answer,
|
||||
如果没有相同的,就新建一条保存全部内容
|
||||
"""
|
||||
try:
|
||||
# 尝试根据conflict_content查找现有记录
|
||||
existing_conflict = MemoryConflict.get_or_none(
|
||||
MemoryConflict.conflict_content == conflict_content,
|
||||
MemoryConflict.chat_id == chat_id
|
||||
)
|
||||
|
||||
if existing_conflict:
|
||||
# 如果找到相同的conflict_content,更新update_time和answer
|
||||
existing_conflict.update_time = update_time
|
||||
existing_conflict.answer = answer
|
||||
existing_conflict.save()
|
||||
return True
|
||||
else:
|
||||
# 如果没有找到相同的,创建新记录
|
||||
MemoryConflict.create(
|
||||
conflict_content=conflict_content,
|
||||
create_time=create_time,
|
||||
update_time=update_time,
|
||||
answer=answer,
|
||||
context=context,
|
||||
chat_id=chat_id,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
# 记录错误并返回False
|
||||
logger.error(f"添加或更新冲突记录时出错: {e}")
|
||||
return False
|
||||
|
||||
async def record_memory_merge_conflict(self, part2_content: str, chat_id: str = None) -> bool:
|
||||
"""
|
||||
记录记忆整合过程中的冲突内容(part2)
|
||||
|
||||
Args:
|
||||
part2_content: 冲突内容(part2)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功记录
|
||||
"""
|
||||
if not part2_content or part2_content.strip() == "":
|
||||
return False
|
||||
|
||||
prompt = f"""以下是一段有冲突的信息,请你根据这些信息总结出几个具体的提问:
|
||||
冲突信息:
|
||||
{part2_content}
|
||||
|
||||
要求:
|
||||
1.提问必须具体,明确
|
||||
2.提问最好涉及指向明确的事物,而不是代称
|
||||
3.如果缺少上下文,不要强行提问,可以忽略
|
||||
|
||||
请用json格式输出,不要输出其他内容,仅输出提问理由和具体提的提问:
|
||||
**示例**
|
||||
// 理由文本
|
||||
```json
|
||||
{{
|
||||
"question":"提问",
|
||||
}}
|
||||
```
|
||||
```json
|
||||
{{
|
||||
"question":"提问"
|
||||
}}
|
||||
```
|
||||
...提问数量在1-3个之间,不要重复,现在请输出:"""
|
||||
|
||||
question_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_tracker.generate_response_async(prompt)
|
||||
|
||||
# 解析JSON响应
|
||||
questions, reasoning_content = parse_md_json(question_response)
|
||||
|
||||
print(prompt)
|
||||
print(question_response)
|
||||
|
||||
for question in questions:
|
||||
await self.record_conflict(
|
||||
conflict_content=question["question"],
|
||||
context=reasoning_content,
|
||||
start_following=False,
|
||||
chat_id=chat_id,
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_conflict_count(self) -> int:
|
||||
"""
|
||||
获取冲突记录数量
|
||||
|
||||
Returns:
|
||||
int: 记录数量
|
||||
"""
|
||||
try:
|
||||
return MemoryConflict.select().count()
|
||||
except Exception as e:
|
||||
logger.error(f"获取冲突记录数量时出错: {e}")
|
||||
return 0
|
||||
|
||||
async def delete_conflict(self, conflict_content: str, chat_id: str) -> bool:
|
||||
"""
|
||||
删除指定的冲突记录
|
||||
|
||||
Args:
|
||||
conflict_content: 冲突内容
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功删除
|
||||
"""
|
||||
try:
|
||||
conflict = MemoryConflict.get_or_none(
|
||||
MemoryConflict.conflict_content == conflict_content,
|
||||
MemoryConflict.chat_id == chat_id
|
||||
)
|
||||
|
||||
if conflict:
|
||||
conflict.delete_instance()
|
||||
logger.info(f"已删除冲突记录: {conflict_content}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"未找到要删除的冲突记录: {conflict_content}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"删除冲突记录时出错: {e}")
|
||||
return False
|
||||
|
||||
# 全局冲突追踪器实例
|
||||
global_conflict_tracker = ConflictTracker()
|
||||
|
|
@ -1,319 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from src.common.database.database_model import GraphNodes
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("migrate")
|
||||
|
||||
|
||||
async def migrate_memory_items_to_string():
|
||||
"""
|
||||
将数据库中记忆节点的memory_items从list格式迁移到string格式
|
||||
并根据原始list的项目数量设置weight值
|
||||
"""
|
||||
logger.info("开始迁移记忆节点格式...")
|
||||
|
||||
migration_stats = {
|
||||
"total_nodes": 0,
|
||||
"converted_nodes": 0,
|
||||
"already_string_nodes": 0,
|
||||
"empty_nodes": 0,
|
||||
"error_nodes": 0,
|
||||
"weight_updated_nodes": 0,
|
||||
"truncated_nodes": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
# 获取所有图节点
|
||||
all_nodes = GraphNodes.select()
|
||||
migration_stats["total_nodes"] = all_nodes.count()
|
||||
|
||||
logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点")
|
||||
|
||||
for node in all_nodes:
|
||||
try:
|
||||
concept = node.concept
|
||||
memory_items_raw = node.memory_items.strip() if node.memory_items else ""
|
||||
original_weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
|
||||
|
||||
# 如果为空,跳过
|
||||
if not memory_items_raw:
|
||||
migration_stats["empty_nodes"] += 1
|
||||
logger.debug(f"跳过空节点: {concept}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 尝试解析JSON
|
||||
parsed_data = json.loads(memory_items_raw)
|
||||
|
||||
if isinstance(parsed_data, list):
|
||||
# 如果是list格式,需要转换
|
||||
if parsed_data:
|
||||
# 转换为字符串格式
|
||||
new_memory_items = " | ".join(str(item) for item in parsed_data)
|
||||
original_length = len(new_memory_items)
|
||||
|
||||
# 检查长度并截断
|
||||
if len(new_memory_items) > 100:
|
||||
new_memory_items = new_memory_items[:100]
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
new_weight = float(len(parsed_data)) # weight = list项目数量
|
||||
|
||||
# 更新数据库
|
||||
node.memory_items = new_memory_items
|
||||
node.weight = new_weight
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
migration_stats["weight_updated_nodes"] += 1
|
||||
|
||||
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
|
||||
logger.info(
|
||||
f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}"
|
||||
)
|
||||
else:
|
||||
# 空list,设置为空字符串
|
||||
node.memory_items = ""
|
||||
node.weight = 1.0
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
logger.debug(f"转换空list节点: {concept}")
|
||||
|
||||
elif isinstance(parsed_data, str):
|
||||
# 已经是字符串格式,检查长度和weight
|
||||
current_content = parsed_data
|
||||
original_length = len(current_content)
|
||||
content_truncated = False
|
||||
|
||||
# 检查长度并截断
|
||||
if len(current_content) > 100:
|
||||
current_content = current_content[:100]
|
||||
content_truncated = True
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
node.memory_items = current_content
|
||||
logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
# 检查weight是否需要更新
|
||||
update_needed = False
|
||||
if original_weight == 1.0:
|
||||
# 如果weight还是默认值,可以根据内容复杂度估算
|
||||
content_parts = (
|
||||
current_content.split(" | ") if " | " in current_content else [current_content]
|
||||
)
|
||||
estimated_weight = max(1.0, float(len(content_parts)))
|
||||
|
||||
if estimated_weight != original_weight:
|
||||
node.weight = estimated_weight
|
||||
update_needed = True
|
||||
logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}")
|
||||
|
||||
# 如果内容被截断或权重需要更新,保存到数据库
|
||||
if content_truncated or update_needed:
|
||||
node.save()
|
||||
if update_needed:
|
||||
migration_stats["weight_updated_nodes"] += 1
|
||||
if content_truncated:
|
||||
migration_stats["converted_nodes"] += 1 # 算作转换节点
|
||||
else:
|
||||
migration_stats["already_string_nodes"] += 1
|
||||
else:
|
||||
migration_stats["already_string_nodes"] += 1
|
||||
|
||||
else:
|
||||
# 其他JSON类型,转换为字符串
|
||||
new_memory_items = str(parsed_data) if parsed_data else ""
|
||||
original_length = len(new_memory_items)
|
||||
|
||||
# 检查长度并截断
|
||||
if len(new_memory_items) > 100:
|
||||
new_memory_items = new_memory_items[:100]
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
node.memory_items = new_memory_items
|
||||
node.weight = 1.0
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
|
||||
logger.debug(f"转换其他类型节点: {concept}{length_info}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 不是JSON格式,假设已经是纯字符串
|
||||
# 检查是否是带引号的字符串
|
||||
if memory_items_raw.startswith('"') and memory_items_raw.endswith('"'):
|
||||
# 去掉引号
|
||||
clean_content = memory_items_raw[1:-1]
|
||||
original_length = len(clean_content)
|
||||
|
||||
# 检查长度并截断
|
||||
if len(clean_content) > 100:
|
||||
clean_content = clean_content[:100]
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
node.memory_items = clean_content
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
|
||||
logger.debug(f"去除引号节点: {concept}{length_info}")
|
||||
else:
|
||||
# 已经是纯字符串格式,检查长度
|
||||
current_content = memory_items_raw
|
||||
original_length = len(current_content)
|
||||
|
||||
# 检查长度并截断
|
||||
if len(current_content) > 100:
|
||||
current_content = current_content[:100]
|
||||
node.memory_items = current_content
|
||||
node.save()
|
||||
|
||||
migration_stats["converted_nodes"] += 1 # 算作转换节点
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
else:
|
||||
migration_stats["already_string_nodes"] += 1
|
||||
logger.debug(f"已是字符串格式节点: {concept}")
|
||||
|
||||
except Exception as e:
|
||||
migration_stats["error_nodes"] += 1
|
||||
logger.error(f"处理节点 {concept} 时发生错误: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生严重错误: {e}")
|
||||
raise
|
||||
|
||||
# 输出迁移统计
|
||||
logger.info("=== 记忆节点迁移完成 ===")
|
||||
logger.info(f"总节点数: {migration_stats['total_nodes']}")
|
||||
logger.info(f"已转换节点: {migration_stats['converted_nodes']}")
|
||||
logger.info(f"已是字符串格式: {migration_stats['already_string_nodes']}")
|
||||
logger.info(f"空节点: {migration_stats['empty_nodes']}")
|
||||
logger.info(f"错误节点: {migration_stats['error_nodes']}")
|
||||
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
|
||||
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
|
||||
|
||||
success_rate = (
|
||||
(migration_stats["converted_nodes"] + migration_stats["already_string_nodes"])
|
||||
/ migration_stats["total_nodes"]
|
||||
* 100
|
||||
if migration_stats["total_nodes"] > 0
|
||||
else 0
|
||||
)
|
||||
logger.info(f"迁移成功率: {success_rate:.1f}%")
|
||||
|
||||
return migration_stats
|
||||
|
||||
|
||||
async def set_all_person_known():
|
||||
"""
|
||||
将person_info库中所有记录的is_known字段设置为True
|
||||
在设置之前,先清理掉user_id或platform为空的记录
|
||||
"""
|
||||
logger.info("开始设置所有person_info记录为已认识...")
|
||||
|
||||
try:
|
||||
from src.common.database.database_model import PersonInfo
|
||||
|
||||
# 获取所有PersonInfo记录
|
||||
all_persons = PersonInfo.select()
|
||||
total_count = all_persons.count()
|
||||
|
||||
logger.info(f"找到 {total_count} 个人员记录")
|
||||
|
||||
if total_count == 0:
|
||||
logger.info("没有找到任何人员记录")
|
||||
return {"total": 0, "deleted": 0, "updated": 0, "known_count": 0}
|
||||
|
||||
# 删除user_id或platform为空的记录
|
||||
deleted_count = 0
|
||||
invalid_records = PersonInfo.select().where(
|
||||
(PersonInfo.user_id.is_null())
|
||||
| (PersonInfo.user_id == "")
|
||||
| (PersonInfo.platform.is_null())
|
||||
| (PersonInfo.platform == "")
|
||||
)
|
||||
|
||||
# 记录要删除的记录信息
|
||||
for record in invalid_records:
|
||||
user_id_info = f"'{record.user_id}'" if record.user_id else "NULL"
|
||||
platform_info = f"'{record.platform}'" if record.platform else "NULL"
|
||||
person_name_info = f"'{record.person_name}'" if record.person_name else "无名称"
|
||||
logger.debug(
|
||||
f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}"
|
||||
)
|
||||
|
||||
# 执行删除操作
|
||||
deleted_count = (
|
||||
PersonInfo.delete()
|
||||
.where(
|
||||
(PersonInfo.user_id.is_null())
|
||||
| (PersonInfo.user_id == "")
|
||||
| (PersonInfo.platform.is_null())
|
||||
| (PersonInfo.platform == "")
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录")
|
||||
else:
|
||||
logger.info("没有发现user_id或platform为空的记录")
|
||||
|
||||
# 重新获取剩余记录数量
|
||||
remaining_count = PersonInfo.select().count()
|
||||
logger.info(f"清理后剩余 {remaining_count} 个有效记录")
|
||||
|
||||
if remaining_count == 0:
|
||||
logger.info("清理后没有剩余记录")
|
||||
return {"total": total_count, "deleted": deleted_count, "updated": 0, "known_count": 0}
|
||||
|
||||
# 批量更新剩余记录的is_known字段为True
|
||||
updated_count = PersonInfo.update(is_known=True).execute()
|
||||
|
||||
logger.info(f"成功更新 {updated_count} 个人员记录的is_known字段为True")
|
||||
|
||||
# 验证更新结果
|
||||
known_count = PersonInfo.select().where(PersonInfo.is_known).count()
|
||||
|
||||
result = {"total": total_count, "deleted": deleted_count, "updated": updated_count, "known_count": known_count}
|
||||
|
||||
logger.info("=== person_info更新完成 ===")
|
||||
logger.info(f"原始记录数: {result['total']}")
|
||||
logger.info(f"删除记录数: {result['deleted']}")
|
||||
logger.info(f"更新记录数: {result['updated']}")
|
||||
logger.info(f"已认识记录数: {result['known_count']}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新person_info过程中发生错误: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def check_and_run_migrations():
|
||||
# 获取根目录
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
data_dir = os.path.join(project_root, "data")
|
||||
temp_dir = os.path.join(data_dir, "temp")
|
||||
done_file = os.path.join(temp_dir, "done.mem")
|
||||
|
||||
# 检查done.mem是否存在
|
||||
if not os.path.exists(done_file):
|
||||
# 如果temp目录不存在则创建
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
# 执行迁移函数
|
||||
# 依次执行两个异步函数
|
||||
await asyncio.sleep(3)
|
||||
await migrate_memory_items_to_string()
|
||||
await set_all_person_known()
|
||||
# 创建done.mem文件
|
||||
with open(done_file, "w", encoding="utf-8") as f:
|
||||
f.write("done")
|
||||
|
|
@ -22,14 +22,15 @@ def init_prompt():
|
|||
以上是群里正在进行的聊天记录
|
||||
|
||||
{identity_block}
|
||||
你刚刚的情绪状态是:{mood_state}
|
||||
|
||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态
|
||||
你先前的情绪状态是:{mood_state}
|
||||
你的情绪特点是:{emotion_style}
|
||||
|
||||
现在,请你根据先前的情绪状态和现在的聊天内容,总结推断你现在的情绪状态
|
||||
请只输出新的情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"change_mood_prompt",
|
||||
"get_mood_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
|
|
@ -66,37 +67,16 @@ class ChatMood:
|
|||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
async def update_mood_by_message(self, message: MessageRecv):
|
||||
async def get_mood(self) -> str:
|
||||
self.regression_count = 0
|
||||
|
||||
during_last_time = message.message_info.time - self.last_change_time # type: ignore
|
||||
current_time = time.time()
|
||||
|
||||
base_probability = 0.05
|
||||
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
|
||||
|
||||
# 基于消息长度计算基础兴趣度
|
||||
message_length = len(message.processed_plain_text or "")
|
||||
interest_multiplier = min(2.0, 1.0 + message_length / 100)
|
||||
|
||||
logger.debug(
|
||||
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
|
||||
)
|
||||
update_probability = global_config.mood.mood_update_threshold * min(
|
||||
1.0, base_probability * time_multiplier * interest_multiplier
|
||||
)
|
||||
|
||||
if random.random() > update_probability:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 更新情绪状态,更新概率: {update_probability:.2f}"
|
||||
)
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
logger.info(f"{self.log_prefix} 获取情绪状态")
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
timestamp_end=current_time,
|
||||
limit=int(global_config.chat.max_context_size / 3),
|
||||
limit_mode="last",
|
||||
)
|
||||
|
|
@ -119,11 +99,11 @@ class ChatMood:
|
|||
identity_block = f"你的名字是{bot_name}{bot_nickname}"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_mood_prompt",
|
||||
"get_mood_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
identity_block=identity_block,
|
||||
mood_state=self.mood_state,
|
||||
emotion_style=global_config.personality.emotion_style,
|
||||
emotion_style=global_config.mood.emotion_style,
|
||||
)
|
||||
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
|
|
@ -138,7 +118,9 @@ class ChatMood:
|
|||
|
||||
self.mood_state = response
|
||||
|
||||
self.last_change_time = message_time
|
||||
self.last_change_time = current_time
|
||||
|
||||
return response
|
||||
|
||||
async def regress_mood(self):
|
||||
message_time = time.time()
|
||||
|
|
@ -172,7 +154,7 @@ class ChatMood:
|
|||
chat_talking_prompt=chat_talking_prompt,
|
||||
identity_block=identity_block,
|
||||
mood_state=self.mood_state,
|
||||
emotion_style=global_config.personality.emotion_style,
|
||||
emotion_style=global_config.mood.emotion_style,
|
||||
)
|
||||
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
|
|
@ -222,7 +204,6 @@ class MoodManager:
|
|||
if self.task_started:
|
||||
return
|
||||
|
||||
logger.info("启动情绪回归任务...")
|
||||
task = MoodRegressionTask(self)
|
||||
await async_task_manager.add_task(task)
|
||||
self.task_started = True
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@ from src.config.config import global_config, model_config
|
|||
|
||||
logger = get_logger("person_info")
|
||||
|
||||
relation_selection_model = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="relation_selection")
|
||||
relation_selection_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="relation_selection"
|
||||
)
|
||||
|
||||
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
|
|
@ -91,9 +93,10 @@ def extract_categories_from_response(response: str) -> list[str]:
|
|||
"""从response中提取所有<>包裹的内容"""
|
||||
if not isinstance(response, str):
|
||||
return []
|
||||
|
||||
|
||||
import re
|
||||
pattern = r'<([^<>]+)>'
|
||||
|
||||
pattern = r"<([^<>]+)>"
|
||||
matches = re.findall(pattern, response)
|
||||
return matches
|
||||
|
||||
|
|
@ -420,7 +423,7 @@ class Person:
|
|||
except Exception as e:
|
||||
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
|
||||
|
||||
async def build_relationship(self,chat_content:str = "",info_type = ""):
|
||||
async def build_relationship(self, chat_content: str = "", info_type=""):
|
||||
if not self.is_known:
|
||||
return ""
|
||||
# 构建points文本
|
||||
|
|
@ -433,7 +436,7 @@ class Person:
|
|||
|
||||
points_text = ""
|
||||
category_list = self.get_all_category()
|
||||
|
||||
|
||||
if chat_content:
|
||||
prompt = f"""当前聊天内容:
|
||||
{chat_content}
|
||||
|
|
@ -449,11 +452,13 @@ class Person:
|
|||
# print(prompt)
|
||||
# print(response)
|
||||
category_list = extract_categories_from_response(response)
|
||||
if "none" not in category_list:
|
||||
if "none" not in category_list:
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category, 2)
|
||||
if random_memory:
|
||||
random_memory_str = "\n".join([get_memory_content_from_memory(memory) for memory in random_memory])
|
||||
random_memory_str = "\n".join(
|
||||
[get_memory_content_from_memory(memory) for memory in random_memory]
|
||||
)
|
||||
points_text = f"有关 {category} 的内容:{random_memory_str}"
|
||||
break
|
||||
elif info_type:
|
||||
|
|
@ -466,18 +471,19 @@ class Person:
|
|||
<分类1><分类2><分类3>......
|
||||
如果没有相关的分类,请输出<none>"""
|
||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
category_list = extract_categories_from_response(response)
|
||||
if "none" not in category_list:
|
||||
if "none" not in category_list:
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category, 3)
|
||||
if random_memory:
|
||||
random_memory_str = "\n".join([get_memory_content_from_memory(memory) for memory in random_memory])
|
||||
random_memory_str = "\n".join(
|
||||
[get_memory_content_from_memory(memory) for memory in random_memory]
|
||||
)
|
||||
points_text = f"有关 {category} 的内容:{random_memory_str}"
|
||||
break
|
||||
else:
|
||||
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category, 1)[0]
|
||||
if random_memory:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from src.plugin_system.apis import (
|
|||
send_api,
|
||||
tool_api,
|
||||
frequency_api,
|
||||
mood_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
|
|
@ -40,4 +41,5 @@ __all__ = [
|
|||
"register_plugin",
|
||||
"tool_api",
|
||||
"frequency_api",
|
||||
"mood_api",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -309,6 +309,7 @@ async def store_action_info(
|
|||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
action_reasoning: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
|
|
@ -322,7 +323,7 @@ async def store_action_info(
|
|||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
|
||||
action_reasoning: 动作执行理由
|
||||
Returns:
|
||||
Dict[str, Any]: 保存的记录数据
|
||||
None: 如果保存失败
|
||||
|
|
@ -348,6 +349,7 @@ async def store_action_info(
|
|||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_done": action_done,
|
||||
"action_reasoning": action_reasoning,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,11 +9,14 @@
|
|||
"""
|
||||
|
||||
import random
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager, EMOJI_DIR
|
||||
from src.chat.utils.utils_image import image_path_to_base64, base64_to_image
|
||||
|
||||
logger = get_logger("emoji_api")
|
||||
|
||||
|
|
@ -245,6 +248,42 @@ def get_emotions() -> List[str]:
|
|||
return []
|
||||
|
||||
|
||||
async def get_all() -> List[Tuple[str, str, str]]:
|
||||
"""获取所有表情包
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis = emoji_manager.emoji_objects
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("[EmojiAPI] 没有可用的表情包")
|
||||
return []
|
||||
|
||||
results = []
|
||||
for emoji_obj in all_emojis:
|
||||
if emoji_obj.is_deleted:
|
||||
continue
|
||||
|
||||
emoji_base64 = image_path_to_base64(emoji_obj.full_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
|
||||
continue
|
||||
|
||||
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
|
||||
results.append((emoji_base64, emoji_obj.description, matched_emotion))
|
||||
|
||||
logger.debug(f"[EmojiAPI] 成功获取 {len(results)} 个表情包")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取所有表情包失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_descriptions() -> List[str]:
|
||||
"""获取所有表情包描述
|
||||
|
||||
|
|
@ -264,3 +303,403 @@ def get_descriptions() -> List[str]:
|
|||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包注册API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""注册新的表情包
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
filename: 可选的文件名,如果未提供则自动生成
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 注册结果,包含以下字段:
|
||||
- success: bool, 是否成功注册
|
||||
- message: str, 结果消息
|
||||
- description: Optional[str], 表情包描述(成功时)
|
||||
- emotions: Optional[List[str]], 情感标签列表(成功时)
|
||||
- replaced: Optional[bool], 是否替换了旧表情包(成功时)
|
||||
- hash: Optional[str], 表情包哈希值(成功时)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果base64为空或无效
|
||||
TypeError: 如果参数类型不正确
|
||||
"""
|
||||
if not image_base64:
|
||||
raise ValueError("图片base64编码不能为空")
|
||||
if not isinstance(image_base64, str):
|
||||
raise TypeError("image_base64必须是字符串类型")
|
||||
if filename is not None and not isinstance(filename, str):
|
||||
raise TypeError("filename必须是字符串类型或None")
|
||||
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}")
|
||||
|
||||
# 1. 获取emoji管理器并检查容量
|
||||
emoji_manager = get_emoji_manager()
|
||||
count_before = emoji_manager.emoji_num
|
||||
max_count = emoji_manager.emoji_num_max
|
||||
|
||||
# 2. 检查是否可以注册(未达到上限或启用替换)
|
||||
can_register = count_before < max_count or (
|
||||
count_before >= max_count and emoji_manager.emoji_num_max_reach_deletion
|
||||
)
|
||||
|
||||
if not can_register:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
# 3. 确保emoji目录存在
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
# 4. 生成文件名
|
||||
if not filename:
|
||||
# 基于时间戳、微秒和短base64生成唯一文件名
|
||||
import time
|
||||
|
||||
timestamp = int(time.time())
|
||||
microseconds = int(time.time() * 1000000) % 1000000 # 添加微秒级精度
|
||||
|
||||
# 生成12位随机标识符,使用base64编码(增加随机性)
|
||||
import random
|
||||
|
||||
random_bytes = random.getrandbits(72).to_bytes(9, "big") # 72位 = 9字节 = 12位base64
|
||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
|
||||
# 确保base64编码适合文件名(替换/和-)
|
||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
||||
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
|
||||
|
||||
# 确保文件名有扩展名
|
||||
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
|
||||
filename = f"{filename}.png" # 默认使用png格式
|
||||
|
||||
# 检查文件名是否已存在,如果存在则重新生成短标识符
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
attempts = 0
|
||||
max_attempts = 10
|
||||
while os.path.exists(temp_file_path) and attempts < max_attempts:
|
||||
# 重新生成短标识符
|
||||
import random
|
||||
|
||||
random_bytes = random.getrandbits(48).to_bytes(6, "big")
|
||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
|
||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
||||
|
||||
# 分离文件名和扩展名,重新生成文件名
|
||||
name_part, ext = os.path.splitext(filename)
|
||||
# 去掉原来的标识符,添加新的
|
||||
base_name = name_part.rsplit("_", 1)[0] # 移除最后一个_后的部分
|
||||
filename = f"{base_name}_{short_id}{ext}"
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
attempts += 1
|
||||
|
||||
# 如果还是冲突,使用UUID作为备用方案
|
||||
if os.path.exists(temp_file_path):
|
||||
uuid_short = str(uuid.uuid4())[:8]
|
||||
name_part, ext = os.path.splitext(filename)
|
||||
base_name = name_part.rsplit("_", 1)[0]
|
||||
filename = f"{base_name}_{uuid_short}{ext}"
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
|
||||
# 如果UUID方案也冲突,添加序号
|
||||
counter = 1
|
||||
original_filename = filename
|
||||
while os.path.exists(temp_file_path):
|
||||
name_part, ext = os.path.splitext(original_filename)
|
||||
filename = f"{name_part}_{counter}{ext}"
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
counter += 1
|
||||
|
||||
# 防止无限循环,最多尝试100次
|
||||
if counter > 100:
|
||||
logger.error(f"[EmojiAPI] 无法生成唯一文件名,尝试次数过多: {original_filename}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "无法生成唯一文件名,请稍后重试",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
# 5. 保存base64图片到emoji目录
|
||||
|
||||
try:
|
||||
# 解码base64并保存图片
|
||||
if not base64_to_image(image_base64, temp_file_path):
|
||||
logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "无法保存图片文件",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
logger.debug(f"[EmojiAPI] 图片已保存到临时文件: {temp_file_path}")
|
||||
|
||||
except Exception as save_error:
|
||||
logger.error(f"[EmojiAPI] 保存图片文件失败: {save_error}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"保存图片文件失败: {str(save_error)}",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
# 6. 调用注册方法
|
||||
register_success = await emoji_manager.register_emoji_by_filename(filename)
|
||||
|
||||
# 7. 清理临时文件(如果注册失败但文件还存在)
|
||||
if not register_success and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"[EmojiAPI] 已清理临时文件: {temp_file_path}")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"[EmojiAPI] 清理临时文件失败: {cleanup_error}")
|
||||
|
||||
# 8. 构建返回结果
|
||||
if register_success:
|
||||
count_after = emoji_manager.emoji_num
|
||||
replaced = count_after <= count_before # 如果数量没增加,说明是替换
|
||||
|
||||
# 尝试获取新注册的表情包信息
|
||||
new_emoji_info = None
|
||||
if count_after > count_before or replaced:
|
||||
# 获取最新的表情包信息
|
||||
try:
|
||||
# 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变)
|
||||
for emoji_obj in reversed(emoji_manager.emoji_objects):
|
||||
if not emoji_obj.is_deleted and (
|
||||
emoji_obj.filename == filename # 直接匹配
|
||||
or (hasattr(emoji_obj, "full_path") and filename in emoji_obj.full_path) # 路径包含匹配
|
||||
):
|
||||
new_emoji_info = emoji_obj
|
||||
break
|
||||
except Exception as find_error:
|
||||
logger.warning(f"[EmojiAPI] 查找新注册表情包信息失败: {find_error}")
|
||||
|
||||
description = new_emoji_info.description if new_emoji_info else None
|
||||
emotions = new_emoji_info.emotion if new_emoji_info else None
|
||||
emoji_hash = new_emoji_info.hash if new_emoji_info else None
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
|
||||
"description": description,
|
||||
"emotions": emotions,
|
||||
"replaced": replaced,
|
||||
"hash": emoji_hash,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 注册表情包时发生异常: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"注册过程中发生错误: {str(e)}",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包删除API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def delete_emoji(emoji_hash: str) -> Dict[str, Any]:
|
||||
"""删除表情包
|
||||
|
||||
Args:
|
||||
emoji_hash: 要删除的表情包的哈希值
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 删除结果,包含以下字段:
|
||||
- success: bool, 是否成功删除
|
||||
- message: str, 结果消息
|
||||
- count_before: Optional[int], 删除前的表情包数量
|
||||
- count_after: Optional[int], 删除后的表情包数量
|
||||
- description: Optional[str], 被删除的表情包描述(成功时)
|
||||
- emotions: Optional[List[str]], 被删除的表情包情感标签(成功时)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果哈希值为空
|
||||
TypeError: 如果哈希值不是字符串类型
|
||||
"""
|
||||
if not emoji_hash:
|
||||
raise ValueError("表情包哈希值不能为空")
|
||||
if not isinstance(emoji_hash, str):
|
||||
raise TypeError("emoji_hash必须是字符串类型")
|
||||
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}")
|
||||
|
||||
# 1. 获取emoji管理器和删除前的数量
|
||||
emoji_manager = get_emoji_manager()
|
||||
count_before = emoji_manager.emoji_num
|
||||
|
||||
# 2. 获取被删除表情包的信息(用于返回结果)
|
||||
try:
|
||||
deleted_emoji = await emoji_manager.get_emoji_from_manager(emoji_hash)
|
||||
description = deleted_emoji.description if deleted_emoji else None
|
||||
emotions = deleted_emoji.emotion if deleted_emoji else None
|
||||
except Exception as info_error:
|
||||
logger.warning(f"[EmojiAPI] 获取被删除表情包信息失败: {info_error}")
|
||||
description = None
|
||||
emotions = None
|
||||
|
||||
# 3. 执行删除操作
|
||||
delete_success = await emoji_manager.delete_emoji(emoji_hash)
|
||||
|
||||
# 4. 获取删除后的数量
|
||||
count_after = emoji_manager.emoji_num
|
||||
|
||||
# 5. 构建返回结果
|
||||
if delete_success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"表情包删除成功 (哈希: {emoji_hash[:8]}...)",
|
||||
"count_before": count_before,
|
||||
"count_after": count_after,
|
||||
"description": description,
|
||||
"emotions": emotions,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "表情包删除失败,可能因为哈希值不存在或删除过程出错",
|
||||
"count_before": count_before,
|
||||
"count_after": count_after,
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 删除表情包时发生异常: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"删除过程中发生错误: {str(e)}",
|
||||
"count_before": None,
|
||||
"count_after": None,
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
}
|
||||
|
||||
|
||||
async def delete_emoji_by_description(description: str, exact_match: bool = False) -> Dict[str, Any]:
|
||||
"""根据描述删除表情包
|
||||
|
||||
Args:
|
||||
description: 表情包描述文本
|
||||
exact_match: 是否精确匹配描述,False则为模糊匹配
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 删除结果,包含以下字段:
|
||||
- success: bool, 是否成功删除
|
||||
- message: str, 结果消息
|
||||
- deleted_count: int, 删除的表情包数量
|
||||
- deleted_hashes: List[str], 被删除的表情包哈希列表
|
||||
- matched_count: int, 匹配到的表情包数量
|
||||
|
||||
Raises:
|
||||
ValueError: 如果描述为空
|
||||
TypeError: 如果描述不是字符串类型
|
||||
"""
|
||||
if not description:
|
||||
raise ValueError("描述不能为空")
|
||||
if not isinstance(description, str):
|
||||
raise TypeError("description必须是字符串类型")
|
||||
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis = emoji_manager.emoji_objects
|
||||
|
||||
# 筛选匹配的表情包
|
||||
matching_emojis = []
|
||||
for emoji_obj in all_emojis:
|
||||
if emoji_obj.is_deleted:
|
||||
continue
|
||||
|
||||
if exact_match:
|
||||
if emoji_obj.description == description:
|
||||
matching_emojis.append(emoji_obj)
|
||||
else:
|
||||
if description.lower() in emoji_obj.description.lower():
|
||||
matching_emojis.append(emoji_obj)
|
||||
|
||||
matched_count = len(matching_emojis)
|
||||
if matched_count == 0:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"未找到匹配描述 '{description}' 的表情包",
|
||||
"deleted_count": 0,
|
||||
"deleted_hashes": [],
|
||||
"matched_count": 0,
|
||||
}
|
||||
|
||||
# 删除匹配的表情包
|
||||
deleted_count = 0
|
||||
deleted_hashes = []
|
||||
for emoji_obj in matching_emojis:
|
||||
try:
|
||||
delete_success = await emoji_manager.delete_emoji(emoji_obj.hash)
|
||||
if delete_success:
|
||||
deleted_count += 1
|
||||
deleted_hashes.append(emoji_obj.hash)
|
||||
except Exception as delete_error:
|
||||
logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.hash}): {delete_error}")
|
||||
|
||||
# 构建返回结果
|
||||
if deleted_count > 0:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"成功删除 {deleted_count} 个表情包 (匹配到 {matched_count} 个)",
|
||||
"deleted_count": deleted_count,
|
||||
"deleted_hashes": deleted_hashes,
|
||||
"matched_count": matched_count,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"匹配到 {matched_count} 个表情包,但删除全部失败",
|
||||
"deleted_count": 0,
|
||||
"deleted_hashes": [],
|
||||
"matched_count": matched_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 根据描述删除表情包时发生异常: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"删除过程中发生错误: {str(e)}",
|
||||
"deleted_count": 0,
|
||||
"deleted_hashes": [],
|
||||
"matched_count": 0,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ from src.chat.frequency_control.frequency_control import frequency_control_manag
|
|||
|
||||
logger = get_logger("frequency_api")
|
||||
|
||||
|
||||
def get_current_talk_frequency(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()
|
||||
|
||||
|
||||
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
|
||||
frequency_control_manager.get_or_create_frequency_control(
|
||||
chat_id
|
||||
).set_talk_frequency_adjust(talk_frequency_adjust)
|
||||
frequency_control_manager.get_or_create_frequency_control(chat_id).set_talk_frequency_adjust(talk_frequency_adjust)
|
||||
|
||||
|
||||
def get_talk_frequency_adjust(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ async def generate_reply(
|
|||
enable_chinese_typo: bool = True,
|
||||
request_type: str = "generator_api",
|
||||
from_plugin: bool = True,
|
||||
reply_time_point: Optional[float] = None,
|
||||
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||
"""生成回复
|
||||
|
||||
|
|
@ -109,6 +110,7 @@ async def generate_reply(
|
|||
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
||||
request_type: 请求类型(可选,记录LLM使用)
|
||||
from_plugin: 是否来自插件
|
||||
reply_time_point: 回复时间点
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
|
||||
"""
|
||||
|
|
@ -136,6 +138,7 @@ async def generate_reply(
|
|||
reply_reason=reply_reason,
|
||||
from_plugin=from_plugin,
|
||||
stream_id=chat_stream.stream_id if chat_stream else chat_id,
|
||||
reply_time_point=reply_time_point,
|
||||
)
|
||||
if not success:
|
||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue