Merge pull request #1465 from Mai-with-u/dev

Dev0.12.1
main
SengokuCola 2025-12-31 14:28:57 +08:00 committed by GitHub
commit 0d685806a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
123 changed files with 10440 additions and 5773 deletions

View File

@ -46,7 +46,7 @@
## 🔥 更新和安装
**最新版本: v0.11.6** ([更新日志](changelogs/changelog.md))
**最新版本: v0.12.0** ([更新日志](changelogs/changelog.md))
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本

1
bot.py
View File

@ -59,6 +59,7 @@ def run_runner_process():
while True:
logger.info(f"正在启动 {script_file}...")
logger.info("正在编译着色器1/114514")
# 启动子进程 (Worker)
# 使用 sys.executable 确保使用相同的 Python 解释器

View File

@ -1,6 +1,18 @@
# Changelog
## [0.12.1] - 2025-12-31
### 🌟 主要更新
- 添加年度总结可以在webui查看
- 可选让llm判定引用回复
- 表达方式优化!现在可以进行自动和手动评估,使其更精准
- 回复和规划记录webui可以查看每一条回复和plan的详情
## [0.12.0] - 2025-12-16
### 细节功能更改
- 优化间隔过长消息的显示
- enable_jargon_detection
- global_memory_blacklist。指定部分群聊不参与全局记忆
- 移除utils_small模型移除弃用的lpmm模型
## [0.12.0] - 2025-12-21
### 🌟 重大更新
- 添加思考力度机制,动态控制回复时间和长度
- planner和replyer现在开启联动更好的回复逻辑
@ -9,12 +21,32 @@
- mcp插件作为内置插件加入默认不启用
- 添加全局记忆配置项,现在可以选择让记忆为全局的
### 🌟 WebUI 重大更新
- **模型预设市场功能正式完善并发布**:现在可以将模型配置完整分享,分享按钮位于模型配置界面右上角
- **全面安全加固**:为所有 WebUI API 和 WebSocket 端点添加身份认证保护Cookie 添加 Secure 和 SameSite 属性,支持环境感知动态配置
- **前端认证重构**:从 localStorage 迁移到 HttpOnly Cookie新增 WebSocket 临时 token 认证机制,解决跨域开发环境下 Cookie 无法携带的问题
- **增强插件配置管理**:支持原始 TOML 配置的加载和保存,前端支持查看和编辑插件配置文件源文件
### 细节功能更改
- 移除频率自动调整
- 移除情绪功能
- 优化记忆差许多呢超时设置
- 部分配置为0的bug
- 插件安装时可以主动选择克隆的分支
- 首页中反馈问卷功能,可以提交反馈信息和建议信息
- 黑话和表达不再提取包含名称的内容
- 模型界面支持编辑 extra params 额外字段
- 模型界面中的任务分配子界面支持编辑慢请求检测阈值
- 模型界面中支持对单个模型单独指定温度参数和 max tokens 参数
- 首页所有数据卡片支持自动选择单位+显示详细信息功能
- WebUI 聊天室表情包、图片、富文本消息支持
- 麦麦适配器配置界面的工作模式支持折叠
- WebUI 插件配置解析支持动态 list 表单
- WebUI 插件配置中的动态 list 支持开关、滑块和下拉框类型
- 在插件商场、插件配置详情界面增加了重启按钮
- 加强安全性和隐私保护:添加登录接口速率限制,防止暴力破解攻击,收紧 CORS 配置(限制允许的 HTTP 方法和请求头完善路径校验validate_safe_path 防止目录穿越攻击fetchWithAuth 支持 FormData 文件上传,新增 robots.txt 路由和 X-Robots-Tag 响应头防止搜索引擎索引,前端添加 meta noindex/nofollow 标签阻止爬虫收录
- 修复并优化了聊天室、模型配置、日志查看器、黑话管理、WebUI 端口占用、配置向导、首页图表、聊天室消息重复、移动端日志不可见、模型提供商删除、主程序配置换行符、HTTP 警告横幅、重启界面、LPMM 配置、人物信息、插件端点安全认证、WebSocket token 等问题,提升整体稳定性与体验。
- 完成主程序配置与模型配置界面重构、模型提供商与麦麦适配器配置重构(引入 TOML 校验、WebSocket 认证逻辑抽取为共享模块统一 WS 端点,升级 React 到 19.2.1 并更新依赖WebUI 配置与可视化全部迁移到主配置及模型配置中,优化配置更新提示、插件详情页面和路径安全校验,并增强模型与梦境等多项配置的可视化和自动检验。
## [0.11.6] - 2025-12-2
### 🌟 重大更新

View File

@ -85,7 +85,6 @@ Action采用**两层决策机制**来优化性能和决策质量:
| ----------- | ---------------------------------------- | ---------------------- |
| [`NEVER`](#never-激活) | 从不激活Action对麦麦不可见 | 临时禁用某个Action |
| [`ALWAYS`](#always-激活) | 永远激活Action总是在麦麦的候选池中 | 核心功能,如回复、不回复 |
| [`LLM_JUDGE`](#llm_judge-激活) | 通过LLM智能判断当前情境是否需要激活此Action | 需要智能判断的复杂场景 |
| `RANDOM` | 基于随机概率决定是否激活 | 增加行为随机性的功能 |
| `KEYWORD` | 当检测到特定关键词时激活 | 明确触发条件的功能 |
@ -117,30 +116,6 @@ class AlwaysActivatedAction(BaseAction):
return True, "执行了核心功能"
```
#### `LLM_JUDGE` 激活
`ActionActivationType.LLM_JUDGE`会使得这个 Action 根据 LLM 的判断来决定是否加入候选池。
而 LLM 的判断是基于代码中预设的`llm_judge_prompt`和自动提供的聊天上下文进行的。
因此使用此种方法需要实现`llm_judge_prompt`属性。
```python
class LLMJudgedAction(BaseAction):
activation_type = ActionActivationType.LLM_JUDGE # 通过LLM判断激活
# LLM判断提示词
llm_judge_prompt = (
"判定是否需要使用这个动作的条件:\n"
"1. 用户希望调用XXX这个动作\n"
"...\n"
"请回答\"是\"或\"否\"。\n"
)
async def execute(self) -> Tuple[bool, str]:
# 根据LLM判断是否执行
return True, "执行了LLM判断功能"
```
#### `RANDOM` 激活
`ActionActivationType.RANDOM`会使得这个 Action 根据随机概率决定是否加入候选池。

View File

@ -1,2 +0,0 @@
!napcat
!.env

View File

@ -1,45 +0,0 @@
stages:
- build-image
- package-helm-chart
# 仅在helm-chart分支运行
workflow:
rules:
- if: '$CI_COMMIT_BRANCH == "helm-chart"'
- when: never
# 构建并推送processor镜像
build-preprocessor:
stage: build-image
image: reg.mikumikumi.xyz/base/kaniko-builder:latest
variables:
BUILD_NO_CACHE: true
rules:
- changes:
- helm-chart/preprocessor/**
script:
- export BUILD_CONTEXT=helm-chart/preprocessor
- export TMP_DST=reg.mikumikumi.xyz/maibot/preprocessor
- export CHART_VERSION=$(cat helm-chart/Chart.yaml | grep '^version:' | cut -d' ' -f2)
- export BUILD_DESTINATION="${TMP_DST}:${CHART_VERSION}"
- export BUILD_ARGS="--destination ${TMP_DST}:latest"
- build
# 打包并推送helm chart
package-helm-chart:
stage: package-helm-chart
image: reg.mikumikumi.xyz/mirror/helm:latest
rules:
- changes:
- helm-chart/files/**
- helm-chart/templates/**
- helm-chart/.gitignore
- helm-chart/.helmignore
- helm-chart/Chart.yaml
- helm-chart/README.md
- helm-chart/values.yaml
script:
- export CHART_VERSION=$(cat helm-chart/Chart.yaml | grep '^version:' | cut -d' ' -f2)
- helm registry login reg.mikumikumi.xyz --username ${CI_REGISTRY_USER} --password ${CI_REGISTRY_PASSWORD}
- helm package helm-chart
- helm push maibot-${CHART_VERSION}.tgz oci://reg.mikumikumi.xyz/maibot

View File

@ -1,2 +0,0 @@
preprocessor
.gitlab-ci.yml

View File

@ -1,6 +0,0 @@
apiVersion: v2
name: maibot
description: "Maimai Bot, a cyber friend dedicated to group chats"
type: application
version: 0.12.0
appVersion: 0.12.0

View File

@ -1,118 +0,0 @@
# MaiBot Helm Chart
这是麦麦的Helm Chart可以方便地将麦麦部署在Kubernetes集群中。
当前Helm Chart对应的麦麦版本可以在`Chart.yaml`中查看`appVersion`项。
详细部署文档:[Kubernetes 部署](https://docs.mai-mai.org/manual/deployment/mmc_deploy_kubernetes.html)
## 可用的Helm Chart版本列表
| Helm Chart版本 | 对应的MaiBot版本 | Commit SHA |
|----------------|--------------|------------------------------------------|
| 0.12.0 | 0.12.0 | baa6e90be7b20050fe25dfc74c0c70653601d00e |
| 0.11.6-beta | 0.11.6-beta | 0bfff0457e6db3f7102fb7f77c58d972634fc93c |
| 0.11.5-beta | 0.11.5-beta | ad2df627001f18996802f23c405b263e78af0d0f |
| 0.11.3-beta | 0.11.3-beta | cd6dc18f546f81e08803d3b8dba48e504dad9295 |
| 0.11.2-beta | 0.11.2-beta | d3c8cea00dbb97f545350f2c3d5bcaf252443df2 |
| 0.11.1-beta | 0.11.1-beta | 94e079a340a43dff8a2bc178706932937fc10b11 |
| 0.11.0-beta | 0.11.0-beta | 16059532d8ef87ac28e2be0838ff8b3a34a91d0f |
| 0.10.3-beta | 0.10.3-beta | 7618937cd4fd0ab1a7bd8a31ab244a8b0742fced |
| 0.10.0-alpha.0 | 0.10.0-alpha | 4efebed10aad977155d3d9e0c24bc6e14e1260ab |
## TL; DR
```shell
helm install maimai \
oci://reg.mikumikumi.xyz/maibot/maibot \
--namespace bot \
--version <MAIBOT_VERSION> \
--values maibot.yaml
```
## Values项说明
`values.yaml`分为几个大部分。
1. `EULA` & `PRIVACY`: 用户必须同意这里的协议才能成功部署麦麦。
2. `pre_processor`: 部署之前的预处理Job的配置。
3. `adapter`: 麦麦的Adapter的部署配置。
4. `core`: 麦麦本体的部署配置。
5. `statistics_dashboard`: 麦麦的运行统计看板部署配置。
麦麦每隔一段时间会自动输出html格式的运行统计报告此统计报告可以部署为看板。
出于隐私考虑,默认禁用。
6. `napcat`: Napcat的部署配置。
考虑到复用外部Napcat实例的情况Napcat部署已被解耦。用户可选是否要部署Napcat。
默认会捆绑部署Napcat。
7. `sqlite_web`: sqlite-web的部署配置。
通过sqlite-web可以在网页上操作麦麦的数据库方便调试。不部署对麦麦的运行无影响。
此服务如果暴露在公网会十分危险,默认不会部署。
8. `config`: 这里填写麦麦各部分组件的运行配置。
这里填写的配置仅会在初次部署时或用户指定时覆盖实际配置文件且需要严格遵守yaml文件的缩进格式。
- `override_*_config`: 指定本次部署/升级是否用以下配置覆盖实际配置文件。默认不覆盖。
- `adapter_config`: 对应adapter的`config.toml`。
此配置文件中对于`napcat_server`和`maibot_server`的`host`和`port`字段的配置会被上面`adapter.service`中的配置覆盖,因此不需要改动。
- `core_model_config`: 对应core的`model_config.toml`。
- `core_bot_config`: 对应core的`bot_config.toml`。
## 部署说明
使用此Helm Chart的一些注意事项。
### 麦麦的配置
要修改麦麦的配置最好的方法是通过WebUI来操作。此处的配置只会在初次部署时或者指定覆盖时注入到MaiBot中。
`0.11.6-beta`之前的版本将配置存储于k8s的ConfigMap资源中。随着版本迭代MaiBot对配置文件的操作复杂性增加k8s的适配复杂度也同步增加且WebUI可以直接修改配置文件因此自`0.11.6-beta`版本开始各组件的配置不再存储于k8s的ConfigMap中而是直接存储于存储卷的实际文件中。
从旧版本升级的用户旧的ConfigMap的配置会自动迁移到新的存储卷的配置文件中。
### 部署时自动重置的配置
adapter的配置中的`napcat_server`和`maibot_server`的`host`和`port`字段,会在每次部署/更新Helm安装实例时被自动重置。
core的配置中的`webui`和`maim_message`的部分字段也会在每次部署/更新Helm安装实例时被自动重置。
自动重置的原因:
- core的Service的DNS名称是动态的由安装实例名拼接无法在adapter的配置文件中提前确定。
- 为了使adapter监听所有地址以及保持Helm Chart中配置的端口号需要在adapter的配置文件中覆盖这些配置。
- core的WebUI启停需要由helm chart控制以便正常创建Service和Ingress资源。
- core的maim_message的api server现在可以作为k8s服务暴露出来。监听的IP和端口需要由helm chart控制以便Service正确映射。
首次部署时预处理任务会负责重置这些配置。这会需要一些时间因此部署进程可能比较慢且部分Pod可能会无法启动等待一分钟左右即可。
### 跨节点PVC挂载问题
MaiBot的一些组件会挂载同一PVC这主要是为了同步数据或修改配置。
如果k8s集群有多个节点且共享相同PVC的Pod未调度到同一节点那么就需要此PVC访问模式具备`ReadWriteMany`访问模式。
不是所有存储控制器都支持`ReadWriteMany`访问模式。
如果你的存储控制器无法支持`ReadWriteMany`访问模式,你可以通过`nodeSelector`配置将彼此之间共享相同PVC的Pod调度到同一节点来避免问题。
会共享PVC的组件列表
- `core`和`adapter`:共享`adapter-config`,用于为`core`的WebUI提供修改adapter的配置文件的能力。
- `core`和`statistics-dashboard`:共享`statistics-dashboard`用于同步统计数据的html文件。
- `core`和`sqlite-web`:共享`maibot-core`,用于为`sqlite-web`提供操作MaiBot数据库的能力。
- 部署时的预处理任务`preprocessor`和`adapter`、`core`:共享`adapter-config`和`core-config`,用于初始化`core`和`adapter`的配置文件。

View File

@ -1,4 +0,0 @@
HOST=0.0.0.0
PORT=8000
WEBUI_HOST=0.0.0.0
WEBUI_PORT=8001

View File

@ -1,36 +0,0 @@
#!/bin/sh
# 此脚本用于覆盖core容器的默认启动命令进行一些初始化
# 由于k8s与docker-compose的卷挂载方式有所不同需要利用此脚本为一些文件和目录提前创建好软链接
# /MaiMBot/data是麦麦数据的实际挂载路径
# /MaiMBot/statistics是统计数据的实际挂载路径
set -e
echo "[K8s Init] Preparing volume..."
# 初次启动,在存储卷中检查并创建关键文件和目录
mkdir -p /MaiMBot/data/plugins
mkdir -p /MaiMBot/data/logs
if [ ! -d "/MaiMBot/statistics" ]
then
echo "[K8s Init] Statistics volume is disabled."
else
touch /MaiMBot/statistics/index.html
fi
# 删除默认插件目录,准备创建用户插件目录软链接
rm -rf /MaiMBot/plugins
# 创建软链接,从存储卷链接到实际位置
ln -s /MaiMBot/data/plugins /MaiMBot/plugins
ln -s /MaiMBot/data/logs /MaiMBot/logs
if [ -f "/MaiMBot/statistics/index.html" ]
then
ln -s /MaiMBot/statistics/index.html /MaiMBot/maibot_statistics.html
fi
echo "[K8s Init] Volume ready."
# 启动麦麦
echo "[K8s Init] Waking up MaiBot..."
echo
exec python bot.py

View File

@ -1,12 +0,0 @@
# 此镜像用于在部署helm chart时动态生成adapter的配置文件
FROM python:3.13-slim
WORKDIR /app
ENV PYTHONUNBUFFERED=1
COPY . /app
RUN pip3 install --no-cache-dir -r requirements.txt
ENTRYPOINT ["python3", "preprocessor.py"]

View File

@ -1,266 +0,0 @@
#!/bin/python3
# 此脚本会被helm chart的post-install hook触发在正式部署后通过k8s的job自动运行一次。
# 这个脚本的作用是在部署helm chart时迁移旧版ConfigMap到配置文件调整adapter的配置文件中的服务监听和服务连接字段调整core的配置文件中的maim_message_api_server和WebUI配置。
#
# - 迁移旧版ConfigMap到配置文件是因为0.11.6-beta之前版本的helm chart将各个配置文件存储在k8s的ConfigMap中
# 由于功能复杂度提升自0.11.6-beta版本开始配置文件采用文件形式存储到存储卷中。
# 从旧版升级来的用户会通过这个脚本自动执行配置的迁移。
#
# - 需要调整adapter的配置文件的原因是:
# 1. core的Service的DNS名称是动态的由安装实例名拼接无法在adapter的配置文件中提前确定。
# 用于对外连接的maibot_server.host和maibot_server.port字段会被替换为core的Service对应的DNS名称和8000端口硬编码用户无需配置
# 2. 为了使adapter监听所有地址以及保持chart中配置的端口号需要在adapter的配置文件中覆盖这些配置。
# 用于监听的napcat_server.host和napcat_server.port字段会被替换为0.0.0.0和8095端口实际映射到的Service端口会在Service中配置
#
# - 需要调整core的配置文件的原因是
# 1. core的WebUI启停需要由helm chart控制以便正常创建Service和Ingress资源。
# 配置文件中的webui.enabled、webui.allowed_ips将由此脚本覆盖为正确配置。
# 2. core的maim_message的api server现在可以作为k8s服务暴露出来。监听的IP和端口需要由helm chart控制以便Service正确映射。
# 配置文件中的maim_message.enable_api_server、maim_message.api_server_host、maim_message.api_server_port将由此脚本覆盖为正确配置。
import os
import toml
import time
import base64
from kubernetes import client, config
from kubernetes.client.exceptions import ApiException
from datetime import datetime, timezone
config.load_incluster_config()
core_api = client.CoreV1Api()
apps_api = client.AppsV1Api()
# 读取部署的关键信息
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", 'r') as f:
namespace = f.read().strip()
release_name = os.getenv("RELEASE_NAME").strip()
is_webui_enabled = os.getenv("IS_WEBUI_ENABLED").lower() == "true"
is_maim_message_api_server_enabled = os.getenv("IS_MMSG_ENABLED").lower() == "true"
config_adapter_b64 = os.getenv("CONFIG_ADAPTER_B64")
config_core_env_b64 = os.getenv("CONFIG_CORE_ENV_B64")
config_core_bot_b64 = os.getenv("CONFIG_CORE_BOT_B64")
config_core_model_b64 = os.getenv("CONFIG_CORE_MODEL_B64")
def log(func: str, msg: str, level: str = 'INFO'):
print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] [{level}] [{func}] {msg}')
def migrate_old_config():
"""迁移旧版配置"""
func_name = 'migrate_old_config'
log(func_name, 'Checking whether there are old configmaps to migrate...')
old_configmap_version = None
status_migrating = { # 存储adapter的config.toml、core的bot_config.toml和model_config.toml三个文件的迁移状态
'adapter_config.toml': False,
'core_bot_config.toml': False,
'core_model_config.toml': False
}
# 如果存储卷中已存在配置文件,则跳过迁移
if os.path.isfile('/app/config/core/bot_config.toml') or os.path.isfile('/app/config/core/model_config.toml') or \
os.path.isfile('/app/config/adapter/config.toml'):
log(func_name, 'Found existing config file(s) in PV. Migration will be ignored. Done.')
return
def migrate_cm_to_file(cm_name: str, key_name: str, file_path: str) -> bool:
"""检测是否有指定名称的configmap如果有的话备份到指定的配置文件里并删除configmap返回是否已备份"""
try:
cm = core_api.read_namespaced_config_map(
name=cm_name,
namespace=namespace
)
log(func_name, f'\tMigrating `{key_name}` of `{cm_name}`...')
with open(file_path, 'w', encoding='utf-8') as _f:
_f.write(cm.data[key_name])
core_api.delete_namespaced_config_map(
name=cm_name,
namespace=namespace
)
log(func_name, f'\tSuccessfully migrated `{key_name}` of `{cm_name}`.')
except ApiException as e:
if e.status == 404:
return False
return True
# 对于0.11.5-beta版本adapter的config.toml、core的bot_config.toml和model_config.toml均存储于不同的ConfigMap需要依次迁移
if True not in status_migrating.values():
status_migrating['adapter_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-adapter-config',
'config.toml',
'/app/config/adapter/config.toml')
status_migrating['core_bot_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-core-bot-config',
'bot_config.toml',
'/app/config/core/bot_config.toml')
status_migrating['core_model_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-core-model-config',
'model_config.toml',
'/app/config/core/model_config.toml')
if True in status_migrating.values():
old_configmap_version = '0.11.5-beta'
# 对于低于0.11.5-beta的版本adapter的1个配置和core的3个配置位于各自的configmap中
if True not in status_migrating.values():
status_migrating['adapter_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-adapter',
'config.toml',
'/app/config/adapter/config.toml')
status_migrating['core_bot_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-core',
'bot_config.toml',
'/app/config/core/bot_config.toml')
status_migrating['core_model_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-core',
'model_config.toml',
'/app/config/core/model_config.toml')
if True in status_migrating.values():
old_configmap_version = 'before 0.11.5-beta'
if old_configmap_version:
log(func_name, f'Migrating status for version `{old_configmap_version}`:')
for k, v in status_migrating.items():
log(func_name, f'\t{k}: {v}')
if False in status_migrating.values():
log(func_name, 'There is/are config(s) that not been migrated. Please check the config manually.',
level='WARNING')
else:
log(func_name, 'Successfully migrated old configs. Done.')
else:
log(func_name, 'Old config not found. Ignoring migration. Done.')
def write_config_files():
"""当注入了配置文件时一般是首次安装或者用户指定覆盖将helm chart注入的配置写入存储卷中的实际文件"""
func_name = 'write_config_files'
log(func_name, 'Detecting config files...')
if config_adapter_b64:
log(func_name, '\tWriting `config.toml` of adapter...')
config_str = base64.b64decode(config_adapter_b64).decode("utf-8")
with open('/app/config/adapter/config.toml', 'w', encoding='utf-8') as _f:
_f.write(config_str)
log(func_name, '\t`config.toml` of adapter wrote.')
if True: # .env直接覆盖
log(func_name, '\tWriting .env file of core...')
config_str = base64.b64decode(config_core_env_b64).decode("utf-8")
with open('/app/config/core/.env', 'w', encoding='utf-8') as _f:
_f.write(config_str)
log(func_name, '\t`.env` of core wrote.')
if config_core_bot_b64:
log(func_name, '\tWriting `bot_config.toml` of core...')
config_str = base64.b64decode(config_core_bot_b64).decode("utf-8")
with open('/app/config/core/bot_config.toml', 'w', encoding='utf-8') as _f:
_f.write(config_str)
log(func_name, '\t`bot_config.toml` of core wrote.')
if config_core_model_b64:
log(func_name, '\tWriting `model_config.toml` of core...')
config_str = base64.b64decode(config_core_model_b64).decode("utf-8")
with open('/app/config/core/model_config.toml', 'w', encoding='utf-8') as _f:
_f.write(config_str)
log(func_name, '\t`model_config.toml` of core wrote.')
log(func_name, 'Detection done.')
def reconfigure_adapter():
"""调整adapter的配置文件的napcat_server和maibot_server字段使其Service能被napcat连接以及连接到core的Service"""
func_name = 'reconfigure_adapter'
log(func_name, 'Reconfiguring `config.toml` of adapter...')
with open('/app/config/adapter/config.toml', 'r', encoding='utf-8') as _f:
config_adapter = toml.load(_f)
config_adapter.setdefault('napcat_server', {})
config_adapter['napcat_server']['host'] = '0.0.0.0'
config_adapter['napcat_server']['port'] = 8095
config_adapter.setdefault('maibot_server', {})
config_adapter['maibot_server']['host'] = f'{release_name}-maibot-core' # 根据release名称动态拼接core服务的DNS名称
config_adapter['maibot_server']['port'] = 8000
with open('/app/config/adapter/config.toml', 'w', encoding='utf-8') as _f:
_f.write(toml.dumps(config_adapter))
log(func_name, 'Reconfiguration done.')
def reconfigure_core():
"""调整core的配置文件的webui和maim_message字段使其服务能被正确映射"""
func_name = 'reconfigure_core'
log(func_name, 'Reconfiguring `bot_config.toml` of core...')
with open('/app/config/core/bot_config.toml', 'r', encoding='utf-8') as _f:
config_core = toml.load(_f)
config_core.setdefault('webui', {})
config_core['webui']['enabled'] = is_webui_enabled
config_core['webui']['allowed_ips'] = '0.0.0.0/0' # 部署于k8s内网使用宽松策略
config_core.setdefault('maim_message', {})
config_core['maim_message']['enable_api_server'] = is_maim_message_api_server_enabled
config_core['maim_message']['api_server_host'] = '0.0.0.0'
config_core['maim_message']['api_server_port'] = 8090
with open('/app/config/core/bot_config.toml', 'w', encoding='utf-8') as _f:
_f.write(toml.dumps(config_core))
log(func_name, 'Reconfiguration done.')
def _scale_statefulsets(statefulsets: list[str], replicas: int, wait: bool = False, timeout: int = 300):
"""调整指定几个statefulset的副本数wait参数控制是否等待调整完成再返回"""
statefulsets = set(statefulsets)
for name in statefulsets:
apps_api.patch_namespaced_stateful_set_scale(
name=name,
namespace=namespace,
body={"spec": {"replicas": replicas}}
)
if not wait:
return
start_time = time.time()
while True:
remaining_pods = []
pods = core_api.list_namespaced_pod(namespace).items
for pod in pods:
owners = pod.metadata.owner_references or []
for owner in owners:
if owner.kind == "StatefulSet" and owner.name in statefulsets:
remaining_pods.append(pod.metadata.name)
if not remaining_pods:
return
elapsed = time.time() - start_time
if elapsed > timeout:
raise TimeoutError(
f"Timeout waiting for Pods to be deleted. "
f"Remaining Pods: {remaining_pods}"
)
time.sleep(5)
def _restart_statefulset(name: str, ignore_error: bool = False):
"""重启指定的statefulset"""
now = datetime.now(timezone.utc).isoformat()
body = {
"spec": {
"template": {
"metadata": {
"annotations": {
"kubectl.kubernetes.io/restartedAt": now
}
}
}
}
}
try:
apps_api.patch_namespaced_stateful_set(
name=name,
namespace=namespace,
body=body
)
except ApiException as e:
if ignore_error:
pass
else:
raise e
if __name__ == '__main__':
log('main', 'Start to process data before install/upgrade...')
log('main', 'Scaling adapter and core to 0...')
_scale_statefulsets([f'{release_name}-maibot-adapter', f'{release_name}-maibot-core'], 0, wait=True)
migrate_old_config()
write_config_files()
reconfigure_adapter()
reconfigure_core()
log('main', 'Scaling adapter and core to 1...')
_scale_statefulsets([f'{release_name}-maibot-adapter', f'{release_name}-maibot-core'], 1)
log('main', 'Process done.')

View File

@ -1,2 +0,0 @@
toml~=0.10.2
kubernetes~=34.1.0

View File

@ -1,3 +0,0 @@
MaiBot has been successfully deployed.
MaiBot on GitHub: https://github.com/Mai-with-u/MaiBot

View File

@ -1,33 +0,0 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: {{ .Release.Name }}-maibot-adapter
namespace: {{ .Release.Namespace }}
spec:
{{- if .Values.adapter.persistence.data.accessModes }}
accessModes:
{{ toYaml .Values.adapter.persistence.data.accessModes | nindent 4 }}
{{- end }}
resources:
requests:
storage: {{ .Values.adapter.persistence.data.size }}
{{- if .Values.adapter.persistence.data.storageClass }}
storageClassName: {{ .Values.adapter.persistence.data.storageClass | default nil }}
{{- end }}
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: {{ .Release.Name }}-maibot-adapter-config
namespace: {{ .Release.Namespace }}
spec:
{{- if .Values.adapter.persistence.config.accessModes }}
accessModes:
{{ toYaml .Values.adapter.persistence.config.accessModes | nindent 4 }}
{{- end }}
resources:
requests:
storage: {{ .Values.adapter.persistence.config.size }}
{{- if .Values.adapter.persistence.config.storageClass }}
storageClassName: {{ .Values.adapter.persistence.config.storageClass | default nil }}
{{- end }}

View File

@ -1,19 +0,0 @@
apiVersion: v1
kind: Service
metadata:
name: {{ .Release.Name }}-maibot-adapter
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-adapter
spec:
ports:
- name: napcat-ws
port: {{ .Values.adapter.service.port }}
protocol: TCP
targetPort: 8095
{{- if eq .Values.adapter.service.type "NodePort" }}
nodePort: {{ .Values.adapter.service.nodePort | default nil }}
{{- end }}
selector:
app: {{ .Release.Name }}-maibot-adapter
type: {{ .Values.adapter.service.type }}

View File

@ -1,58 +0,0 @@
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: {{ .Release.Name }}-maibot-adapter
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-adapter
spec:
serviceName: {{ .Release.Name }}-maibot-adapter
replicas: 0 # post-install任务初始化完毕后自动扩容至1
selector:
matchLabels:
app: {{ .Release.Name }}-maibot-adapter
template:
metadata:
labels:
app: {{ .Release.Name }}-maibot-adapter
spec:
containers:
- name: adapter
env:
- name: TZ
value: Asia/Shanghai
image: {{ .Values.adapter.image.repository | default "unclas/maimbot-adapter" }}:{{ .Values.adapter.image.tag | default "main-20251211074617" }}
imagePullPolicy: {{ .Values.adapter.image.pullPolicy }}
ports:
- containerPort: 8095
name: napcat-ws
protocol: TCP
{{- if .Values.adapter.resources }}
resources:
{{ toYaml .Values.adapter.resources | nindent 12 }}
{{- end }}
volumeMounts:
- mountPath: /adapters/data
name: data
- mountPath: /adapters/config.toml
name: config
subPath: config.toml
{{- if .Values.adapter.image.pullSecrets }}
imagePullSecrets:
{{ toYaml .Values.adapter.image.pullSecrets | nindent 8 }}
{{- end }}
{{- if .Values.adapter.nodeSelector }}
nodeSelector:
{{ toYaml .Values.adapter.nodeSelector | nindent 8 }}
{{- end }}
{{- if .Values.adapter.tolerations }}
tolerations:
{{ toYaml .Values.adapter.tolerations | nindent 8 }}
{{- end }}
volumes:
- name: data
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-adapter
- name: config
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-adapter-config

View File

@ -1,26 +0,0 @@
{{- if and .Values.core.webui.enabled .Values.core.webui.ingress.enabled }}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ .Release.Name }}-maibot-webui
namespace: {{ .Release.Namespace }}
{{- if .Values.core.webui.ingress.annotations }}
annotations:
{{ toYaml .Values.core.webui.ingress.annotations | nindent 4 }}
{{- end }}
labels:
app: {{ .Release.Name }}-maibot-core
spec:
ingressClassName: {{ .Values.core.webui.ingress.className }}
rules:
- host: {{ .Values.core.webui.ingress.host }}
http:
paths:
- backend:
service:
name: {{ .Release.Name }}-maibot-core
port:
number: {{ .Values.core.webui.service.port }}
path: {{ .Values.core.webui.ingress.path }}
pathType: {{ .Values.core.webui.ingress.pathType }}
{{- end }}

View File

@ -1,33 +0,0 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: {{ .Release.Name }}-maibot-core
namespace: {{ .Release.Namespace }}
spec:
{{- if .Values.core.persistence.data.accessModes }}
accessModes:
{{ toYaml .Values.core.persistence.data.accessModes | nindent 4 }}
{{- end }}
resources:
requests:
storage: {{ .Values.core.persistence.data.size }}
{{- if .Values.core.persistence.data.storageClass }}
storageClassName: {{ .Values.core.persistence.data.storageClass | default nil }}
{{- end }}
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: {{ .Release.Name }}-maibot-core-config
namespace: {{ .Release.Namespace }}
spec:
{{- if .Values.core.persistence.config.accessModes }}
accessModes:
{{ toYaml .Values.core.persistence.config.accessModes | nindent 4 }}
{{- end }}
resources:
requests:
storage: {{ .Values.core.persistence.config.size }}
{{- if .Values.core.persistence.config.storageClass }}
storageClassName: {{ .Values.core.persistence.config.storageClass | default nil }}
{{- end }}

View File

@ -1,34 +0,0 @@
apiVersion: v1
kind: Service
metadata:
name: {{ .Release.Name }}-maibot-core
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-core
spec:
ports:
- name: adapter-ws
port: 8000
protocol: TCP
targetPort: 8000
{{- if .Values.core.webui.enabled }}
- name: webui
port: {{ .Values.core.webui.service.port }}
protocol: TCP
targetPort: 8001
{{- if eq .Values.core.webui.service.type "NodePort" }}
nodePort: {{ .Values.core.webui.service.nodePort | default nil }}
{{- end }}
{{- end }}
{{- if .Values.core.maim_message_api_server.enabled }}
- name: maim-message
port: {{ .Values.core.maim_message_api_server.service.port }}
protocol: TCP
targetPort: 8090
{{- if eq .Values.core.maim_message_api_server.service.type "NodePort" }}
nodePort: {{ .Values.core.maim_message_api_server.service.nodePort | default nil }}
{{- end }}
{{- end }}
selector:
app: {{ .Release.Name }}-maibot-core
type: ClusterIP

View File

@ -1,103 +0,0 @@
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: {{ .Release.Name }}-maibot-core
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-core
spec:
serviceName: {{ .Release.Name }}-maibot-core
replicas: 0 # post-install任务初始化完毕后自动扩容至1
selector:
matchLabels:
app: {{ .Release.Name }}-maibot-core
template:
metadata:
labels:
app: {{ .Release.Name }}-maibot-core
spec:
containers:
- name: core
command: # 为了在k8s中初始化这里替换启动命令为指定脚本
- sh
args:
- /MaiMBot/k8s-init.sh
env:
- name: TZ
value: "Asia/Shanghai"
- name: EULA_AGREE
value: "1b662741904d7155d1ce1c00b3530d0d"
- name: PRIVACY_AGREE
value: "9943b855e72199d0f5016ea39052f1b6"
image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.12.0" }}
imagePullPolicy: {{ .Values.core.image.pullPolicy }}
ports:
- containerPort: 8000
name: adapter-ws
protocol: TCP
{{- if .Values.core.webui.enabled }}
- containerPort: 8001
name: webui
protocol: TCP
{{- end }}
{{- if .Values.core.resources }}
resources:
{{ toYaml .Values.core.resources | nindent 12 }}
{{- end }}
volumeMounts:
- mountPath: /MaiMBot/data
name: data
- mountPath: /MaiMBot/k8s-init.sh
name: scripts
readOnly: true
subPath: k8s-init.sh
- mountPath: /MaiMBot/.env
name: config
subPath: .env
- mountPath: /MaiMBot/config/model_config.toml
name: config
subPath: model_config.toml
- mountPath: /MaiMBot/config/bot_config.toml
name: config
subPath: bot_config.toml
- mountPath: /MaiMBot/adapters-config/config.toml # WebUI修改adapter配置所用
name: adapter-config
subPath: config.toml
{{- if .Values.statistics_dashboard.enabled }}
- mountPath: /MaiMBot/statistics
name: statistics
{{- end }}
serviceAccountName: {{ .Release.Name }}-maibot-sa
{{- if .Values.core.image.pullSecrets }}
imagePullSecrets:
{{ toYaml .Values.core.image.pullSecrets | nindent 8 }}
{{- end }}
{{- if .Values.core.nodeSelector }}
nodeSelector:
{{ toYaml .Values.core.nodeSelector | nindent 8 }}
{{- end }}
{{- if .Values.core.tolerations }}
tolerations:
{{ toYaml .Values.core.tolerations | nindent 8 }}
{{- end }}
volumes:
- name: data
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-core
- configMap:
items:
- key: k8s-init.sh
path: k8s-init.sh
name: {{ .Release.Name }}-maibot-scripts
name: scripts
- name: config
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-core-config
- name: adapter-config
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-adapter-config
{{- if .Values.statistics_dashboard.enabled }}
- name: statistics
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-statistics-dashboard
{{- end }}

View File

@ -1,26 +0,0 @@
{{- if and .Values.napcat.enabled .Values.napcat.ingress.enabled }}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ .Release.Name }}-maibot-napcat
namespace: {{ .Release.Namespace }}
{{- if .Values.napcat.ingress.annotations }}
annotations:
{{ toYaml .Values.napcat.ingress.annotations | nindent 4 }}
{{- end }}
labels:
app: {{ .Release.Name }}-maibot-napcat
spec:
ingressClassName: {{ .Values.napcat.ingress.className }}
rules:
- host: {{ .Values.napcat.ingress.host }}
http:
paths:
- backend:
service:
name: {{ .Release.Name }}-maibot-napcat
port:
number: {{ .Values.napcat.service.port }}
path: {{ .Values.napcat.ingress.path }}
pathType: {{ .Values.napcat.ingress.pathType }}
{{- end }}

View File

@ -1,18 +0,0 @@
{{- if .Values.napcat.enabled }}
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: {{ .Release.Name }}-maibot-napcat
namespace: {{ .Release.Namespace }}
spec:
{{- if .Values.napcat.persistence.accessModes }}
accessModes:
{{ toYaml .Values.napcat.persistence.accessModes | nindent 4 }}
{{- end }}
resources:
requests:
storage: {{ .Values.napcat.persistence.size }}
{{- if .Values.napcat.persistence.storageClass }}
storageClassName: {{ .Values.napcat.persistence.storageClass | default nil }}
{{- end }}
{{- end }}

View File

@ -1,21 +0,0 @@
{{- if .Values.napcat.enabled }}
apiVersion: v1
kind: Service
metadata:
name: {{ .Release.Name }}-maibot-napcat
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-napcat
spec:
ports:
- name: webui
port: {{ .Values.napcat.service.port }}
protocol: TCP
targetPort: 6099
{{- if eq .Values.napcat.service.type "NodePort" }}
nodePort: {{ .Values.napcat.service.nodePort | default nil }}
{{- end }}
selector:
app: {{ .Release.Name }}-maibot-napcat
type: {{ .Values.napcat.service.type }}
{{- end }}

View File

@ -1,72 +0,0 @@
{{- if .Values.napcat.enabled }}
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: {{ .Release.Name }}-maibot-napcat
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-napcat
spec:
serviceName: {{ .Release.Name }}-maibot-napcat
replicas: 1
selector:
matchLabels:
app: {{ .Release.Name }}-maibot-napcat
template:
metadata:
labels:
app: {{ .Release.Name }}-maibot-napcat
spec:
containers:
- name: napcat
env:
- name: NAPCAT_GID
value: "{{ .Values.napcat.permission.gid }}"
- name: NAPCAT_UID
value: "{{ .Values.napcat.permission.uid }}"
- name: TZ
value: Asia/Shanghai
image: {{ .Values.napcat.image.repository | default "mlikiowa/napcat-docker" }}:{{ .Values.napcat.image.tag | default "v4.9.80" }}
imagePullPolicy: {{ .Values.napcat.image.pullPolicy }}
livenessProbe:
failureThreshold: 3
httpGet:
path: /
port: 6099
scheme: HTTP
initialDelaySeconds: 60
periodSeconds: 60
successThreshold: 1
timeoutSeconds: 10
ports:
- containerPort: 6099
name: webui
protocol: TCP
{{- if .Values.napcat.resources }}
resources:
{{ toYaml .Values.napcat.resources | nindent 12 }}
{{- end }}
volumeMounts:
- mountPath: /app/napcat/config
name: napcat
subPath: config
- mountPath: /app/.config/QQ
name: napcat
subPath: data
{{- if .Values.napcat.image.pullSecrets }}
imagePullSecrets:
{{ toYaml .Values.napcat.image.pullSecrets | nindent 8 }}
{{- end }}
{{- if .Values.napcat.nodeSelector }}
nodeSelector:
{{ toYaml .Values.napcat.nodeSelector | nindent 8 }}
{{- end }}
{{- if .Values.napcat.tolerations }}
tolerations:
{{ toYaml .Values.napcat.tolerations | nindent 8 }}
{{- end }}
volumes:
- name: napcat
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-napcat
{{- end }}

View File

@ -1,8 +0,0 @@
# 检查EULA和PRIVACY
{{- if not .Values.EULA_AGREE }}
{{ fail "You must accept the EULA by setting 'EULA_AGREE: true'. EULA: https://github.com/Mai-with-u/MaiBot/blob/main/EULA.md" }}
{{- end }}
{{- if not .Values.PRIVACY_AGREE }}
{{ fail "You must accept the Privacy Policy by setting 'PRIVACY_AGREE: true'. Privacy Policy: https://github.com/Mai-with-u/MaiBot/blob/main/PRIVACY.md" }}
{{- end }}

View File

@ -1,9 +0,0 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: {{ .Release.Name }}-maibot-scripts
namespace: {{ .Release.Namespace }}
data:
# core
k8s-init.sh: |
{{ .Files.Get "files/k8s-init.sh" | nindent 4 }}

View File

@ -1,62 +0,0 @@
# 预处理脚本,仅会在部署前运行一次
apiVersion: batch/v1
kind: Job
metadata:
name: {{ .Release.Name }}-maibot-preprocessor
namespace: {{ .Release.Namespace }}
annotations:
"helm.sh/hook": post-install,post-upgrade
"helm.sh/hook-delete-policy": before-hook-creation,hook-succeeded
spec:
backoffLimit: 2
template:
spec:
serviceAccountName: {{ .Release.Name }}-maibot-sa
restartPolicy: Never
containers:
- name: preprocessor
image: {{ .Values.pre_processor.image.repository | default "reg.mikumikumi.xyz/maibot/preprocessor" }}:{{ .Values.pre_processor.image.tag | default "0.12.0" }}
imagePullPolicy: {{ .Values.pre_processor.image.pullPolicy }}
env:
- name: RELEASE_NAME
value: {{ .Release.Name }}
- name: IS_WEBUI_ENABLED
value: {{ .Values.core.webui.enabled | quote }}
- name: IS_MMSG_ENABLED
value: {{ .Values.core.maim_message_api_server.enabled | quote }}
{{- if or .Values.config.override_adapter_config .Release.IsInstall }}
- name: CONFIG_ADAPTER_B64
value: {{ .Values.config.adapter_config | b64enc | quote }}
{{- end }}
- name: CONFIG_CORE_ENV_B64
value: {{ tpl (.Files.Get "files/.env") . | b64enc | quote }}
{{- if or .Values.config.override_core_bot_config .Release.IsInstall }}
- name: CONFIG_CORE_BOT_B64
value: {{ .Values.config.core_bot_config | b64enc | quote }}
{{- end }}
{{- if or .Values.config.override_core_model_config .Release.IsInstall }}
- name: CONFIG_CORE_MODEL_B64
value: {{ .Values.config.core_model_config | b64enc | quote }}
{{- end }}
volumeMounts:
- mountPath: /app/config/adapter
name: adapter-config
- mountPath: /app/config/core
name: core-config
imagePullSecrets:
{{ toYaml .Values.pre_processor.image.pullSecrets | nindent 8 }}
{{- if .Values.pre_processor.nodeSelector }}
nodeSelector:
{{ toYaml .Values.pre_processor.nodeSelector | nindent 8 }}
{{- end }}
{{- if .Values.pre_processor.tolerations }}
tolerations:
{{ toYaml .Values.pre_processor.tolerations | nindent 8 }}
{{- end }}
volumes:
- name: adapter-config
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-adapter-config
- name: core-config
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-core-config

View File

@ -1,36 +0,0 @@
# 初始化及反向修改ConfigMap所需要的rbac授权
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ .Release.Name }}-maibot-sa
namespace: {{ .Release.Namespace }}
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: {{ .Release.Name }}-maibot-role
namespace: {{ .Release.Namespace }}
rules:
- apiGroups: [""]
resources: ["configmaps", "pods"]
verbs: ["get", "list", "delete"]
- apiGroups: ["apps"]
resources: ["statefulsets"]
verbs: ["get", "list", "update", "patch"]
- apiGroups: ["apps"]
resources: ["statefulsets/scale"]
verbs: ["get", "patch", "update"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: {{ .Release.Name }}-maibot-rolebinding
namespace: {{ .Release.Namespace }}
subjects:
- kind: ServiceAccount
name: {{ .Release.Name }}-maibot-sa
namespace: {{ .Release.Namespace }}
roleRef:
kind: Role
name: {{ .Release.Name }}-maibot-role
apiGroup: rbac.authorization.k8s.io

View File

@ -1,26 +0,0 @@
{{- if and .Values.sqlite_web.enabled .Values.sqlite_web.ingress.enabled }}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ .Release.Name }}-maibot-sqlite-web
namespace: {{ .Release.Namespace }}
{{- if .Values.sqlite_web.ingress.annotations }}
annotations:
{{ toYaml .Values.sqlite_web.ingress.annotations | nindent 4 }}
{{- end }}
labels:
app: {{ .Release.Name }}-maibot-sqlite-web
spec:
ingressClassName: {{ .Values.sqlite_web.ingress.className }}
rules:
- host: {{ .Values.sqlite_web.ingress.host }}
http:
paths:
- backend:
service:
name: {{ .Release.Name }}-maibot-sqlite-web
port:
number: {{ .Values.sqlite_web.service.port }}
path: {{ .Values.sqlite_web.ingress.path }}
pathType: {{ .Values.sqlite_web.ingress.pathType }}
{{- end }}

View File

@ -1,21 +0,0 @@
{{- if .Values.sqlite_web.enabled }}
apiVersion: v1
kind: Service
metadata:
name: {{ .Release.Name }}-maibot-sqlite-web
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-sqlite-web
spec:
ports:
- name: webui
port: {{ .Values.sqlite_web.service.port }}
protocol: TCP
targetPort: 8080
{{- if eq .Values.sqlite_web.service.type "NodePort" }}
nodePort: {{ .Values.sqlite_web.service.nodePort | default nil }}
{{- end }}
selector:
app: {{ .Release.Name }}-maibot-sqlite-web
type: {{ .Values.sqlite_web.service.type }}
{{- end }}

View File

@ -1,64 +0,0 @@
{{- if .Values.sqlite_web.enabled }}
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: {{ .Release.Name }}-maibot-sqlite-web
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-sqlite-web
spec:
serviceName: {{ .Release.Name }}-maibot-sqlite-web
replicas: 1
selector:
matchLabels:
app: {{ .Release.Name }}-maibot-sqlite-web
template:
metadata:
labels:
app: {{ .Release.Name }}-maibot-sqlite-web
spec:
containers:
- name: sqlite-web
env:
- name: SQLITE_DATABASE
value: /data/MaiMBot/MaiBot.db
image: {{ .Values.sqlite_web.image.repository | default "coleifer/sqlite-web" }}:{{ .Values.sqlite_web.image.tag | default "latest" }}
imagePullPolicy: {{ .Values.sqlite_web.image.pullPolicy }}
livenessProbe:
failureThreshold: 3
httpGet:
path: /
port: 8080
scheme: HTTP
initialDelaySeconds: 60
periodSeconds: 60
successThreshold: 1
timeoutSeconds: 10
ports:
- containerPort: 8080
name: webui
protocol: TCP
{{- if .Values.sqlite_web.resources }}
resources:
{{ toYaml .Values.sqlite_web.resources | nindent 12 }}
{{- end }}
volumeMounts:
- mountPath: /data/MaiMBot
name: data
{{- if .Values.sqlite_web.image.pullSecrets }}
imagePullSecrets:
{{ toYaml .Values.sqlite_web.image.pullSecrets | nindent 8 }}
{{- end }}
{{- if .Values.sqlite_web.nodeSelector }}
nodeSelector:
{{ toYaml .Values.sqlite_web.nodeSelector | nindent 8 }}
{{- end }}
{{- if .Values.sqlite_web.tolerations }}
tolerations:
{{ toYaml .Values.sqlite_web.tolerations | nindent 8 }}
{{- end }}
volumes:
- name: data
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-core
{{- end }}

View File

@ -1,61 +0,0 @@
{{- if .Values.statistics_dashboard.enabled }}
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ .Release.Name }}-maibot-statistics-dashboard
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-statistics-dashboard
spec:
replicas: {{ .Values.statistics_dashboard.replicaCount }}
selector:
matchLabels:
app: {{ .Release.Name }}-maibot-statistics-dashboard
template:
metadata:
labels:
app: {{ .Release.Name }}-maibot-statistics-dashboard
spec:
containers:
- name: nginx
image: {{ .Values.statistics_dashboard.image.repository | default "nginx" }}:{{ .Values.statistics_dashboard.image.tag | default "latest" }}
imagePullPolicy: {{ .Values.statistics_dashboard.image.pullPolicy }}
livenessProbe:
failureThreshold: 3
httpGet:
path: /
port: 80
scheme: HTTP
initialDelaySeconds: 60
periodSeconds: 10
successThreshold: 1
timeoutSeconds: 1
ports:
- containerPort: 80
name: dashboard
protocol: TCP
{{- if .Values.statistics_dashboard.resources }}
resources:
{{ toYaml .Values.statistics_dashboard.resources | nindent 12 }}
{{- end }}
volumeMounts:
- mountPath: /usr/share/nginx/html
name: statistics
readOnly: true
{{- if .Values.statistics_dashboard.image.pullSecrets }}
imagePullSecrets:
{{ toYaml .Values.statistics_dashboard.image.pullSecrets | nindent 8 }}
{{- end }}
{{- if .Values.statistics_dashboard.nodeSelector }}
nodeSelector:
{{ toYaml .Values.statistics_dashboard.nodeSelector | nindent 8 }}
{{- end }}
{{- if .Values.statistics_dashboard.tolerations }}
tolerations:
{{ toYaml .Values.statistics_dashboard.tolerations | nindent 8 }}
{{- end }}
volumes:
- name: statistics
persistentVolumeClaim:
claimName: {{ .Release.Name }}-maibot-statistics-dashboard
{{- end }}

View File

@ -1,26 +0,0 @@
{{- if and .Values.statistics_dashboard.enabled .Values.statistics_dashboard.ingress.enabled }}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ .Release.Name }}-maibot-statistics-dashboard
namespace: {{ .Release.Namespace }}
{{- if .Values.statistics_dashboard.ingress.annotations }}
annotations:
{{ toYaml .Values.statistics_dashboard.ingress.annotations | nindent 4 }}
{{- end }}
labels:
app: {{ .Release.Name }}-maibot-statistics-dashboard
spec:
ingressClassName: {{ .Values.statistics_dashboard.ingress.className }}
rules:
- host: {{ .Values.statistics_dashboard.ingress.host }}
http:
paths:
- backend:
service:
name: {{ .Release.Name }}-maibot-statistics-dashboard
port:
number: {{ .Values.statistics_dashboard.service.port }}
path: {{ .Values.statistics_dashboard.ingress.path }}
pathType: {{ .Values.statistics_dashboard.ingress.pathType }}
{{- end }}

View File

@ -1,18 +0,0 @@
{{- if .Values.statistics_dashboard.enabled }}
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: {{ .Release.Name }}-maibot-statistics-dashboard
namespace: {{ .Release.Namespace }}
spec:
{{- if .Values.statistics_dashboard.persistence.accessModes }}
accessModes:
{{ toYaml .Values.statistics_dashboard.persistence.accessModes | nindent 4 }}
{{- end }}
resources:
requests:
storage: {{ .Values.statistics_dashboard.persistence.size }}
{{- if .Values.statistics_dashboard.persistence.storageClass }}
storageClassName: {{ .Values.statistics_dashboard.persistence.storageClass | default nil }}
{{- end }}
{{- end }}

View File

@ -1,21 +0,0 @@
{{- if .Values.statistics_dashboard.enabled }}
apiVersion: v1
kind: Service
metadata:
name: {{ .Release.Name }}-maibot-statistics-dashboard
namespace: {{ .Release.Namespace }}
labels:
app: {{ .Release.Name }}-maibot-statistics-dashboard
spec:
ports:
- name: dashboard
port: {{ .Values.statistics_dashboard.service.port }}
protocol: TCP
targetPort: 80
{{- if eq .Values.statistics_dashboard.service.type "NodePort" }}
nodePort: {{ .Values.statistics_dashboard.service.nodePort | default nil }}
{{- end }}
selector:
app: {{ .Release.Name }}-maibot-statistics-dashboard
type: {{ .Values.statistics_dashboard.service.type }}
{{- end }}

View File

@ -1,772 +0,0 @@
# 只有同意了EULA和PRIVACY协议才可以部署麦麦
# 配置以下的选项为true表示你同意了EULA和PRIVACY条款
# https://github.com/MaiM-with-u/MaiBot/blob/main/EULA.md
# https://github.com/MaiM-with-u/MaiBot/blob/main/PRIVACY.md
EULA_AGREE: false
PRIVACY_AGREE: false
# 预处理Job的配置
pre_processor:
image:
repository: # 默认 reg.mikumikumi.xyz/maibot/preprocessor
tag: # 默认 0.12.0
pullPolicy: IfNotPresent
pullSecrets: [ ]
nodeSelector: { }
tolerations: [ ]
# 麦麦Adapter的部署配置
adapter:
image:
repository: # 默认 unclas/maimbot-adapter
tag: # 默认 main-20251211074617
pullPolicy: IfNotPresent
pullSecrets: [ ]
resources: { }
nodeSelector: { }
tolerations: [ ]
# 配置adapter的napcat websocket service
# adapter会启动一个websocket服务端用于与napcat通信
# 这里的选项可以帮助你自定义服务端口
# 默认不使用NodePort。如果通过NodePort将服务端口映射到公网可能会被恶意客户端连接请自行使用中间件鉴权
service:
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的websocket端口映射到物理节点的端口
port: 8095 # websocket监听端口ClusterIP的端口
nodePort: # 仅在设置NodePort类型时有效不指定则会随机分配端口
persistence:
config: # 配置文件的存储卷
storageClass:
accessModes:
- ReadWriteOnce
size: 10Mi
data: # 数据的存储卷
storageClass:
accessModes:
- ReadWriteOnce
size: 1Gi
# 麦麦本体的部署配置
core:
image:
repository: # 默认 sengokucola/maibot
tag: # 默认 0.12.0
pullPolicy: IfNotPresent
pullSecrets: [ ]
resources: { }
nodeSelector: { }
tolerations: [ ]
persistence:
config: # 配置文件的存储卷
storageClass:
accessModes:
- ReadWriteOnce
size: 10Mi
data: # 数据的存储卷
storageClass:
accessModes:
- ReadWriteOnce
size: 10Gi
webui: # WebUI相关配置
enabled: true # 默认启用
service:
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
port: 8001 # 服务端口
nodePort: # 仅在设置NodePort类型时有效不指定则会随机分配端口
ingress:
enabled: false
className: nginx
annotations: { }
host: maim.example.com # 访问麦麦WebUI的域名
path: /
pathType: Prefix
maim_message_api_server:
enabled: false
service:
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
port: 8090 # 服务端口
nodePort: # 仅在设置NodePort类型时有效不指定则会随机分配端口
# 麦麦的运行统计看板配置
# 麦麦每隔一段时间会自动输出html格式的运行统计报告此统计报告可以作为静态网页访问
# 此功能默认禁用。如果你认为报告可以被公开访问(报告包含联系人/群组名称、模型token花费信息等则可以启用此功能
# 如果启用此功能,你也可以考虑使用中间件进行鉴权,保护隐私信息
statistics_dashboard:
enabled: false # 是否启用运行统计看板
replicaCount: 1
image:
repository: # 默认 nginx
tag: # 默认 latest
pullPolicy: IfNotPresent
pullSecrets: [ ]
resources: { }
nodeSelector: { }
tolerations: [ ]
service:
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
port: 80 # 服务端口
nodePort: # 仅在设置NodePort类型时有效不指定则会随机分配端口
ingress:
enabled: false
className: nginx
annotations: { }
host: maim-statistics.example.com # 访问运行统计看板的域名
path: /
pathType: Prefix
persistence:
storageClass:
# 如果你希望运行统计看板服务与麦麦本体运行在不同的节点多活部署那么需要ReadWriteMany访问模式
# 注意ReadWriteMany特性需要存储类底层支持
accessModes:
- ReadWriteOnce
size: 100Mi
# napcat的部署配置
# napcat部署完毕后务必修改默认密码
napcat:
# 考虑到复用外部napcat实例的情况napcat部署已被解耦
# 如果你有外部部署的napcat则可以修改下面的enabled为false本次不会重复部署napcat
# 如果没有外部部署的napcat默认会捆绑部署napcat不需要修改此项
enabled: true
image:
repository: # 默认 mlikiowa/napcat-docker
tag: # 默认 v4.9.91
pullPolicy: IfNotPresent
pullSecrets: [ ]
resources: { }
nodeSelector: { }
tolerations: [ ]
# napcat进程的权限默认不是特权用户
permission:
uid: 1000
gid: 1000
# 配置napcat web面板的service
service:
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
port: 6099 # 服务端口
nodePort: # 仅在设置NodePort类型时有效不指定则会随机分配端口
# 配置napcat web面板的ingress
ingress:
enabled: false # 是否启用
className: nginx
annotations: { }
host: napcat.example.com # 暴露napcat web面板使用的域名
path: /
pathType: Prefix
persistence:
storageClass:
accessModes:
- ReadWriteOnce
size: 5Gi
# sqlite-web的部署配置
sqlite_web:
# 通过sqlite-web可以在网页上操作麦麦的数据库方便调试。不部署对麦麦的运行无影响
# 默认不会捆绑部署sqlite-web如果你需要部署请修改下面的enabled为true
# sqlite-web服务无鉴权暴露在公网上十分危险推荐使用集群ClusterIP内网访问
# !!!如果一定要暴露在公网,请自行使用中间件鉴权!!!
enabled: false
image:
repository: # 默认 coleifer/sqlite-web
tag: # 默认 latest
pullPolicy: IfNotPresent
pullSecrets: [ ]
resources: { }
nodeSelector: { }
tolerations: [ ]
# 配置sqlite-web面板的service
# 默认不使用NodePort。如果使用NodePort暴露到公网请自行使用中间件鉴权
service:
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
port: 8080 # 服务端口
nodePort: # 仅在设置NodePort类型时有效不指定则会随机分配端口
# 配置sqlite-web面板的ingress
# 默认不使用ingress。如果使用ingress暴露到公网请自行使用中间件鉴权
ingress:
enabled: false # 是否启用
className: nginx
annotations: { }
host: maim-sqlite.example.com # 暴露websocket使用的域名
path: /
pathType: Prefix
# 设置麦麦各部分组件的初始运行配置
# 考虑到配置文件的操作复杂性增加k8s的适配复杂度也同步增加且WebUI可以直接修改配置文件
# 自0.11.6-beta版本开始各组件的配置不再存储于k8s的configmap中而是直接存储于存储卷的实际文件中
# 从旧版本升级的用户旧的configmap的配置会自动迁移到新的存储卷的配置文件中
# 此处的配置只在初次部署时或者指定覆盖时注入到MaiBot中
config:
# 指定是否用下面的配置覆盖MaiBot现有的配置文件
override_adapter_config: false
override_core_bot_config: false
override_core_model_config: false
# adapter的config.toml
adapter_config: |
[inner]
version = "0.1.2" # 版本号
# 请勿修改版本号,除非你知道自己在做什么
[nickname] # 现在没用
nickname = ""
[napcat_server] # Napcat连接的ws服务设置
token = "" # Napcat设定的访问令牌若无则留空
heartbeat_interval = 30 # 与Napcat设置的心跳相同按秒计
[chat] # 黑白名单功能
group_list_type = "whitelist" # 群组名单类型可选为whitelist, blacklist
group_list = [] # 群组名单
# 当group_list_type为whitelist时只有群组名单中的群组可以聊天
# 当group_list_type为blacklist时群组名单中的任何群组无法聊天
private_list_type = "whitelist" # 私聊名单类型可选为whitelist, blacklist
private_list = [] # 私聊名单
# 当private_list_type为whitelist时只有私聊名单中的用户可以聊天
# 当private_list_type为blacklist时私聊名单中的任何用户无法聊天
ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天)
ban_qq_bot = false # 是否屏蔽QQ官方机器人
enable_poke = true # 是否启用戳一戳功能
[voice] # 发送语音设置
use_tts = false # 是否使用tts语音请确保你配置了tts并有对应的adapter
[debug]
level = "INFO" # 日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL
# core的model_config.toml
core_model_config: |
[inner]
version = "1.9.1"
# 配置文件版本号迭代规则同bot_config.toml
[[api_providers]] # API服务提供商可以配置多个
name = "DeepSeek" # API服务商名称可随意命名在models的api-provider中需使用这个命名
base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL
api_key = "your-api-key-here" # API密钥请替换为实际的API密钥
client_type = "openai" # 请求客户端(可选,默认值为"openai"使用gimini等Google系模型时请配置为"gemini"
max_retry = 2 # 最大重试次数单个模型API调用失败最多重试的次数
timeout = 120 # API请求超时时间单位
retry_interval = 10 # 重试间隔时间(单位:秒)
[[api_providers]] # 阿里 百炼 API服务商配置
name = "BaiLian"
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
api_key = "your-bailian-key"
client_type = "openai"
max_retry = 2
timeout = 120
retry_interval = 5
[[api_providers]] # 特殊Google的Gimini使用特殊API与OpenAI格式不兼容需要配置client为"gemini"
name = "Google"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "your-google-api-key-1"
client_type = "gemini"
max_retry = 2
timeout = 120
retry_interval = 10
[[api_providers]] # SiliconFlow的API服务商配置
name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
api_key = "your-siliconflow-api-key"
client_type = "openai"
max_retry = 3
timeout = 120
retry_interval = 5
[[models]] # 模型(可以配置多个)
model_identifier = "deepseek-chat" # 模型标识符API服务商提供的模型标识符
name = "deepseek-v3" # 模型名称(可随意命名,在后面中需使用这个命名)
api_provider = "DeepSeek" # API服务商名称对应在api_providers中配置的服务商名称
price_in = 2.0 # 输入价格用于API调用统计单位元/ M token可选若无该字段默认值为0
price_out = 8.0 # 输出价格用于API调用统计单位元/ M token可选若无该字段默认值为0
# force_stream_mode = true # 强制流式输出模式若模型不支持非流式输出请取消该注释启用强制流式输出若无该字段默认值为false
[[models]]
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
name = "siliconflow-deepseek-v3.2"
api_provider = "SiliconFlow"
price_in = 2.0
price_out = 3.0
# temperature = 0.5 # 可选:为该模型单独指定温度,会覆盖任务配置中的温度
# max_tokens = 4096 # 可选为该模型单独指定最大token数会覆盖任务配置中的max_tokens
[models.extra_params] # 可选的额外参数配置
enable_thinking = false # 不启用思考
[[models]]
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
name = "siliconflow-deepseek-v3.2-think"
api_provider = "SiliconFlow"
price_in = 2.0
price_out = 3.0
# temperature = 0.7 # 可选:为该模型单独指定温度,会覆盖任务配置中的温度
# max_tokens = 4096 # 可选为该模型单独指定最大token数会覆盖任务配置中的max_tokens
[models.extra_params] # 可选的额外参数配置
enable_thinking = true # 启用思考
[[models]]
model_identifier = "Qwen/Qwen3-Next-80B-A3B-Instruct"
name = "qwen3-next-80b"
api_provider = "SiliconFlow"
price_in = 1.0
price_out = 4.0
[[models]]
model_identifier = "zai-org/GLM-4.6"
name = "siliconflow-glm-4.6"
api_provider = "SiliconFlow"
price_in = 3.5
price_out = 14.0
[models.extra_params] # 可选的额外参数配置
enable_thinking = false # 不启用思考
[[models]]
model_identifier = "zai-org/GLM-4.6"
name = "siliconflow-glm-4.6-think"
api_provider = "SiliconFlow"
price_in = 3.5
price_out = 14.0
[models.extra_params] # 可选的额外参数配置
enable_thinking = true # 启用思考
[[models]]
model_identifier = "deepseek-ai/DeepSeek-R1"
name = "siliconflow-deepseek-r1"
api_provider = "SiliconFlow"
price_in = 4.0
price_out = 16.0
[[models]]
model_identifier = "Qwen/Qwen3-30B-A3B-Instruct-2507"
name = "qwen3-30b"
api_provider = "SiliconFlow"
price_in = 0.7
price_out = 2.8
[[models]]
model_identifier = "Qwen/Qwen3-VL-30B-A3B-Instruct"
name = "qwen3-vl-30"
api_provider = "SiliconFlow"
price_in = 4.13
price_out = 4.13
[[models]]
model_identifier = "FunAudioLLM/SenseVoiceSmall"
name = "sensevoice-small"
api_provider = "SiliconFlow"
price_in = 0
price_out = 0
[[models]]
model_identifier = "BAAI/bge-m3"
name = "bge-m3"
api_provider = "SiliconFlow"
price_in = 0
price_out = 0
[model_task_config.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,麦麦的情绪变化等,是麦麦必须的模型
model_list = ["siliconflow-deepseek-v3.2"] # 使用的模型列表,每个子项对应上面的模型名称(name)
temperature = 0.2 # 模型温度新V3建议0.1-0.3
max_tokens = 4096 # 最大输出token数
slow_threshold = 15.0 # 慢请求阈值(秒),模型等待回复时间超过此值会输出警告日志
[model_task_config.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
model_list = ["qwen3-30b","qwen3-next-80b"]
temperature = 0.7
max_tokens = 2048
slow_threshold = 10.0
[model_task_config.tool_use] #工具调用模型,需要使用支持工具调用的模型
model_list = ["qwen3-30b","qwen3-next-80b"]
temperature = 0.7
max_tokens = 800
slow_threshold = 10.0
[model_task_config.replyer] # 首要回复模型,还用于表达器和表达方式学习
model_list = ["siliconflow-deepseek-v3.2","siliconflow-deepseek-v3.2-think","siliconflow-glm-4.6","siliconflow-glm-4.6-think"]
temperature = 0.3 # 模型温度新V3建议0.1-0.3
max_tokens = 2048
slow_threshold = 25.0
[model_task_config.planner] #决策:负责决定麦麦该什么时候回复的模型
model_list = ["siliconflow-deepseek-v3.2"]
temperature = 0.3
max_tokens = 800
slow_threshold = 12.0
[model_task_config.vlm] # 图像识别模型
model_list = ["qwen3-vl-30"]
max_tokens = 256
slow_threshold = 15.0
[model_task_config.voice] # 语音识别模型
model_list = ["sensevoice-small"]
slow_threshold = 12.0
# 嵌入模型
[model_task_config.embedding]
model_list = ["bge-m3"]
slow_threshold = 5.0
# ------------LPMM知识库模型------------
[model_task_config.lpmm_entity_extract] # 实体提取模型
model_list = ["siliconflow-deepseek-v3.2"]
temperature = 0.2
max_tokens = 800
slow_threshold = 20.0
[model_task_config.lpmm_rdf_build] # RDF构建模型
model_list = ["siliconflow-deepseek-v3.2"]
temperature = 0.2
max_tokens = 800
slow_threshold = 20.0
[model_task_config.lpmm_qa] # 问答模型
model_list = ["siliconflow-deepseek-v3.2"]
temperature = 0.7
max_tokens = 800
slow_threshold = 20.0
# core的bot_config.toml
core_bot_config: |
[inner]
version = "7.2.5"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
# 如果你想要修改配置文件请递增version的值
# 如果新增项目请阅读src/config/official_configs.py中的说明
#
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
# 主版本号MMC版本更新
# 次版本号:配置文件内容大更新
# 修订号:配置文件内容小更新
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
[bot]
platform = "qq"
qq_account = "1145141919810" # 麦麦的QQ账号
platforms = ["wx:114514","xx:1919810"] # 麦麦的其他平台账号
nickname = "麦麦" # 麦麦的昵称
alias_names = ["麦叠", "牢麦"] # 麦麦的别名
[personality]
# 建议120字以内描述人格特质 和 身份特征
personality = "是一个大二在读女大学生,现在正在上网和群友聊天,有时有点攻击性,有时比较温柔"
# アイデンティティがない 生まれないらららら
# 描述麦麦说话的表达风格表达习惯如要修改可以酌情新增内容建议1-2行
reply_style = "请不要刻意突出自身学科背景。可以参考贴吧,知乎和微博的回复风格。"
# 多种回复风格列表,可选配置:当 non-empty 且 multiple_probability>0 时,会按概率随机从中选择一个替换 reply_style
multiple_reply_style = [
# "你的风格平淡但不失讽刺,很简短,很白话。可以参考贴吧,微博的回复风格。",
# "用1-2个字进行回复",
# "用1-2个符号进行回复",
# "言辭凝練古雅,穿插《論語》經句卻不晦澀,以文言短句為基,輔以淺白語意,持長者溫和風範,全用繁體字表達,具先秦儒者談吐韻致。",
# "带点翻译腔,但不要太长",
]
# 替换概率:每次构建回复时,以该概率从 multiple_reply_style 中随机选择一个替换 reply_style0.0-1.0
multiple_probability = 0.3
# 麦麦的说话规则,行为风格:
plan_style = """
1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用
2.如果相同的内容已经被执行,请不要重复执行
3.你对技术相关话题,游戏和动漫相关话题感兴趣,也对日常话题感兴趣,不喜欢太过沉重严肃的话题
4.请控制你的发言频率,不要太过频繁的发言
5.如果有人对你感到厌烦,请减少回复
6.如果有人在追问你,或者话题没有说完,请你继续回复"""
# 麦麦识图规则,不建议修改
visual_style = "请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本"
# 麦麦私聊的说话规则,行为风格:
private_plan_style = """
1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用
2.如果相同的内容已经被执行,请不要重复执行
3.某句话如果已经被回复过,不要重复回复"""
# 状态,可以理解为人格多样性,会随机替换人格
states = [
"是一个女大学生,喜欢上网聊天,会刷小红书。" ,
"是一个大二心理学生,会刷贴吧和中国知网。" ,
"是一个赛博网友,最近很想吐槽人。"
]
# 替换概率每次构建人格时替换personality的概率0.0-1.0
state_probability = 0.3
[expression]
# 表达学习配置
learning_list = [ # 表达学习配置列表,支持按聊天流配置
["", "enable", "enable", "enable"], # 全局配置使用表达启用学习启用jargon学习
["qq:1919810:group", "enable", "enable", "enable"], # 特定群聊配置使用表达启用学习启用jargon学习
["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置使用表达禁用学习禁用jargon学习
# 格式说明:
# 第一位: chat_stream_id空字符串表示全局配置
# 第二位: 是否使用学到的表达 ("enable"/"disable")
# 第三位: 是否学习表达 ("enable"/"disable")
# 第四位: 是否启用jargon学习 ("enable"/"disable")
]
expression_groups = [
# ["*"], # 全局共享组所有chat_id共享学习到的表达方式取消注释以启用全局共享
["qq:1919810:private","qq:114514:private","qq:1111111:group"], # 特定互通组相同组的chat_id会共享学习到的表达方式
# 格式说明:
# ["*"] - 启用全局共享,所有聊天流共享表达方式
# ["qq:123456:private","qq:654321:group"] - 特定互通组组内chat_id共享表达方式
# 注意如果为群聊则需要设置为group如果设置为私聊则需要设置为private
]
reflect = false # 是否启用表达反思Bot主动向管理员询问表达方式是否合适
reflect_operator_id = "" # 表达反思操作员ID格式platform:id:type (例如 "qq:123456:private" 或 "qq:654321:group")
allow_reflect = [] # 允许进行表达反思的聊天流ID列表格式["qq:123456:private", "qq:654321:group", ...],只有在此列表中的聊天流才会提出问题并跟踪。如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true
all_global_jargon = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除
enable_jargon_explanation = true # 是否在回复前尝试对上下文中的黑话进行解释关闭可减少一次LLM调用仅影响回复前的黑话匹配与解释不影响黑话学习
jargon_mode = "planner" # 黑话解释来源模式,可选: "context"(使用上下文自动匹配黑话) 或 "planner"仅使用Planner在reply动作中给出的unknown_words列表
[chat] # 麦麦的聊天设置
talk_value = 1 # 聊天频率越小越沉默范围0-1
mentioned_bot_reply = true # 是否启用提及必回复
max_context_size = 30 # 上下文长度
planner_smooth = 3 # 规划器平滑增大数值会减小planner负荷略微降低反应速度推荐1-50为关闭必须大于等于0
think_mode = "dynamic" # 思考模式可选classic默认浅度思考和回复、deep会进行比较长的深度回复、dynamic动态选择两种模式
enable_talk_value_rules = true # 是否启用动态发言频率规则
# 动态发言频率规则:按时段/按chat_id调整 talk_value优先匹配具体chat再匹配全局
# 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
# 说明:
# - target 为空字符串表示全局type 为 group/private例如"qq:1919810:group" 或 "qq:114514:private"
# - 支持跨夜区间,例如 "23:00-02:00";数值范围建议 0-1如果 value 设置为0会自动转换为0.0001以避免除以零错误。
talk_value_rules = [
{ target = "", time = "00:00-08:59", value = 0.8 },
{ target = "", time = "09:00-22:59", value = 1.0 },
{ target = "qq:1919810:group", time = "20:00-23:59", value = 0.6 },
{ target = "qq:114514:private", time = "00:00-23:59", value = 0.3 },
]
[memory]
max_agent_iterations = 3 # 记忆思考深度最低为1
agent_timeout_seconds = 45.0 # 最长回忆时间(秒)
enable_jargon_detection = true # 记忆检索过程中是否启用黑话识别
global_memory = false # 是否允许记忆检索进行全局查询
[dream]
interval_minutes = 60 # 做梦时间间隔分钟默认30分钟
max_iterations = 20 # 做梦最大轮次默认20轮
first_delay_seconds = 1800 # 程序启动后首次做梦前的延迟时间默认60秒
# 做梦结果推送目标,格式为 "platform:user_id"
# 例如: "qq:123456" 表示在做梦结束后将梦境文本额外发送给该QQ私聊用户。
# 为空字符串时不推送。
dream_send = ""
# 做梦时间段配置,格式:["HH:MM-HH:MM", ...]
# 如果列表为空,则表示全天允许做梦。
# 如果配置了时间段,则只有在这些时间段内才会实际执行做梦流程。
# 时间段外,调度器仍会按间隔检查,但不会进入做梦流程。
# 支持跨夜区间,例如 "23:00-02:00" 表示从23:00到次日02:00。
# 示例:
dream_time_ranges = [
# "09:00-22:00", # 白天允许做梦
"23:00-10:00", # 跨夜时间段23:00到次日10:00
]
# dream_time_ranges = []
[tool]
enable_tool = true # 是否启用工具
[emoji]
emoji_chance = 0.4 # 麦麦激活表情包动作的概率
max_reg_num = 100 # 表情包最大注册数量
do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包
check_interval = 10 # 检查表情包(注册,破损,删除)的时间间隔(分钟)
steal_emoji = true # 是否偷取表情包,让麦麦可以将一些表情包据为己有
content_filtration = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存
filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存
[voice]
enable_asr = false # 是否启用语音识别,启用后麦麦可以识别语音消息,启用该功能需要配置语音识别模型[model_task_config.voice]
[message_receive]
# 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息
ban_words = [
# "403","张三"
]
ban_msgs_regex = [
# 需要过滤的消息(原始消息)匹配的正则表达式,匹配到的消息将被过滤,若不了解正则表达式请勿修改
# "https?://[^\\s]+", # 匹配https链接
# "\\d{4}-\\d{2}-\\d{2}", # 匹配日期
]
[lpmm_knowledge] # lpmm知识库配置
enable = false # 是否启用lpmm知识库
lpmm_mode = "agent"
# 可选择classic传统模式/agent 模式,结合新的记忆一同使用
rag_synonym_search_top_k = 10 # 同义检索TopK
rag_synonym_threshold = 0.8 # 同义阈值,相似度高于该值的关系会被当作同义词
info_extraction_workers = 3 # 实体抽取同时执行线程数非Pro模型不要设置超过5
qa_relation_search_top_k = 10 # 关系检索TopK
qa_relation_threshold = 0.5 # 关系阈值,相似度高于该值的关系会被认为是相关关系
qa_paragraph_search_top_k = 1000 # 段落检索TopK不能过小可能影响搜索结果
qa_paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重当搜索仅使用DPR时此参数不起作用
qa_ent_filter_top_k = 10 # 实体过滤TopK
qa_ppr_damping = 0.8 # PPR阻尼系数
qa_res_top_k = 3 # 最终提供段落TopK
embedding_dimension = 1024 # 嵌入向量维度,输出维度一致
# 性能与降级参数(低配机器可下调)
# 低配机器参考:单/双核或内存≤4GB如轻量云主机/云函数/开发板建议先关闭PPR并降低并发
max_embedding_workers = 3 # 嵌入/抽取并发线程数
embedding_chunk_size = 4 # 每批嵌入的条数
max_synonym_entities = 2000 # 同义边参与的实体数上限,超限则跳过
enable_ppr = true # 是否启用PPR低配机器可关闭
[keyword_reaction]
keyword_rules = [
{ keywords = ["人机", "bot", "机器", "入机", "robot", "机器人", "ai", "AI"], reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" },
{ keywords = ["测试关键词回复", "test"], reaction = "回答测试成功" },
# { keywords = ["你好", "hello"], reaction = "你好,有什么可以帮你?" }
# 在此处添加更多规则,格式同上
]
regex_rules = [
{ regex = ["^(?P<n>\\S{1,20})是这样的$"], reaction = "请按照以下模板造句:[n]是这样的xx只要xx就可以可是[n]要考虑的事情就很多了比如什么时候xx什么时候xx什么时候xx。请自由发挥替换xx部分只需保持句式结构同时表达一种将[n]过度重视的反讽意味)" },
{ regex = ["(?:[^,。.\\s]+\\s*)?我(?:也)?[没沒]要求你\\s*(?P<action>.+?)[.。,]?$"],reaction = "进行一次“八岐大蛇定型文”补完。暂时忽略回复风格,进入“粉絲吐槽”模式。请严格按照以下“填空模板”,围绕“[action]”进行创作,使其充满滑稽的夸张感和天马行空的联想:我没要求你[action],我不是(与action关联的、宽容的角色)。可是,(与action相关的具体负面行为)是什么意思?你的(与action相关的某种观念)怎么了?你才(与action相关的某个状态或时间点)吧?再这样下去,你(一个中期的、等比级数式的滑稽推演)(一个后期的、等比级数式的滑稽推演),最后就变成(一个与主题相关的、夸张的最终形态)了。作为(与最终形态相关的、克星或权威身份),我可能得(对你执行一个天罚般的行动)。真的。"}
]
[response_post_process]
enable_response_post_process = true # 是否启用回复后处理,包括错别字生成器,回复分割器
[chinese_typo]
enable = true # 是否启用中文错别字生成器
error_rate=0.01 # 单字替换概率
min_freq=9 # 最小字频阈值
tone_error_rate=0.1 # 声调错误概率
word_replace_rate=0.006 # 整词替换概率
[response_splitter]
enable = true # 是否启用回复分割器
max_length = 512 # 回复允许的最大长度
max_sentence_num = 8 # 回复允许的最大句子数
enable_kaomoji_protection = false # 是否启用颜文字保护
enable_overflow_return_all = false # 是否在句子数量超出回复允许的最大句子数时一次性返回全部内容
[log]
date_style = "m-d H:i:s" # 日期格式
log_level_style = "lite" # 日志级别样式,可选FULLcompactlite
color_text = "full" # 日志文本颜色可选nonetitlefull
log_level = "INFO" # 全局日志级别(向下兼容,优先级低于下面的分别设置)
console_log_level = "INFO" # 控制台日志级别,可选: DEBUG, INFO, WARNING, ERROR, CRITICAL
file_log_level = "DEBUG" # 文件日志级别,可选: DEBUG, INFO, WARNING, ERROR, CRITICAL
# 第三方库日志控制
suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"] # 完全屏蔽的库
library_log_levels = { aiohttp = "WARNING"} # 设置特定库的日志级别
[debug]
show_prompt = false # 是否显示prompt
show_replyer_prompt = false # 是否显示回复器prompt
show_replyer_reasoning = false # 是否显示回复器推理
show_jargon_prompt = false # 是否显示jargon相关提示词
show_memory_prompt = false # 是否显示记忆检索相关提示词
show_planner_prompt = false # 是否显示planner的prompt和原始返回结果
show_lpmm_paragraph = false # 是否显示lpmm找到的相关文段日志
[maim_message]
auth_token = [] # 认证令牌用于旧版API验证为空则不启用验证
# 新版API Server配置额外监听端口
enable_api_server = false # 是否启用额外的新版API Server
api_server_host = "0.0.0.0" # 新版API Server主机地址
api_server_port = 8090 # 新版API Server端口号
api_server_use_wss = false # 新版API Server是否启用WSS
api_server_cert_file = "" # 新版API Server SSL证书文件路径
api_server_key_file = "" # 新版API Server SSL密钥文件路径
api_server_allowed_api_keys = [] # 新版API Server允许的API Key列表为空则允许所有连接
[telemetry] #发送统计信息,主要是看全球有多少只麦麦
enable = true
[webui] # WebUI 独立服务器配置
# 注意: WebUI 的监听地址(host)和端口(port)已移至 .env 文件中的 WEBUI_HOST 和 WEBUI_PORT
enabled = true # 是否启用WebUI
mode = "production" # 模式: development(开发) 或 production(生产)
# 防爬虫配置
anti_crawler_mode = "basic" # 防爬虫模式: false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止)
allowed_ips = "127.0.0.1" # IP白名单逗号分隔支持精确IP、CIDR格式和通配符
# 示例: 127.0.0.1,192.168.1.0/24,172.17.0.0/16
trusted_proxies = "" # 信任的代理IP列表逗号分隔只有来自这些IP的X-Forwarded-For才被信任
# 示例: 127.0.0.1,192.168.1.1,172.17.0.1
trust_xff = false # 是否启用X-Forwarded-For代理解析默认false
# 启用后仍要求直连IP在trusted_proxies中才会信任XFF头
secure_cookie = false # 是否启用安全Cookie仅通过HTTPS传输默认false
[experimental] #实验性功能
# 为指定聊天添加额外的prompt配置
# 格式: ["platform:id:type:prompt内容", ...]
# 示例:
# chat_prompts = [
# "qq:114514:group:这是一个摄影群,你精通摄影知识",
# "qq:19198:group:这是一个二次元交流群",
# "qq:114514:private:这是你与好朋友的私聊"
# ]
chat_prompts = []
# 此系统暂时移除,无效配置
[relationship]
enable_relationship = true # 是否启用关系系统

View File

@ -113,8 +113,6 @@ from .core.claude_config import (
)
from .tool_chain import (
ToolChainDefinition,
ToolChainStep,
ChainExecutionResult,
tool_chain_manager,
)
@ -1651,7 +1649,7 @@ class MCPStatusCommand(BaseCommand):
tool_name = f"chain_{name}".replace("-", "_").replace(".", "_")
if component_registry.get_component_info(tool_name, ComponentType.TOOL):
registered += 1
lines = [f"✅ 已重新加载工具链配置"]
lines = ["✅ 已重新加载工具链配置"]
lines.append(f"📋 配置数: {len(chains)}")
lines.append(f"🔧 已注册: {registered} 个(可被 LLM 调用)")
if chains:
@ -1698,7 +1696,7 @@ class MCPStatusCommand(BaseCommand):
output_preview += "..."
lines.append(output_preview)
else:
lines.append(f"❌ 工具链执行失败")
lines.append("❌ 工具链执行失败")
lines.append(f"错误: {result.error}")
if result.step_results:
lines.append("")
@ -1777,9 +1775,9 @@ class MCPStatusCommand(BaseCommand):
cb = info.get("circuit_breaker", {})
cb_state = cb.get("state", "closed")
if cb_state == "open":
lines.append(f" ⚡ 断路器熔断中")
lines.append(" ⚡ 断路器熔断中")
elif cb_state == "half_open":
lines.append(f" ⚡ 断路器试探中")
lines.append(" ⚡ 断路器试探中")
if info["consecutive_failures"] > 0:
lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']}")

View File

@ -14,12 +14,11 @@ MCP Workflow 模块 v1.9.0
- ReAct 软流程互补用户可选择合适的执行方式
"""
import asyncio
import json
import re
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple
try:
from src.common.logger import get_logger

View File

@ -1,5 +1,5 @@
import random
from typing import List, Tuple, Type, Any
from typing import List, Tuple, Type, Any, Optional
from src.plugin_system import (
BasePlugin,
register_plugin,
@ -17,6 +17,9 @@ from src.plugin_system import (
emoji_api,
)
from src.config.config import global_config
from src.common.logger import get_logger
logger = get_logger("hello_world_plugin")
class CompareNumbersTool(BaseTool):
@ -217,6 +220,39 @@ class RandomEmojis(BaseCommand):
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
class TestCommand(BaseCommand):
"""响应/test命令"""
command_name = "test"
command_description = "测试命令"
command_pattern = r"^/test$"
async def execute(self) -> Tuple[bool, Optional[str], int]:
"""执行测试命令"""
try:
from src.plugin_system.apis import generator_api
reply_reason = "这是一条测试消息。"
logger.info(f"测试命令:{reply_reason}")
result_status, data = await generator_api.generate_reply(
chat_stream=self.message.chat_stream,
reply_reason=reply_reason,
enable_chinese_typo=False,
extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"",
)
if result_status:
# 发送生成的回复
if data and data.reply_set and data.reply_set.reply_data:
for reply_seg in data.reply_set.reply_data:
send_data = reply_seg.content
await self.send_text(send_data, storage_message=True)
logger.info(f"已回复: {send_data}")
return True, "", 1
except Exception as e:
logger.error(f"表达器生成失败:{e}")
return True, "", 1
# ===== 插件注册 =====
@ -259,6 +295,7 @@ class HelloWorldPlugin(BasePlugin):
(PrintMessage.get_handler_info(), PrintMessage),
(ForwardMessages.get_handler_info(), ForwardMessages),
(RandomEmojis.get_command_info(), RandomEmojis),
(TestCommand.get_command_info(), TestCommand),
]

View File

@ -0,0 +1,322 @@
"""
评估结果统计脚本
功能
1. 扫描temp目录下所有JSON文件
2. 分析每个文件的统计信息
3. 输出详细的统计报告
"""
import json
import os
import sys
import glob
from collections import Counter
from datetime import datetime
from typing import Dict, List, Set, Tuple
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.logger import get_logger
logger = get_logger("evaluation_stats_analyzer")
# 评估结果文件路径
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
def parse_datetime(dt_str: str) -> datetime | None:
"""解析ISO格式的日期时间字符串"""
try:
return datetime.fromisoformat(dt_str)
except Exception:
return None
def analyze_single_file(file_path: str) -> Dict:
"""
分析单个JSON文件的统计信息
Args:
file_path: JSON文件路径
Returns:
统计信息字典
"""
file_name = os.path.basename(file_path)
stats = {
"file_name": file_name,
"file_path": file_path,
"file_size": os.path.getsize(file_path),
"error": None,
"last_updated": None,
"total_count": 0,
"actual_count": 0,
"suitable_count": 0,
"unsuitable_count": 0,
"suitable_rate": 0.0,
"unique_pairs": 0,
"evaluators": Counter(),
"evaluation_dates": [],
"date_range": None,
"has_expression_id": False,
"has_reason": False,
"reason_count": 0,
}
try:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
# 基本信息
stats["last_updated"] = data.get("last_updated")
stats["total_count"] = data.get("total_count", 0)
results = data.get("manual_results", [])
stats["actual_count"] = len(results)
if not results:
return stats
# 统计通过/不通过
suitable_count = sum(1 for r in results if r.get("suitable") is True)
unsuitable_count = sum(1 for r in results if r.get("suitable") is False)
stats["suitable_count"] = suitable_count
stats["unsuitable_count"] = unsuitable_count
stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0
# 统计唯一的(situation, style)对
pairs: Set[Tuple[str, str]] = set()
for r in results:
if "situation" in r and "style" in r:
pairs.add((r["situation"], r["style"]))
stats["unique_pairs"] = len(pairs)
# 统计评估者
for r in results:
evaluator = r.get("evaluator", "unknown")
stats["evaluators"][evaluator] += 1
# 统计评估时间
evaluation_dates = []
for r in results:
evaluated_at = r.get("evaluated_at")
if evaluated_at:
dt = parse_datetime(evaluated_at)
if dt:
evaluation_dates.append(dt)
stats["evaluation_dates"] = evaluation_dates
if evaluation_dates:
min_date = min(evaluation_dates)
max_date = max(evaluation_dates)
stats["date_range"] = {
"start": min_date.isoformat(),
"end": max_date.isoformat(),
"duration_days": (max_date - min_date).days + 1
}
# 检查字段存在性
stats["has_expression_id"] = any("expression_id" in r for r in results)
stats["has_reason"] = any(r.get("reason") for r in results)
stats["reason_count"] = sum(1 for r in results if r.get("reason"))
except Exception as e:
stats["error"] = str(e)
logger.error(f"分析文件 {file_name} 时出错: {e}")
return stats
def print_file_stats(stats: Dict, index: int = None):
"""打印单个文件的统计信息"""
prefix = f"[{index}] " if index is not None else ""
print(f"\n{'=' * 80}")
print(f"{prefix}文件: {stats['file_name']}")
print(f"{'=' * 80}")
if stats["error"]:
print(f"✗ 错误: {stats['error']}")
return
print(f"文件路径: {stats['file_path']}")
print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)")
if stats["last_updated"]:
print(f"最后更新: {stats['last_updated']}")
print("\n【记录统计】")
print(f" 文件中的 total_count: {stats['total_count']}")
print(f" 实际记录数: {stats['actual_count']}")
if stats['total_count'] != stats['actual_count']:
diff = stats['total_count'] - stats['actual_count']
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
print("\n【评估结果统计】")
print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)")
print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)")
print("\n【唯一性统计】")
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']}")
if stats['actual_count'] > 0:
duplicate_count = stats['actual_count'] - stats['unique_pairs']
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
print("\n【评估者统计】")
if stats['evaluators']:
for evaluator, count in stats['evaluators'].most_common():
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
else:
print(" 无评估者信息")
print("\n【时间统计】")
if stats['date_range']:
print(f" 最早评估时间: {stats['date_range']['start']}")
print(f" 最晚评估时间: {stats['date_range']['end']}")
print(f" 评估时间跨度: {stats['date_range']['duration_days']}")
else:
print(" 无时间信息")
print("\n【字段统计】")
print(f" 包含 expression_id: {'' if stats['has_expression_id'] else ''}")
print(f" 包含 reason: {'' if stats['has_reason'] else ''}")
if stats['has_reason']:
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
def print_summary(all_stats: List[Dict]):
"""打印汇总统计信息"""
print(f"\n{'=' * 80}")
print("汇总统计")
print(f"{'=' * 80}")
total_files = len(all_stats)
valid_files = [s for s in all_stats if not s.get("error")]
error_files = [s for s in all_stats if s.get("error")]
print("\n【文件统计】")
print(f" 总文件数: {total_files}")
print(f" 成功解析: {len(valid_files)}")
print(f" 解析失败: {len(error_files)}")
if error_files:
print("\n 失败文件列表:")
for stats in error_files:
print(f" - {stats['file_name']}: {stats['error']}")
if not valid_files:
print("\n没有成功解析的文件")
return
# 汇总记录统计
total_records = sum(s['actual_count'] for s in valid_files)
total_suitable = sum(s['suitable_count'] for s in valid_files)
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
total_unique_pairs = set()
# 收集所有唯一的(situation, style)对
for stats in valid_files:
try:
with open(stats['file_path'], "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
for r in results:
if "situation" in r and "style" in r:
total_unique_pairs.add((r["situation"], r["style"]))
except Exception:
pass
print("\n【记录汇总】")
print(f" 总记录数: {total_records:,}")
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,}")
if total_records > 0:
duplicate_count = total_records - len(total_unique_pairs)
duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0
print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)")
# 汇总评估者统计
all_evaluators = Counter()
for stats in valid_files:
all_evaluators.update(stats['evaluators'])
print("\n【评估者汇总】")
if all_evaluators:
for evaluator, count in all_evaluators.most_common():
rate = (count / total_records * 100) if total_records > 0 else 0
print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)")
else:
print(" 无评估者信息")
# 汇总时间范围
all_dates = []
for stats in valid_files:
all_dates.extend(stats['evaluation_dates'])
if all_dates:
min_date = min(all_dates)
max_date = max(all_dates)
print("\n【时间汇总】")
print(f" 最早评估时间: {min_date.isoformat()}")
print(f" 最晚评估时间: {max_date.isoformat()}")
print(f" 总时间跨度: {(max_date - min_date).days + 1}")
# 文件大小汇总
total_size = sum(s['file_size'] for s in valid_files)
avg_size = total_size / len(valid_files) if valid_files else 0
print("\n【文件大小汇总】")
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
print(f" 平均大小: {avg_size:,.0f} 字节 ({avg_size / 1024:.2f} KB)")
def main():
"""主函数"""
logger.info("=" * 80)
logger.info("开始分析评估结果统计信息")
logger.info("=" * 80)
if not os.path.exists(TEMP_DIR):
print(f"\n✗ 错误未找到temp目录: {TEMP_DIR}")
logger.error(f"未找到temp目录: {TEMP_DIR}")
return
# 查找所有JSON文件
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
if not json_files:
print(f"\n✗ 错误temp目录下未找到JSON文件: {TEMP_DIR}")
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
return
json_files.sort() # 按文件名排序
print(f"\n找到 {len(json_files)} 个JSON文件")
print("=" * 80)
# 分析每个文件
all_stats = []
for i, json_file in enumerate(json_files, 1):
stats = analyze_single_file(json_file)
all_stats.append(stats)
print_file_stats(stats, index=i)
# 打印汇总统计
print_summary(all_stats)
print(f"\n{'=' * 80}")
print("分析完成")
print(f"{'=' * 80}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,507 @@
import argparse
import asyncio
import os
import sys
import time
import json
import importlib
from typing import Dict, Any
from datetime import datetime
# 强制使用 utf-8避免控制台编码报错
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
# 确保能导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db
from src.common.database.database_model import LLMUsage
logger = get_logger("compare_finish_search_token")
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况
Args:
start_time: 开始时间戳
Returns:
包含token使用统计的字典
"""
try:
start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录
records = (
LLMUsage.select()
.where(
(LLMUsage.timestamp >= start_datetime)
& (
(LLMUsage.request_type.like("%memory%"))
| (LLMUsage.request_type == "memory.question")
| (LLMUsage.request_type == "memory.react")
| (LLMUsage.request_type == "memory.react.final")
)
)
.order_by(LLMUsage.timestamp.asc())
)
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
total_cost = 0.0
request_count = 0
model_usage = {} # 按模型统计
for record in records:
total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0
request_count += 1
# 按模型统计
model_name = record.model_name or "unknown"
if model_name not in model_usage:
model_usage[model_name] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost": 0.0,
"request_count": 0,
}
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1
return {
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"total_cost": total_cost,
"request_count": request_count,
"model_usage": model_usage,
}
except Exception as e:
logger.error(f"获取token使用情况失败: {e}")
return {
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"total_tokens": 0,
"total_cost": 0.0,
"request_count": 0,
"model_usage": {},
}
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try:
# 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
)
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith('src.memory_system.') and key != 'src.memory_system':
keys_to_remove.append(key)
for key in keys_to_remove:
try:
del sys.modules[key]
except KeyError:
pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try:
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
try:
import src.chat.replyer.group_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
try:
import src.chat.replyer.private_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name)
return (
memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question,
)
except (ImportError, AttributeError) as e:
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
raise
def _init_tools_without_finish_search():
"""初始化工具但不注册 finish_search"""
from src.memory_system.retrieval_tools import (
register_query_chat_history,
register_query_person_info,
register_query_words,
)
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
from src.config.config import global_config
# 清空工具注册器
tool_registry = get_tool_registry()
tool_registry.tools.clear()
# 注册除 finish_search 外的所有工具
register_query_chat_history()
register_query_person_info()
register_query_words()
# 如果启用 LPMM agent 模式,也注册 LPMM 工具
if global_config.lpmm_knowledge.lpmm_mode == "agent":
from src.memory_system.retrieval_tools.query_lpmm_knowledge import register_tool as register_lpmm_knowledge
register_lpmm_knowledge()
logger.info("已初始化工具(不包含 finish_search")
def _init_tools_with_finish_search():
"""初始化工具并注册 finish_search"""
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
from src.memory_system.retrieval_tools import init_all_tools
# 清空工具注册器
tool_registry = get_tool_registry()
tool_registry.tools.clear()
# 初始化所有工具(包括 finish_search
init_all_tools()
logger.info("已初始化工具(包含 finish_search")
async def get_prompt_tokens_for_tools(
question: str,
chat_id: str,
use_finish_search: bool,
) -> Dict[str, Any]:
"""获取使用不同工具配置时的prompt token消耗
Args:
question: 要查询的问题
chat_id: 聊天ID
use_finish_search: 是否使用 finish_search 工具
Returns:
包含prompt token信息的字典
"""
# 先初始化 prompt如果还未初始化
# 注意init_memory_retrieval_prompt 会调用 init_all_tools所以我们需要在它之后重新设置工具
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt, _ = _import_memory_retrieval()
init_memory_retrieval_prompt()
# 初始化工具(根据参数决定是否包含 finish_search
# 必须在 init_memory_retrieval_prompt 之后调用,因为它会调用 init_all_tools
if use_finish_search:
_init_tools_with_finish_search()
else:
_init_tools_without_finish_search()
# 获取工具注册器
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
tool_registry = get_tool_registry()
tool_definitions = tool_registry.get_tool_definitions()
# 验证工具列表(调试用)
tool_names = [tool["name"] for tool in tool_definitions]
if use_finish_search:
if "finish_search" not in tool_names:
logger.warning("期望包含 finish_search 工具,但工具列表中未找到")
else:
if "finish_search" in tool_names:
logger.warning("期望不包含 finish_search 工具,但工具列表中找到了,将移除")
# 移除 finish_search 工具
tool_registry.tools.pop("finish_search", None)
tool_definitions = tool_registry.get_tool_definitions()
tool_names = [tool["name"] for tool in tool_definitions]
# 构建第一次调用的prompt模拟_react_agent_solve_question的第一次调用
from src.config.config import global_config
bot_name = global_config.bot.nickname
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# 构建head_prompt
head_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_prompt_head",
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info="",
current_iteration=1,
remaining_iterations=global_config.memory.max_agent_iterations - 1,
max_iterations=global_config.memory.max_agent_iterations,
)
# 构建消息列表只包含system message模拟第一次调用
from src.llm_models.payload_content.message import MessageBuilder, RoleType
messages = []
system_builder = MessageBuilder()
system_builder.set_role(RoleType.System)
system_builder.add_text_content(head_prompt)
messages.append(system_builder.build())
# 调用LLM API来计算token只调用一次不实际执行
from src.llm_models.utils_model import LLMRequest, RequestType
from src.config.config import model_config
# 创建LLM请求对象
llm_request = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="memory.react.compare")
# 构建工具选项
tool_built = llm_request._build_tool_options(tool_definitions)
# 直接调用 _execute_request 以获取完整的响应对象(包含 usage
response, model_info = await llm_request._execute_request(
request_type=RequestType.RESPONSE,
message_factory=lambda _client, *, _messages=messages: _messages,
temperature=None,
max_tokens=None,
tool_options=tool_built,
)
# 从响应中获取token使用情况
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
if response and hasattr(response, 'usage') and response.usage:
prompt_tokens = response.usage.prompt_tokens or 0
completion_tokens = response.usage.completion_tokens or 0
total_tokens = response.usage.total_tokens or 0
return {
"use_finish_search": use_finish_search,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"tool_count": len(tool_definitions),
"tool_names": [tool["name"] for tool in tool_definitions],
}
async def compare_prompt_tokens(
question: str,
chat_id: str = "compare_finish_search",
) -> Dict[str, Any]:
"""对比使用 finish_search 工具与否的输入 token 差异
只运行一次只计算输入 token 的差异确保除了工具定义外其他内容一致
Args:
question: 要查询的问题
chat_id: 聊天ID
Returns:
包含对比结果的字典
"""
print("\n" + "=" * 80)
print("finish_search 工具 输入 Token 消耗对比测试")
print("=" * 80)
print(f"\n[测试问题] {question}")
print(f"[聊天ID] {chat_id}")
print("\n注意: 只对比第一次LLM调用的输入token差异不运行完整迭代流程")
# 第一次测试:不使用 finish_search
print("\n" + "-" * 80)
print("[测试 1/2] 不使用 finish_search 工具")
print("-" * 80)
result_without = await get_prompt_tokens_for_tools(
question=question,
chat_id=f"{chat_id}_without",
use_finish_search=False,
)
print(f"\n[结果]")
print(f" 工具数量: {result_without['tool_count']}")
print(f" 工具列表: {', '.join(result_without['tool_names'])}")
print(f" 输入 Prompt Tokens: {result_without['prompt_tokens']:,}")
# 等待一下,确保数据库记录已写入
await asyncio.sleep(1)
# 第二次测试:使用 finish_search
print("\n" + "-" * 80)
print("[测试 2/2] 使用 finish_search 工具")
print("-" * 80)
result_with = await get_prompt_tokens_for_tools(
question=question,
chat_id=f"{chat_id}_with",
use_finish_search=True,
)
print(f"\n[结果]")
print(f" 工具数量: {result_with['tool_count']}")
print(f" 工具列表: {', '.join(result_with['tool_names'])}")
print(f" 输入 Prompt Tokens: {result_with['prompt_tokens']:,}")
# 对比结果
print("\n" + "=" * 80)
print("[对比结果]")
print("=" * 80)
prompt_token_diff = result_with['prompt_tokens'] - result_without['prompt_tokens']
prompt_token_diff_percent = (prompt_token_diff / result_without['prompt_tokens'] * 100) if result_without['prompt_tokens'] > 0 else 0
tool_count_diff = result_with['tool_count'] - result_without['tool_count']
print(f"\n[输入 Prompt Token 对比]")
print(f" 不使用 finish_search: {result_without['prompt_tokens']:,} tokens")
print(f" 使用 finish_search: {result_with['prompt_tokens']:,} tokens")
print(f" 差异: {prompt_token_diff:+,} tokens ({prompt_token_diff_percent:+.2f}%)")
print(f"\n[工具数量对比]")
print(f" 不使用 finish_search: {result_without['tool_count']} 个工具")
print(f" 使用 finish_search: {result_with['tool_count']} 个工具")
print(f" 差异: {tool_count_diff:+d} 个工具")
print(f"\n[工具列表对比]")
without_tools = set(result_without['tool_names'])
with_tools = set(result_with['tool_names'])
only_with = with_tools - without_tools
only_without = without_tools - with_tools
if only_with:
print(f" 仅在 '使用 finish_search' 中的工具: {', '.join(only_with)}")
if only_without:
print(f" 仅在 '不使用 finish_search' 中的工具: {', '.join(only_without)}")
if not only_with and not only_without:
print(f" 工具列表相同(除了 finish_search")
# 显示其他token信息
print(f"\n[其他 Token 信息]")
print(f" Completion Tokens (不使用 finish_search): {result_without.get('completion_tokens', 0):,}")
print(f" Completion Tokens (使用 finish_search): {result_with.get('completion_tokens', 0):,}")
print(f" 总 Tokens (不使用 finish_search): {result_without.get('total_tokens', 0):,}")
print(f" 总 Tokens (使用 finish_search): {result_with.get('total_tokens', 0):,}")
print("\n" + "=" * 80)
return {
"question": question,
"without_finish_search": result_without,
"with_finish_search": result_with,
"comparison": {
"prompt_token_diff": prompt_token_diff,
"prompt_token_diff_percent": prompt_token_diff_percent,
"tool_count_diff": tool_count_diff,
},
}
def main() -> None:
parser = argparse.ArgumentParser(
description="对比使用 finish_search 工具与否的 token 消耗差异"
)
parser.add_argument(
"--chat-id",
default="compare_finish_search",
help="测试用的聊天ID默认: compare_finish_search",
)
parser.add_argument(
"--output",
"-o",
help="将结果保存到JSON文件可选",
)
args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False)
# 交互式输入问题
print("\n" + "=" * 80)
print("finish_search 工具 Token 消耗对比测试工具")
print("=" * 80)
question = input("\n请输入要查询的问题: ").strip()
if not question:
print("错误: 问题不能为空")
return
# 连接数据库
try:
db.connect(reuse_if_open=True)
except Exception as e:
logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}")
return
# 运行对比测试
try:
result = asyncio.run(
compare_prompt_tokens(
question=question,
chat_id=args.chat_id,
)
)
# 如果指定了输出文件,保存结果
if args.output:
# 将thinking_steps转换为可序列化的格式
output_result = result.copy()
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试")
except Exception as e:
logger.error(f"测试失败: {e}", exc_info=True)
print(f"\n[错误] 测试失败: {e}")
finally:
try:
db.close()
except Exception:
pass
if __name__ == "__main__":
main()

View File

@ -0,0 +1,556 @@
"""
表达方式按count分组的LLM评估和统计分析脚本
功能
1. 随机选择50条表达至少要有20条count>1的项目然后进行LLM评估
2. 比较不同count之间的LLM评估合格率是否有显著差异
- 首先每个count分开比较
- 然后比较count为1和count大于1的两种
"""
import asyncio
import random
import json
import sys
import os
import re
from typing import List, Dict, Set, Tuple
from datetime import datetime
from collections import defaultdict
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression
from src.common.database.database import db
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
logger = get_logger("expression_evaluator_count_analysis_llm")
# 评估结果文件路径
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.json")
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
"""
加载已有的评估结果
Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合)
"""
if not os.path.exists(COUNT_ANALYSIS_FILE):
return [], set()
try:
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("evaluation_results", [])
# 使用 (situation, style) 作为唯一标识
evaluated_pairs = {(r["situation"], r["style"]) for r in results if "situation" in r and "style" in r}
logger.info(f"已加载 {len(results)} 条已有评估结果")
return results, evaluated_pairs
except Exception as e:
logger.error(f"加载已有评估结果失败: {e}")
return [], set()
def save_results(evaluation_results: List[Dict]):
"""
保存评估结果到文件
Args:
evaluation_results: 评估结果列表
"""
try:
os.makedirs(TEMP_DIR, exist_ok=True)
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(evaluation_results),
"evaluation_results": evaluation_results
}
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
except Exception as e:
logger.error(f"保存评估结果失败: {e}")
print(f"\n✗ 保存评估结果失败: {e}")
def select_expressions_for_evaluation(
evaluated_pairs: Set[Tuple[str, str]] = None
) -> List[Expression]:
"""
选择用于评估的表达方式
选择所有count>1的项目然后选择两倍数量的count=1的项目
Args:
evaluated_pairs: 已评估的项目集合用于避免重复
Returns:
选中的表达方式列表
"""
if evaluated_pairs is None:
evaluated_pairs = set()
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 过滤出未评估的项目
unevaluated = [
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
if not unevaluated:
logger.warning("所有项目都已评估完成")
return []
# 按count分组
count_eq1 = [expr for expr in unevaluated if expr.count == 1]
count_gt1 = [expr for expr in unevaluated if expr.count > 1]
logger.info(f"未评估项目中count=1的有{len(count_eq1)}count>1的有{len(count_gt1)}")
# 选择所有count>1的项目
selected_count_gt1 = count_gt1.copy()
# 选择count=1的项目数量为count>1数量的2倍
count_gt1_count = len(selected_count_gt1)
count_eq1_needed = count_gt1_count * 2
if len(count_eq1) < count_eq1_needed:
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}")
count_eq1_needed = len(count_eq1)
# 随机选择count=1的项目
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
selected = selected_count_gt1 + selected_count_eq1
random.shuffle(selected) # 打乱顺序
logger.info(f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍")
return selected
except Exception as e:
logger.error(f"选择表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
使用条件或使用情景{situation}
表达方式或言语风格{style}
请从以下方面进行评估
1. 表达方式或言语风格 是否与使用条件或使用情景 匹配
2. 允许部分语法错误或口头化或缺省出现
3. 表达方式不能太过特指需要具有泛用性
4. 一般不涉及具体的人名或名称
请以JSON格式输出评估结果
{{
"suitable": true/false,
"reason": "评估理由(如果不合适,请说明原因)"
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容"""
return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
expression: 表达方式对象
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"situation": expression.situation,
"style": expression.style,
"count": expression.count,
"suitable": suitable,
"reason": reason,
"error": error,
"evaluator": "llm",
"evaluated_at": datetime.now().isoformat()
}
def perform_statistical_analysis(evaluation_results: List[Dict]):
"""
对评估结果进行统计分析
Args:
evaluation_results: 评估结果列表
"""
if not evaluation_results:
print("\n没有评估结果可供分析")
return
print("\n" + "=" * 60)
print("统计分析结果")
print("=" * 60)
# 按count分组统计
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
for result in evaluation_results:
count = result.get("count", 1)
suitable = result.get("suitable", False)
count_groups[count]["total"] += 1
if suitable:
count_groups[count]["suitable"] += 1
else:
count_groups[count]["unsuitable"] += 1
# 显示每个count的统计
print("\n【按count分组统计】")
print("-" * 60)
for count in sorted(count_groups.keys()):
group = count_groups[count]
total = group["total"]
suitable = group["suitable"]
unsuitable = group["unsuitable"]
pass_rate = (suitable / total * 100) if total > 0 else 0
print(f"Count = {count}:")
print(f" 总数: {total}")
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)")
print()
# 比较count=1和count>1
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
for result in evaluation_results:
count = result.get("count", 1)
suitable = result.get("suitable", False)
if count == 1:
count_eq1_group["total"] += 1
if suitable:
count_eq1_group["suitable"] += 1
else:
count_eq1_group["unsuitable"] += 1
else:
count_gt1_group["total"] += 1
if suitable:
count_gt1_group["suitable"] += 1
else:
count_gt1_group["unsuitable"] += 1
print("\n【Count=1 vs Count>1 对比】")
print("-" * 60)
eq1_total = count_eq1_group["total"]
eq1_suitable = count_eq1_group["suitable"]
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
gt1_total = count_gt1_group["total"]
gt1_suitable = count_gt1_group["suitable"]
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
print("Count = 1:")
print(f" 总数: {eq1_total}")
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)")
print()
print("Count > 1:")
print(f" 总数: {gt1_total}")
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)")
print()
# 进行卡方检验简化版使用2x2列联表
if eq1_total > 0 and gt1_total > 0:
print("【统计显著性检验】")
print("-" * 60)
# 构建2x2列联表
# 通过 不通过
# count=1 a b
# count>1 c d
a = eq1_suitable
b = eq1_total - eq1_suitable
c = gt1_suitable
d = gt1_total - gt1_suitable
# 计算卡方统计量简化版使用Pearson卡方检验
n = eq1_total + gt1_total
if n > 0:
# 期望频数
e_a = (eq1_total * (a + c)) / n
e_b = (eq1_total * (b + d)) / n
e_c = (gt1_total * (a + c)) / n
e_d = (gt1_total * (b + d)) / n
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5
min_expected = min(e_a, e_b, e_c, e_d)
if min_expected < 5:
print("警告期望频数小于5卡方检验可能不准确")
print("建议使用Fisher精确检验")
# 计算卡方值
chi_square = 0
if e_a > 0:
chi_square += ((a - e_a) ** 2) / e_a
if e_b > 0:
chi_square += ((b - e_b) ** 2) / e_b
if e_c > 0:
chi_square += ((c - e_c) ** 2) / e_c
if e_d > 0:
chi_square += ((d - e_d) ** 2) / e_d
# 自由度 = (行数-1) * (列数-1) = 1
df = 1
# 临界值(α=0.05
chi_square_critical_005 = 3.841
chi_square_critical_001 = 6.635
print(f"卡方统计量: {chi_square:.4f}")
print(f"自由度: {df}")
print(f"临界值 (α=0.05): {chi_square_critical_005}")
print(f"临界值 (α=0.01): {chi_square_critical_001}")
if chi_square >= chi_square_critical_001:
print("结论: 在α=0.01水平下count=1和count>1的合格率存在显著差异p<0.01")
elif chi_square >= chi_square_critical_005:
print("结论: 在α=0.05水平下count=1和count>1的合格率存在显著差异p<0.05")
else:
print("结论: 在α=0.05水平下count=1和count>1的合格率不存在显著差异p≥0.05")
# 计算差异大小
diff = abs(eq1_pass_rate - gt1_pass_rate)
print(f"\n合格率差异: {diff:.2f}%")
if diff > 10:
print("差异较大(>10%")
elif diff > 5:
print("差异中等5-10%")
else:
print("差异较小(<5%")
else:
print("数据不足,无法进行统计检验")
else:
print("数据不足无法进行count=1和count>1的对比分析")
# 保存统计分析结果
analysis_result = {
"analysis_time": datetime.now().isoformat(),
"count_groups": {str(k): v for k, v in count_groups.items()},
"count_eq1": count_eq1_group,
"count_gt1": count_gt1_group,
"total_evaluated": len(evaluation_results)
}
try:
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
with open(analysis_file, "w", encoding="utf-8") as f:
json.dump(analysis_result, f, ensure_ascii=False, indent=2)
print(f"\n✓ 统计分析结果已保存到: {analysis_file}")
except Exception as e:
logger.error(f"保存统计分析结果失败: {e}")
async def main():
"""主函数"""
logger.info("=" * 60)
logger.info("开始表达方式按count分组的LLM评估和统计分析")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
logger.info("数据库连接成功")
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results()
evaluation_results = existing_results.copy()
if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}")
# 检查是否需要继续评估检查是否还有未评估的count>1项目
# 先查询未评估的count>1项目数量
try:
all_expressions = list(Expression.select())
unevaluated_count_gt1 = [
expr for expr in all_expressions
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
]
has_unevaluated = len(unevaluated_count_gt1) > 0
except Exception as e:
logger.error(f"查询未评估项目失败: {e}")
has_unevaluated = False
if has_unevaluated:
print("\n" + "=" * 60)
print("开始LLM评估")
print("=" * 60)
print("评估结果会自动保存到文件\n")
# 创建LLM实例
print("创建LLM实例...")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_count_analysis_llm"
)
print("✓ LLM实例创建成功\n")
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
print(f"\n✗ 创建LLM实例失败: {e}")
db.close()
return
# 选择需要评估的表达方式选择所有count>1的项目然后选择两倍数量的count=1的项目
expressions = select_expressions_for_evaluation(
evaluated_pairs=evaluated_pairs
)
if not expressions:
print("\n没有可评估的项目")
else:
print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)}")
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)}\n")
batch_results = []
for i, expression in enumerate(expressions, 1):
print(f"LLM评估进度: {i}/{len(expressions)}")
print(f" Situation: {expression.situation}")
print(f" Style: {expression.style}")
print(f" Count: {expression.count}")
llm_result = await llm_evaluate_expression(expression, llm)
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
if llm_result.get('error'):
print(f" 错误: {llm_result['error']}")
print()
batch_results.append(llm_result)
# 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
# 添加延迟以避免API限流
await asyncio.sleep(0.3)
# 将当前批次结果添加到总结果中
evaluation_results.extend(batch_results)
# 保存结果
save_results(evaluation_results)
else:
print(f"\n所有count>1的项目都已评估完成已有 {len(evaluation_results)} 条评估结果")
# 进行统计分析
if len(evaluation_results) > 0:
perform_statistical_analysis(evaluation_results)
else:
print("\n没有评估结果可供分析")
# 关闭数据库连接
try:
db.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.warning(f"关闭数据库连接时出错: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,523 @@
"""
表达方式LLM评估脚本
功能
1. 读取已保存的人工评估结果作为效标
2. 使用LLM对相同项目进行评估
3. 对比人工评估和LLM评估的结果输出分析报告
"""
import asyncio
import argparse
import json
import random
import sys
import os
import glob
from typing import List, Dict, Set, Tuple
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.common.logger import get_logger
logger = get_logger("expression_evaluator_llm")
# 评估结果文件路径
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
def load_manual_results() -> List[Dict]:
"""
加载人工评估结果自动读取temp目录下所有JSON文件并合并
Returns:
人工评估结果列表已去重
"""
if not os.path.exists(TEMP_DIR):
logger.error(f"未找到temp目录: {TEMP_DIR}")
print("\n✗ 错误未找到temp目录")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return []
# 查找所有JSON文件
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
if not json_files:
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
print("\n✗ 错误temp目录下未找到JSON文件")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return []
logger.info(f"找到 {len(json_files)} 个JSON文件")
print(f"\n找到 {len(json_files)} 个JSON文件:")
for json_file in json_files:
print(f" - {os.path.basename(json_file)}")
# 读取并合并所有JSON文件
all_results = []
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
for json_file in json_files:
try:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
# 去重:使用(situation, style)作为唯一标识
for result in results:
if "situation" not in result or "style" not in result:
logger.warning(f"跳过无效数据(缺少必要字段): {result}")
continue
pair = (result["situation"], result["style"])
if pair not in seen_pairs:
seen_pairs.add(pair)
all_results.append(result)
logger.info(f"{os.path.basename(json_file)} 加载了 {len(results)} 条结果")
except Exception as e:
logger.error(f"加载文件 {json_file} 失败: {e}")
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
continue
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
return all_results
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
使用条件或使用情景{situation}
表达方式或言语风格{style}
请从以下方面进行评估
1. 表达方式或言语风格 是否与使用条件或使用情景 匹配
2. 允许部分语法错误或口头化或缺省出现
3. 表达方式不能太过特指需要具有泛用性
4. 一般不涉及具体的人名或名称
请以JSON格式输出评估结果
{{
"suitable": true/false,
"reason": "评估理由(如果不合适,请说明原因)"
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容"""
return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式: situation={situation}, style={style}")
suitable, reason, error = await _single_llm_evaluation(situation, style, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"situation": situation,
"style": style,
"suitable": suitable,
"reason": reason,
"error": error,
"evaluator": "llm"
}
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
"""
对比人工评估和LLM评估的结果
Args:
manual_results: 人工评估结果列表
llm_results: LLM评估结果列表
method_name: 评估方法名称用于标识
Returns:
对比分析结果字典
"""
# 按(situation, style)建立映射
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
total = len(manual_results)
matched = 0
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0
for manual_result in manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
manual_suitable = manual_result["suitable"]
llm_suitable = llm_result["suitable"]
if manual_suitable == llm_suitable:
matched += 1
if manual_suitable and llm_suitable:
true_positives += 1
elif not manual_suitable and not llm_suitable:
true_negatives += 1
elif not manual_suitable and llm_suitable:
false_positives += 1
elif manual_suitable and not llm_suitable:
false_negatives += 1
accuracy = (matched / total * 100) if total > 0 else 0
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
# 计算人工效标的不合适率
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0
# 计算经过LLM删除后剩余项目中的不合适率
# 在所有项目中移除LLM判定为不合适的项目后剩下的项目 = TP + FPLLM判定为合适的项目
# 在这些剩下的项目中,按人工评定的不合适项目 = FP人工认为不合适但LLM认为合适
llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数保留的项目
llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0
# 两者百分比相减评估LLM评定修正后的不合适率是否有降低
rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate
random_baseline = 50.0
accuracy_above_random = accuracy - random_baseline
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
return {
"method": method_name,
"total": total,
"matched": matched,
"accuracy": accuracy,
"accuracy_above_random": accuracy_above_random,
"accuracy_improvement_ratio": accuracy_improvement_ratio,
"true_positives": true_positives,
"true_negatives": true_negatives,
"false_positives": false_positives,
"false_negatives": false_negatives,
"precision": precision,
"recall": recall,
"f1_score": f1_score,
"specificity": specificity,
"manual_unsuitable_rate": manual_unsuitable_rate,
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
"rate_difference": rate_difference
}
async def main(count: int | None = None):
"""
主函数
Args:
count: 随机选取的数据条数如果为None则使用全部数据
"""
logger.info("=" * 60)
logger.info("开始表达方式LLM评估")
logger.info("=" * 60)
# 1. 加载人工评估结果
print("\n步骤1: 加载人工评估结果")
manual_results = load_manual_results()
if not manual_results:
return
print(f"成功加载 {len(manual_results)} 条人工评估结果")
# 如果指定了数量,随机选择指定数量的数据
if count is not None:
if count <= 0:
print(f"\n✗ 错误指定的数量必须大于0当前值: {count}")
return
if count > len(manual_results):
print(f"\n⚠ 警告:指定的数量 ({count}) 大于可用数据量 ({len(manual_results)}),将使用全部数据")
else:
random.seed() # 使用系统时间作为随机种子
manual_results = random.sample(manual_results, count)
print(f"随机选取 {len(manual_results)} 条数据进行评估")
# 验证数据完整性
valid_manual_results = []
for r in manual_results:
if "situation" in r and "style" in r:
valid_manual_results.append(r)
else:
logger.warning(f"跳过无效数据: {r}")
if len(valid_manual_results) != len(manual_results):
print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过")
print(f"有效数据: {len(valid_manual_results)}")
# 2. 创建LLM实例并评估
print("\n步骤2: 创建LLM实例")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_llm"
)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
return
print("\n步骤3: 开始LLM评估")
llm_results = []
for i, manual_result in enumerate(valid_manual_results, 1):
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
llm_results.append(await evaluate_expression_llm(
manual_result["situation"],
manual_result["style"],
llm
))
await asyncio.sleep(0.3)
# 5. 输出FP和FN项目在评估结果之前
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
# 5.1 输出FP项目人工评估不通过但LLM误判为通过
print("\n" + "=" * 60)
print("人工评估不通过但LLM误判为通过的项目FP - False Positive")
print("=" * 60)
fp_items = []
for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
# 人工评估不通过但LLM评估通过FP情况
if not manual_result["suitable"] and llm_result["suitable"]:
fp_items.append({
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
if fp_items:
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
for idx, item in enumerate(fp_items, 1):
print(f"--- [{idx}] ---")
print(f"Situation: {item['situation']}")
print(f"Style: {item['style']}")
print("人工评估: 不通过 ❌")
print("LLM评估: 通过 ✅ (误判)")
if item.get('llm_error'):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误判项目所有人工评估不通过的项目都被LLM正确识别为不通过")
# 5.2 输出FN项目人工评估通过但LLM误判为不通过
print("\n" + "=" * 60)
print("人工评估通过但LLM误判为不通过的项目FN - False Negative")
print("=" * 60)
fn_items = []
for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
# 人工评估通过但LLM评估不通过FN情况
if manual_result["suitable"] and not llm_result["suitable"]:
fn_items.append({
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
if fn_items:
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
for idx, item in enumerate(fn_items, 1):
print(f"--- [{idx}] ---")
print(f"Situation: {item['situation']}")
print(f"Style: {item['style']}")
print("人工评估: 通过 ✅")
print("LLM评估: 不通过 ❌ (误删)")
if item.get('llm_error'):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误删项目所有人工评估通过的项目都被LLM正确识别为通过")
# 6. 对比分析并输出结果
comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估")
print("\n" + "=" * 60)
print("评估结果(以人工评估为标准)")
print("=" * 60)
# 详细评估结果(核心指标优先)
print(f"\n--- {comparison['method']} ---")
print(f" 总数: {comparison['total']}")
print()
# print(" 【核心能力指标】")
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}")
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
print()
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}")
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
print()
print(" 【其他指标】")
print(f" 准确率: {comparison['accuracy']:.2f}% (整体判断正确率)")
print(f" 精确率: {comparison['precision']:.2f}% (判断为合适的项目中,实际合适的比例)")
print(f" F1分数: {comparison['f1_score']:.2f} (精确率和召回率的调和平均)")
print(f" 匹配数: {comparison['matched']}/{comparison['total']}")
print()
print(" 【不合适率分析】")
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}")
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
print()
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})")
print(f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%")
print()
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
# print(f" - 含义: {'LLM删除后剩余项目中的不合适率降低了' if comparison['rate_difference'] > 0 else 'LLM删除后剩余项目中的不合适率反而升高了' if comparison['rate_difference'] < 0 else '两者相等'} ({'✓ LLM删除有效' if comparison['rate_difference'] > 0 else '✗ LLM删除效果不佳' if comparison['rate_difference'] < 0 else '效果相同'})")
# print()
print(" 【分类统计】")
print(f" TP (正确识别为合适): {comparison['true_positives']}")
print(f" TN (正确识别为不合适): {comparison['true_negatives']}")
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
# 7. 保存结果到JSON文件
output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json")
try:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"manual_results": valid_manual_results,
"llm_results": llm_results,
"comparison": comparison
}, f, ensure_ascii=False, indent=2)
logger.info(f"\n评估结果已保存到: {output_file}")
except Exception as e:
logger.warning(f"保存结果到文件失败: {e}")
print("\n" + "=" * 60)
print("评估完成")
print("=" * 60)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="表达方式LLM评估脚本",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python evaluate_expressions_llm_v6.py # 使用全部数据
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
"""
)
parser.add_argument(
"-n", "--count",
type=int,
default=None,
help="随机选取的数据条数(默认:使用全部数据)"
)
args = parser.parse_args()
asyncio.run(main(count=args.count))

View File

@ -0,0 +1,278 @@
"""
表达方式人工评估脚本
功能
1. 不停随机抽取项目不重复进行人工评估
2. 将结果保存到 temp 文件夹下的 JSON 文件作为效标标准答案
3. 支持继续评估从已有文件中读取已评估的项目避免重复
"""
import random
import json
import sys
import os
from typing import List, Dict, Set, Tuple
from datetime import datetime
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression
from src.common.database.database import db
from src.common.logger import get_logger
logger = get_logger("expression_evaluator_manual")
# 评估结果文件路径
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
"""
加载已有的评估结果
Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合)
"""
if not os.path.exists(MANUAL_EVAL_FILE):
return [], set()
try:
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
# 使用 (situation, style) 作为唯一标识
evaluated_pairs = {(r["situation"], r["style"]) for r in results if "situation" in r and "style" in r}
logger.info(f"已加载 {len(results)} 条已有评估结果")
return results, evaluated_pairs
except Exception as e:
logger.error(f"加载已有评估结果失败: {e}")
return [], set()
def save_results(manual_results: List[Dict]):
"""
保存评估结果到文件
Args:
manual_results: 评估结果列表
"""
try:
os.makedirs(TEMP_DIR, exist_ok=True)
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(manual_results),
"manual_results": manual_results
}
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)")
except Exception as e:
logger.error(f"保存评估结果失败: {e}")
print(f"\n✗ 保存评估结果失败: {e}")
def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]:
"""
获取未评估的表达方式
Args:
evaluated_pairs: 已评估的项目(situation, style)元组集合
batch_size: 每次获取的数量
Returns:
未评估的表达方式列表
"""
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 过滤出未评估的项目:匹配 situation 和 style 均一致
unevaluated = [
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
if not unevaluated:
logger.info("所有项目都已评估完成")
return []
# 如果未评估数量少于请求数量,返回所有
if len(unevaluated) <= batch_size:
logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回")
return unevaluated
# 随机选择指定数量
selected = random.sample(unevaluated, batch_size)
logger.info(f"{len(unevaluated)} 条未评估项目中随机选择了 {len(selected)}")
return selected
except Exception as e:
logger.error(f"获取未评估表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
"""
人工评估单个表达方式
Args:
expression: 表达方式对象
index: 当前索引从1开始
total: 总数
Returns:
评估结果字典如果用户退出则返回 None
"""
print("\n" + "=" * 60)
print(f"人工评估 [{index}/{total}]")
print("=" * 60)
print(f"Situation: {expression.situation}")
print(f"Style: {expression.style}")
print("\n请评估该表达方式是否合适:")
print(" 输入 'y''yes''1' 表示合适(通过)")
print(" 输入 'n''no''0' 表示不合适(不通过)")
print(" 输入 'q''quit' 退出评估")
print(" 输入 's''skip' 跳过当前项目")
while True:
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
if user_input in ['q', 'quit']:
print("退出评估")
return None
if user_input in ['s', 'skip']:
print("跳过当前项目")
return "skip"
if user_input in ['y', 'yes', '1', '', '通过']:
suitable = True
break
elif user_input in ['n', 'no', '0', '', '不通过']:
suitable = False
break
else:
print("输入无效,请重新输入 (y/n/q/s)")
result = {
"situation": expression.situation,
"style": expression.style,
"suitable": suitable,
"reason": None,
"evaluator": "manual",
"evaluated_at": datetime.now().isoformat()
}
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
return result
def main():
"""主函数"""
logger.info("=" * 60)
logger.info("开始表达方式人工评估")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
logger.info("数据库连接成功")
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results()
manual_results = existing_results.copy()
if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}")
print("\n" + "=" * 60)
print("开始人工评估")
print("=" * 60)
print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目")
print("评估结果会自动保存到文件\n")
batch_size = 10
batch_count = 0
while True:
# 获取未评估的项目
expressions = get_unevaluated_expressions(evaluated_pairs, batch_size)
if not expressions:
print("\n" + "=" * 60)
print("所有项目都已评估完成!")
print("=" * 60)
break
batch_count += 1
print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---")
batch_results = []
for i, expression in enumerate(expressions, 1):
manual_result = manual_evaluate_expression(expression, i, len(expressions))
if manual_result is None:
# 用户退出
print("\n评估已中断")
if batch_results:
# 保存当前批次的结果
manual_results.extend(batch_results)
save_results(manual_results)
return
if manual_result == "skip":
# 跳过当前项目
continue
batch_results.append(manual_result)
# 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((manual_result["situation"], manual_result["style"]))
# 将当前批次结果添加到总结果中
manual_results.extend(batch_results)
# 保存结果
save_results(manual_results)
print(f"\n当前批次完成,已评估总数: {len(manual_results)}")
# 询问是否继续
while True:
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
if continue_input in ['y', 'yes', '1', '', '继续']:
break
elif continue_input in ['n', 'no', '0', '', '退出']:
print("\n评估结束")
return
else:
print("输入无效,请重新输入 (y/n)")
# 关闭数据库连接
try:
db.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.warning(f"关闭数据库连接时出错: {e}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,476 @@
"""
表达方式评估脚本
功能
1. 随机读取指定数量的表达方式获取其situation和style
2. 先进行人工评估逐条手动评估
3. 然后使用LLM进行评估
4. 对比人工评估和LLM评估的正确率精确率召回率F1分数等指标以人工评估为标准
5. 不真正修改数据库只是做评估
"""
import asyncio
import random
import json
import sys
import os
from typing import List, Dict
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression
from src.common.database.database import db
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.common.logger import get_logger
logger = get_logger("expression_evaluator_comparison")
def get_random_expressions(count: int = 10) -> List[Expression]:
"""
随机读取指定数量的表达方式
Args:
count: 要读取的数量默认10条
Returns:
表达方式列表
"""
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 如果总数少于请求数量,返回所有
if len(all_expressions) <= count:
logger.info(f"数据库中共有 {len(all_expressions)} 条表达方式,全部返回")
return all_expressions
# 随机选择指定数量
selected = random.sample(all_expressions, count)
logger.info(f"{len(all_expressions)} 条表达方式中随机选择了 {len(selected)}")
return selected
except Exception as e:
logger.error(f"随机读取表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
"""
人工评估单个表达方式
Args:
expression: 表达方式对象
index: 当前索引从1开始
total: 总数
Returns:
评估结果字典包含
- expression_id: 表达方式ID
- situation: 情境
- style: 风格
- suitable: 是否合适人工评估
- reason: 评估理由始终为None
"""
print("\n" + "=" * 60)
print(f"人工评估 [{index}/{total}]")
print("=" * 60)
print(f"Situation: {expression.situation}")
print(f"Style: {expression.style}")
print("\n请评估该表达方式是否合适:")
print(" 输入 'y''yes''1' 表示合适(通过)")
print(" 输入 'n''no''0' 表示不合适(不通过)")
print(" 输入 'q''quit' 退出评估")
while True:
user_input = input("\n您的评估 (y/n/q): ").strip().lower()
if user_input in ['q', 'quit']:
print("退出评估")
return None
if user_input in ['y', 'yes', '1', '', '通过']:
suitable = True
break
elif user_input in ['n', 'no', '0', '', '不通过']:
suitable = False
break
else:
print("输入无效,请重新输入 (y/n/q)")
result = {
"expression_id": expression.id,
"situation": expression.situation,
"style": expression.style,
"suitable": suitable,
"reason": None,
"evaluator": "manual"
}
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
return result
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
prompt = f"""请评估以下表达方式是否合适:
情境situation{situation}
风格style{style}
请从以下方面进行评估
1. 情境描述是否清晰准确
2. 风格表达是否合理自然
3. 情境和风格是否匹配
4. 允许部分语法错误出现
5. 允许口头化或缺省表达
6. 允许部分上下文缺失
请以JSON格式输出评估结果
{{
"suitable": true/false,
"reason": "评估理由(如果不合适,请说明原因)"
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容"""
return prompt
async def _single_llm_evaluation(expression: Expression, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
expression: 表达方式对象
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(expression.situation, expression.style)
logger.debug(f"正在评估表达方式 ID: {expression.id}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果")
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 ID: {expression.id} 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
async def evaluate_expression_llm(expression: Expression, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
expression: 表达方式对象
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式 ID: {expression.id}")
suitable, reason, error = await _single_llm_evaluation(expression, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"expression_id": expression.id,
"situation": expression.situation,
"style": expression.style,
"suitable": suitable,
"reason": reason,
"error": error,
"evaluator": "llm"
}
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
"""
对比人工评估和LLM评估的结果
Args:
manual_results: 人工评估结果列表
llm_results: LLM评估结果列表
method_name: 评估方法名称用于标识
Returns:
对比分析结果字典
"""
# 按expression_id建立映射
llm_dict = {r["expression_id"]: r for r in llm_results}
total = len(manual_results)
matched = 0
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0
for manual_result in manual_results:
llm_result = llm_dict.get(manual_result["expression_id"])
if llm_result is None:
continue
manual_suitable = manual_result["suitable"]
llm_suitable = llm_result["suitable"]
if manual_suitable == llm_suitable:
matched += 1
if manual_suitable and llm_suitable:
true_positives += 1
elif not manual_suitable and not llm_suitable:
true_negatives += 1
elif not manual_suitable and llm_suitable:
false_positives += 1
elif manual_suitable and not llm_suitable:
false_negatives += 1
accuracy = (matched / total * 100) if total > 0 else 0
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
random_baseline = 50.0
accuracy_above_random = accuracy - random_baseline
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
return {
"method": method_name,
"total": total,
"matched": matched,
"accuracy": accuracy,
"accuracy_above_random": accuracy_above_random,
"accuracy_improvement_ratio": accuracy_improvement_ratio,
"true_positives": true_positives,
"true_negatives": true_negatives,
"false_positives": false_positives,
"false_negatives": false_negatives,
"precision": precision,
"recall": recall,
"f1_score": f1_score,
"specificity": specificity
}
async def main():
"""主函数"""
logger.info("=" * 60)
logger.info("开始表达方式评估")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
logger.info("数据库连接成功")
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 1. 随机读取表达方式
logger.info("\n步骤1: 随机读取表达方式")
expressions = get_random_expressions(10)
if not expressions:
logger.error("没有可用的表达方式,退出")
return
logger.info(f"成功读取 {len(expressions)} 条表达方式")
# 2. 人工评估
print("\n" + "=" * 60)
print("开始人工评估")
print("=" * 60)
print(f"共需要评估 {len(expressions)} 条表达方式")
print("请逐条进行评估...\n")
manual_results = []
for i, expression in enumerate(expressions, 1):
manual_result = manual_evaluate_expression(expression, i, len(expressions))
if manual_result is None:
print("\n评估已中断")
return
manual_results.append(manual_result)
print("\n" + "=" * 60)
print("人工评估完成")
print("=" * 60)
# 3. 创建LLM实例并评估
logger.info("\n步骤3: 创建LLM实例")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_comparison"
)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
return
logger.info("\n步骤4: 开始LLM评估")
llm_results = []
for i, expression in enumerate(expressions, 1):
logger.info(f"LLM评估进度: {i}/{len(expressions)}")
llm_results.append(await evaluate_expression_llm(expression, llm))
await asyncio.sleep(0.3)
# 4. 对比分析并输出结果
comparison = compare_evaluations(manual_results, llm_results, "LLM评估")
print("\n" + "=" * 60)
print("评估结果(以人工评估为标准)")
print("=" * 60)
print("\n评估目标:")
print(" 1. 核心能力:将不合适的项目正确提取出来(特定负类召回率)")
print(" 2. 次要能力:尽可能少的误删合适的项目(召回率)")
# 详细评估结果(核心指标优先)
print("\n【详细对比】")
print(f"\n--- {comparison['method']} ---")
print(f" 总数: {comparison['total']}")
print()
print(" 【核心能力指标】")
print(f" ⭐ 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}")
print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
print()
print(f" ⭐ 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}")
print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
print()
print(" 【其他指标】")
print(f" 准确率: {comparison['accuracy']:.2f}% (整体判断正确率)")
print(f" 精确率: {comparison['precision']:.2f}% (判断为合适的项目中,实际合适的比例)")
print(f" F1分数: {comparison['f1_score']:.2f} (精确率和召回率的调和平均)")
print(f" 匹配数: {comparison['matched']}/{comparison['total']}")
print()
print(" 【分类统计】")
print(f" TP (正确识别为合适): {comparison['true_positives']}")
print(f" TN (正确识别为不合适): {comparison['true_negatives']}")
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
# 5. 输出人工评估不通过但LLM误判为通过的详细信息
print("\n" + "=" * 60)
print("人工评估不通过但LLM误判为通过的项目FP - False Positive")
print("=" * 60)
# 按expression_id建立映射
llm_dict = {r["expression_id"]: r for r in llm_results}
fp_items = []
for manual_result in manual_results:
llm_result = llm_dict.get(manual_result["expression_id"])
if llm_result is None:
continue
# 人工评估不通过但LLM评估通过FP情况
if not manual_result["suitable"] and llm_result["suitable"]:
fp_items.append({
"expression_id": manual_result["expression_id"],
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
if fp_items:
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
for idx, item in enumerate(fp_items, 1):
print(f"--- [{idx}] 项目 ID: {item['expression_id']} ---")
print(f"Situation: {item['situation']}")
print(f"Style: {item['style']}")
print("人工评估: 不通过 ❌")
print("LLM评估: 通过 ✅ (误判)")
if item.get('llm_error'):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误判项目所有人工评估不通过的项目都被LLM正确识别为不通过")
# 6. 保存结果到JSON文件
output_file = os.path.join(project_root, "data", "expression_evaluation_comparison.json")
try:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"manual_results": manual_results,
"llm_results": llm_results,
"comparison": comparison
}, f, ensure_ascii=False, indent=2)
logger.info(f"\n评估结果已保存到: {output_file}")
except Exception as e:
logger.warning(f"保存结果到文件失败: {e}")
print("\n" + "=" * 60)
print("评估完成")
print("=" * 60)
# 关闭数据库连接
try:
db.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.warning(f"关闭数据库连接时出错: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,567 +0,0 @@
"""
模拟 Expression 合并过程
用法:
python scripts/expression_merge_simulation.py
或指定 chat_id:
python scripts/expression_merge_simulation.py --chat-id <chat_id>
或指定相似度阈值:
python scripts/expression_merge_simulation.py --similarity-threshold 0.8
"""
import sys
import os
import json
import argparse
import asyncio
import random
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# Import after setting up path (required for project imports)
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
from src.bw_learner.learner_utils import calculate_style_similarity # noqa: E402
from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config # noqa: E402
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id[:8]}...)"
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
else:
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def parse_content_list(stored_list: Optional[str]) -> List[str]:
"""解析 content_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def parse_style_list(stored_list: Optional[str]) -> List[str]:
"""解析 style_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def find_exact_style_match(
expressions: List[Expression],
target_style: str,
chat_id: str,
exclude_ids: set
) -> Optional[Expression]:
"""
查找具有完全匹配 style Expression 记录
检查 style 字段和 style_list 中的每一项
"""
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
if expr.style == target_style:
return expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
if target_style in style_list:
return expr
return None
def find_similar_style_expression(
expressions: List[Expression],
target_style: str,
chat_id: str,
similarity_threshold: float,
exclude_ids: set
) -> Optional[Tuple[Expression, float]]:
"""
查找具有相似 style Expression 记录
检查 style 字段和 style_list 中的每一项
Returns:
(Expression, similarity) None
"""
best_match = None
best_similarity = 0.0
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
similarity = calculate_style_similarity(target_style, expr.style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
for existing_style in style_list:
similarity = calculate_style_similarity(target_style, existing_style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match:
return (best_match, best_similarity)
return None
async def compose_situation_text(content_list: List[str], summary_model: LLMRequest) -> str:
"""组合 situation 文本,尝试使用 LLM 总结"""
sanitized = [c.strip() for c in content_list if c.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 situation 失败: {e}")
# 如果总结失败,返回用 "/" 连接的字符串
return "/".join(sanitized)
async def compose_style_text(style_list: List[str], summary_model: LLMRequest) -> str:
"""组合 style 文本,尝试使用 LLM 总结"""
sanitized = [s.strip() for s in style_list if s.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
print(f"Prompt:{prompt} Summary:{summary}")
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 style 失败: {e}")
# 如果总结失败,返回第一个
return sanitized[0]
async def simulate_merge(
expressions: List[Expression],
similarity_threshold: float = 0.75,
use_llm: bool = False,
max_samples: int = 10,
) -> Dict:
"""
模拟合并过程
Args:
expressions: Expression 列表从数据库读出的原始记录
similarity_threshold: style 相似度阈值
use_llm: 是否使用 LLM 进行实际总结
max_samples: 最多随机抽取的 Expression 数量 0 None 时表示不限制
Returns:
包含合并统计信息的字典
"""
# 如果样本太多,随机抽取一部分进行模拟,避免运行时间过长
if max_samples and len(expressions) > max_samples:
expressions = random.sample(expressions, max_samples)
# 按 chat_id 分组
expressions_by_chat = defaultdict(list)
for expr in expressions:
expressions_by_chat[expr.chat_id].append(expr)
# 初始化 LLM 模型(如果需要)
summary_model = None
if use_llm:
try:
summary_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="expression.summary"
)
print("✅ LLM 模型已初始化,将进行实际总结")
except Exception as e:
print(f"⚠️ LLM 模型初始化失败: {e},将跳过 LLM 总结")
use_llm = False
merge_stats = {
"total_expressions": len(expressions),
"total_chats": len(expressions_by_chat),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merge_details": [],
"chat_stats": {},
"use_llm": use_llm
}
# 为每个 chat_id 模拟合并
for chat_id, chat_expressions in expressions_by_chat.items():
chat_name = get_chat_name(chat_id)
chat_stat = {
"chat_id": chat_id,
"chat_name": chat_name,
"total": len(chat_expressions),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merges": []
}
processed_ids = set()
for expr in chat_expressions:
if expr.id in processed_ids:
continue
target_style = expr.style
target_situation = expr.situation
# 第一层:检查完全匹配
exact_match = find_exact_style_match(
chat_expressions,
target_style,
chat_id,
{expr.id}
)
if exact_match:
# 完全匹配(不使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(exact_match.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(exact_match.style_list)
if exact_match.style and exact_match.style not in target_style_list:
target_style_list.append(exact_match.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
merge_info = {
"type": "exact",
"source_id": expr.id,
"target_id": exact_match.id,
"source_style": target_style,
"target_style": exact_match.style,
"source_situation": target_situation,
"target_situation": exact_match.situation,
"similarity": 1.0,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": exact_match.situation, # 完全匹配时保持原 situation
"merged_style": exact_match.style # 完全匹配时保持原 style
}
chat_stat["exact_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["exact_matches"] += 1
processed_ids.add(expr.id)
continue
# 第二层:检查相似匹配
similar_match = find_similar_style_expression(
chat_expressions,
target_style,
chat_id,
similarity_threshold,
{expr.id}
)
if similar_match:
match_expr, similarity = similar_match
# 相似匹配(使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(match_expr.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(match_expr.style_list)
if match_expr.style and match_expr.style not in target_style_list:
target_style_list.append(match_expr.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
# 使用 LLM 总结(如果启用)
merged_situation = match_expr.situation
merged_style = match_expr.style or target_style
if use_llm and summary_model:
try:
merged_situation = await compose_situation_text(target_content_list, summary_model)
merged_style = await compose_style_text(target_style_list, summary_model)
except Exception as e:
print(f" ⚠️ 处理记录 {expr.id} 时 LLM 总结失败: {e}")
# 如果总结失败,使用 fallback
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
else:
# 不使用 LLM 时,使用简单拼接
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
merge_info = {
"type": "similar",
"source_id": expr.id,
"target_id": match_expr.id,
"source_style": target_style,
"target_style": match_expr.style,
"source_situation": target_situation,
"target_situation": match_expr.situation,
"similarity": similarity,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": merged_situation,
"merged_style": merged_style,
"llm_used": use_llm and summary_model is not None
}
chat_stat["similar_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["similar_matches"] += 1
processed_ids.add(expr.id)
continue
# 没有匹配,作为新记录
chat_stat["new_records"] += 1
merge_stats["new_records"] += 1
processed_ids.add(expr.id)
merge_stats["chat_stats"][chat_id] = chat_stat
merge_stats["merge_details"].extend(chat_stat["merges"])
return merge_stats
def print_merge_results(stats: Dict, show_details: bool = True, max_details: int = 50):
"""打印合并结果"""
print("\n" + "=" * 80)
print("Expression 合并模拟结果")
print("=" * 80)
print("\n📊 总体统计:")
print(f" 总 Expression 数: {stats['total_expressions']}")
print(f" 总聊天数: {stats['total_chats']}")
print(f" 完全匹配合并: {stats['exact_matches']}")
print(f" 相似匹配合并: {stats['similar_matches']}")
print(f" 新记录(无匹配): {stats['new_records']}")
if stats.get('use_llm'):
print(" LLM 总结: 已启用")
else:
print(" LLM 总结: 未启用(仅模拟)")
total_merges = stats['exact_matches'] + stats['similar_matches']
if stats['total_expressions'] > 0:
merge_ratio = (total_merges / stats['total_expressions']) * 100
print(f" 合并比例: {merge_ratio:.1f}%")
# 按聊天分组显示
print("\n📋 按聊天分组统计:")
for chat_id, chat_stat in stats['chat_stats'].items():
print(f"\n {chat_stat['chat_name']} ({chat_id[:8]}...):")
print(f" 总数: {chat_stat['total']}")
print(f" 完全匹配: {chat_stat['exact_matches']}")
print(f" 相似匹配: {chat_stat['similar_matches']}")
print(f" 新记录: {chat_stat['new_records']}")
# 显示合并详情
if show_details and stats['merge_details']:
print(f"\n📝 合并详情 (显示前 {min(max_details, len(stats['merge_details']))} 条):")
print()
for idx, merge in enumerate(stats['merge_details'][:max_details], 1):
merge_type = "完全匹配" if merge['type'] == 'exact' else f"相似匹配 (相似度: {merge['similarity']:.3f})"
print(f" {idx}. {merge_type}")
print(f" 源记录 ID: {merge['source_id']}")
print(f" 目标记录 ID: {merge['target_id']}")
print(f" 源 Style: {merge['source_style'][:50]}")
print(f" 目标 Style: {merge['target_style'][:50]}")
print(f" 源 Situation: {merge['source_situation'][:50]}")
print(f" 目标 Situation: {merge['target_situation'][:50]}")
# 显示合并后的结果
if 'merged_situation' in merge:
print(f" → 合并后 Situation: {merge['merged_situation'][:50]}")
if 'merged_style' in merge:
print(f" → 合并后 Style: {merge['merged_style'][:50]}")
if merge.get('llm_used'):
print(" → LLM 总结: 已使用")
elif merge['type'] == 'similar':
print(" → LLM 总结: 未使用(模拟模式)")
# 显示合并后的列表
if 'merged_content_list' in merge and len(merge['merged_content_list']) > 1:
print(f" → Content List ({len(merge['merged_content_list'])} 项): {', '.join(merge['merged_content_list'][:3])}")
if len(merge['merged_content_list']) > 3:
print(f" ... 还有 {len(merge['merged_content_list']) - 3}")
if 'merged_style_list' in merge and len(merge['merged_style_list']) > 1:
print(f" → Style List ({len(merge['merged_style_list'])} 项): {', '.join(merge['merged_style_list'][:3])}")
if len(merge['merged_style_list']) > 3:
print(f" ... 还有 {len(merge['merged_style_list']) - 3}")
print()
if len(stats['merge_details']) > max_details:
print(f" ... 还有 {len(stats['merge_details']) - max_details} 条合并记录未显示")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="模拟 Expression 合并过程")
parser.add_argument(
"--chat-id",
type=str,
default=None,
help="指定要分析的 chat_id不指定则分析所有"
)
parser.add_argument(
"--similarity-threshold",
type=float,
default=0.75,
help="相似度阈值 (0-1, 默认: 0.75)"
)
parser.add_argument(
"--no-details",
action="store_true",
help="不显示详细信息,只显示统计"
)
parser.add_argument(
"--max-details",
type=int,
default=50,
help="最多显示的合并详情数 (默认: 50)"
)
parser.add_argument(
"--output",
type=str,
default=None,
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
)
parser.add_argument(
"--use-llm",
action="store_true",
help="启用 LLM 进行实际总结(默认: 仅模拟,不调用 LLM"
)
parser.add_argument(
"--max-samples",
type=int,
default=10,
help="最多随机抽取的 Expression 数量 (默认: 10设置为 0 表示不限制)"
)
args = parser.parse_args()
# 验证阈值
if not 0 <= args.similarity_threshold <= 1:
print("错误: similarity-threshold 必须在 0-1 之间")
return
# 确定输出文件路径
if args.output:
output_file = args.output
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"expression_merge_simulation_{timestamp}.txt")
# 查询 Expression 记录
print("正在从数据库加载Expression数据...")
try:
if args.chat_id:
expressions = list(Expression.select().where(Expression.chat_id == args.chat_id))
print(f"✅ 成功加载 {len(expressions)} 条Expression记录 (chat_id: {args.chat_id})")
else:
expressions = list(Expression.select())
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
except Exception as e:
print(f"❌ 加载数据失败: {e}")
return
if not expressions:
print("❌ 数据库中没有找到Expression记录")
return
# 执行合并模拟
print(f"\n正在模拟合并过程(相似度阈值: {args.similarity_threshold},最大样本数: {args.max_samples}...")
if args.use_llm:
print("⚠️ 已启用 LLM 总结,将进行实际的 API 调用")
else:
print(" 未启用 LLM 总结,仅进行模拟(使用 --use-llm 启用实际 LLM 调用)")
stats = asyncio.run(
simulate_merge(
expressions,
similarity_threshold=args.similarity_threshold,
use_llm=args.use_llm,
max_samples=args.max_samples,
)
)
# 输出结果
original_stdout = sys.stdout
try:
with open(output_file, "w", encoding="utf-8") as f:
sys.stdout = f
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
sys.stdout = original_stdout
# 同时在控制台输出
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
except Exception as e:
sys.stdout = original_stdout
print(f"❌ 写入文件失败: {e}")
return
print(f"\n✅ 模拟结果已保存到: {output_file}")
if __name__ == "__main__":
main()

View File

@ -1,342 +0,0 @@
import sys
import os
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime
from typing import List, Tuple
import numpy as np
from src.common.database.database_model import Expression, ChatStreams
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def get_expression_data() -> List[Tuple[float, float, str, str]]:
"""获取Expression表中的数据返回(create_date, count, chat_id, expression_type)的列表"""
expressions = Expression.select()
data = []
for expr in expressions:
# 如果create_date为空跳过该记录
if expr.create_date is None:
continue
data.append((expr.create_date, expr.count, expr.chat_id, expr.type))
return data
def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
"""创建散点图"""
if not data:
print("没有找到有效的表达式数据")
return
# 分离数据
create_dates = [item[0] for item in data]
counts = [item[1] for item in data]
_chat_ids = [item[2] for item in data]
_expression_types = [item[3] for item in data]
# 转换时间戳为datetime对象
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
# 计算时间跨度,自动调整显示格式
time_span = max(dates) - min(dates)
if time_span.days > 30: # 超过30天按月显示
date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(12, 8))
# 创建散点图
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap="viridis")
# 设置标签和标题
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title("表达式使用次数随时间分布散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加网格
ax.grid(True, alpha=0.3)
# 添加颜色条
cbar = plt.colorbar(scatter)
cbar.set_label("数据点顺序", fontsize=10)
# 调整布局
plt.tight_layout()
# 显示统计信息
print("\n=== 数据统计 ===")
print(f"总数据点数量: {len(data)}")
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')}{max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
print(f"使用次数范围: {min(counts):.1f}{max(counts):.1f}")
print(f"平均使用次数: {np.mean(counts):.2f}")
print(f"中位数使用次数: {np.median(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n散点图已保存到: {save_path}")
# 显示图片
plt.show()
def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
"""创建按聊天分组的散点图"""
if not data:
print("没有找到有效的表达式数据")
return
# 按chat_id分组
chat_groups = {}
for item in data:
chat_id = item[2]
if chat_id not in chat_groups:
chat_groups[chat_id] = []
chat_groups[chat_id].append(item)
# 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示
date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(14, 10))
# 为每个聊天分配不同颜色
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
create_dates = [item[0] for item in chat_data]
counts = [item[1] for item in chat_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
chat_name = get_chat_name(chat_id)
# 截断过长的聊天名称
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
ax.scatter(
dates,
counts,
alpha=0.7,
s=40,
c=[colors[i]],
label=f"{display_name} ({len(chat_data)}个)",
edgecolors="black",
linewidth=0.5,
)
# 设置标签和标题
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title("按聊天分组的表达式使用次数散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
# 添加网格
ax.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 显示统计信息
print("\n=== 分组统计 ===")
print(f"总聊天数量: {len(chat_groups)}")
for chat_id, chat_data in chat_groups.items():
chat_name = get_chat_name(chat_id)
counts = [item[1] for item in chat_data]
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n分组散点图已保存到: {save_path}")
# 显示图片
plt.show()
def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
"""创建按表达式类型分组的散点图"""
if not data:
print("没有找到有效的表达式数据")
return
# 按type分组
type_groups = {}
for item in data:
expr_type = item[3]
if expr_type not in type_groups:
type_groups[expr_type] = []
type_groups[expr_type].append(item)
# 计算时间跨度,自动调整显示格式
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
time_span = max(all_dates) - min(all_dates)
if time_span.days > 30: # 超过30天按月显示
date_format = "%Y-%m-%d"
major_locator = mdates.MonthLocator()
minor_locator = mdates.DayLocator(interval=7)
elif time_span.days > 7: # 超过7天按天显示
date_format = "%Y-%m-%d"
major_locator = mdates.DayLocator(interval=1)
minor_locator = mdates.HourLocator(interval=12)
else: # 7天内按小时显示
date_format = "%Y-%m-%d %H:%M"
major_locator = mdates.HourLocator(interval=6)
minor_locator = mdates.HourLocator(interval=1)
# 创建图形
fig, ax = plt.subplots(figsize=(12, 8))
# 为每个类型分配不同颜色
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
for i, (expr_type, type_data) in enumerate(type_groups.items()):
create_dates = [item[0] for item in type_data]
counts = [item[1] for item in type_data]
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
ax.scatter(
dates,
counts,
alpha=0.7,
s=40,
c=[colors[i]],
label=f"{expr_type} ({len(type_data)}个)",
edgecolors="black",
linewidth=0.5,
)
# 设置标签和标题
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
ax.set_ylabel("使用次数 (Count)", fontsize=12)
ax.set_title("按表达式类型分组的散点图", fontsize=14, fontweight="bold")
# 设置x轴日期格式 - 根据时间跨度自动调整
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_minor_locator(minor_locator)
plt.xticks(rotation=45)
# 添加图例
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
# 添加网格
ax.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 显示统计信息
print("\n=== 类型统计 ===")
for expr_type, type_data in type_groups.items():
counts = [item[1] for item in type_data]
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
# 保存图片
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"\n类型散点图已保存到: {save_path}")
# 显示图片
plt.show()
def main():
"""主函数"""
print("开始分析表达式数据...")
# 获取数据
data = get_expression_data()
if not data:
print("没有找到有效的表达式数据create_date不为空的数据")
return
print(f"找到 {len(data)} 条有效数据")
# 创建输出目录
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
# 生成时间戳用于文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 1. 创建基础散点图
print("\n1. 创建基础散点图...")
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
# 2. 创建按聊天分组的散点图
print("\n2. 创建按聊天分组的散点图...")
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
# 3. 创建按类型分组的散点图
print("\n3. 创建按类型分组的散点图...")
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
print("\n分析完成!")
if __name__ == "__main__":
main()

View File

@ -1,559 +0,0 @@
"""
分析expression库中situation和style的相似度
用法:
python scripts/expression_similarity_analysis.py
或指定阈值:
python scripts/expression_similarity_analysis.py --situation-threshold 0.8 --style-threshold 0.7
"""
import sys
import os
import argparse
from typing import List, Tuple
from collections import defaultdict
from difflib import SequenceMatcher
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# Import after setting up path (required for project imports)
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
from src.config.config import global_config # noqa: E402
from src.chat.message_receive.chat_stream import get_chat_manager # noqa: E402
class TeeOutput:
"""同时输出到控制台和文件的类"""
def __init__(self, file_path: str):
self.file = open(file_path, "w", encoding="utf-8")
self.console = sys.stdout
def write(self, text: str):
"""写入文本到控制台和文件"""
self.console.write(text)
self.file.write(text)
self.file.flush() # 立即刷新到文件
def flush(self):
"""刷新输出"""
self.console.flush()
self.file.flush()
def close(self):
"""关闭文件"""
if self.file:
self.file.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
"""
解析'platform:id:type'为chat_id直接复用 ChatManager 的逻辑
"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
except Exception:
return None
def build_chat_id_groups() -> dict[str, set[str]]:
"""
根据expression_groups配置构建chat_id到相关chat_id集合的映射
Returns:
dict: {chat_id: set of related chat_ids (including itself)}
"""
groups = global_config.expression.expression_groups
chat_id_groups: dict[str, set[str]] = {}
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组收集所有配置中的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if stream_config_str == "*":
continue
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
# 所有chat_id都互相相关
for chat_id in all_chat_ids:
chat_id_groups[chat_id] = all_chat_ids.copy()
else:
# 处理普通组
for group in groups:
group_chat_ids = set()
for stream_config_str in group:
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.add(chat_id_candidate)
# 组内的所有chat_id都互相相关
for chat_id in group_chat_ids:
if chat_id not in chat_id_groups:
chat_id_groups[chat_id] = set()
chat_id_groups[chat_id].update(group_chat_ids)
# 确保每个chat_id至少包含自身
for chat_id in chat_id_groups:
chat_id_groups[chat_id].add(chat_id)
return chat_id_groups
def are_chat_ids_related(chat_id1: str, chat_id2: str, chat_id_groups: dict[str, set[str]]) -> bool:
"""
判断两个chat_id是否相关相同或同组
Args:
chat_id1: 第一个chat_id
chat_id2: 第二个chat_id
chat_id_groups: chat_id到相关chat_id集合的映射
Returns:
bool: 如果两个chat_id相同或同组返回True
"""
if chat_id1 == chat_id2:
return True
# 如果chat_id1在映射中检查chat_id2是否在其相关集合中
if chat_id1 in chat_id_groups:
return chat_id2 in chat_id_groups[chat_id1]
# 如果chat_id1不在映射中说明它不在任何组中只与自己相关
return False
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id[:8]}...)"
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
else:
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def text_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度
使用SequenceMatcher计算相似度返回0-1之间的值
在计算前会移除"使用""句式"这两个词
"""
if not text1 or not text2:
return 0.0
# 移除"使用"和"句式"这两个词
def remove_ignored_words(text: str) -> str:
"""移除需要忽略的词"""
text = text.replace("使用", "")
text = text.replace("句式", "")
return text.strip()
cleaned_text1 = remove_ignored_words(text1)
cleaned_text2 = remove_ignored_words(text2)
# 如果清理后文本为空返回0
if not cleaned_text1 or not cleaned_text2:
return 0.0
return SequenceMatcher(None, cleaned_text1, cleaned_text2).ratio()
def find_similar_pairs(
expressions: List[Expression],
field_name: str,
threshold: float,
max_pairs: int = None
) -> List[Tuple[int, int, float, str, str]]:
"""
找出相似的expression对
Args:
expressions: Expression对象列表
field_name: 要比较的字段名 ('situation' 'style')
threshold: 相似度阈值 (0-1)
max_pairs: 最多返回的对数None表示返回所有
Returns:
List of (index1, index2, similarity, text1, text2) tuples
"""
similar_pairs = []
n = len(expressions)
print(f"正在分析 {field_name} 字段的相似度...")
print(f"总共需要比较 {n * (n - 1) // 2} 对...")
for i in range(n):
if (i + 1) % 100 == 0:
print(f" 已处理 {i + 1}/{n} 个项目...")
expr1 = expressions[i]
text1 = getattr(expr1, field_name, "")
for j in range(i + 1, n):
expr2 = expressions[j]
text2 = getattr(expr2, field_name, "")
similarity = text_similarity(text1, text2)
if similarity >= threshold:
similar_pairs.append((i, j, similarity, text1, text2))
# 按相似度降序排序
similar_pairs.sort(key=lambda x: x[2], reverse=True)
if max_pairs:
similar_pairs = similar_pairs[:max_pairs]
return similar_pairs
def group_similar_items(
expressions: List[Expression],
field_name: str,
threshold: float,
chat_id_groups: dict[str, set[str]]
) -> List[List[int]]:
"""
将相似的expression分组仅比较相同chat_id或同组的项目
Args:
expressions: Expression对象列表
field_name: 要比较的字段名 ('situation' 'style')
threshold: 相似度阈值 (0-1)
chat_id_groups: chat_id到相关chat_id集合的映射
Returns:
List of groups, each group is a list of indices
"""
n = len(expressions)
# 使用并查集的思想来分组
parent = list(range(n))
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
px, py = find(x), find(y)
if px != py:
parent[px] = py
print(f"正在对 {field_name} 字段进行分组仅比较相同chat_id或同组的项目...")
# 统计需要比较的对数
total_pairs = 0
for i in range(n):
for j in range(i + 1, n):
if are_chat_ids_related(expressions[i].chat_id, expressions[j].chat_id, chat_id_groups):
total_pairs += 1
print(f"总共需要比较 {total_pairs}已过滤不同chat_id且不同组的项目...")
compared_pairs = 0
for i in range(n):
if (i + 1) % 100 == 0:
print(f" 已处理 {i + 1}/{n} 个项目...")
expr1 = expressions[i]
text1 = getattr(expr1, field_name, "")
for j in range(i + 1, n):
expr2 = expressions[j]
# 只比较相同chat_id或同组的项目
if not are_chat_ids_related(expr1.chat_id, expr2.chat_id, chat_id_groups):
continue
compared_pairs += 1
text2 = getattr(expr2, field_name, "")
similarity = text_similarity(text1, text2)
if similarity >= threshold:
union(i, j)
# 收集分组
groups = defaultdict(list)
for i in range(n):
root = find(i)
groups[root].append(i)
# 只返回包含多个项目的组
result = [group for group in groups.values() if len(group) > 1]
result.sort(key=len, reverse=True)
return result
def print_similarity_analysis(
expressions: List[Expression],
field_name: str,
threshold: float,
chat_id_groups: dict[str, set[str]],
show_details: bool = True,
max_groups: int = 20
):
"""打印相似度分析结果"""
print("\n" + "=" * 80)
print(f"{field_name.upper()} 相似度分析 (阈值: {threshold})")
print("=" * 80)
# 分组分析
groups = group_similar_items(expressions, field_name, threshold, chat_id_groups)
total_items = len(expressions)
similar_items_count = sum(len(group) for group in groups)
unique_groups = len(groups)
print("\n📊 统计信息:")
print(f" 总项目数: {total_items}")
print(f" 相似项目数: {similar_items_count} ({similar_items_count / total_items * 100:.1f}%)")
print(f" 相似组数: {unique_groups}")
print(f" 平均每组项目数: {similar_items_count / unique_groups:.1f}" if unique_groups > 0 else " 平均每组项目数: 0")
if not groups:
print(f"\n未找到相似度 >= {threshold} 的项目组")
return
print(f"\n📋 相似组详情 (显示前 {min(max_groups, len(groups))} 组):")
print()
for group_idx, group in enumerate(groups[:max_groups], 1):
print(f"{group_idx} (共 {len(group)} 个项目):")
if show_details:
# 显示组内所有项目的详细信息
for idx in group:
expr = expressions[idx]
text = getattr(expr, field_name, "")
chat_name = get_chat_name(expr.chat_id)
# 截断过长的文本
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [{expr.id}] {display_text}")
print(f" 聊天: {chat_name}, Count: {expr.count}")
# 计算组内平均相似度
if len(group) > 1:
similarities = []
above_threshold_pairs = [] # 存储满足阈值的相似对
above_threshold_count = 0
for i in range(len(group)):
for j in range(i + 1, len(group)):
text1 = getattr(expressions[group[i]], field_name, "")
text2 = getattr(expressions[group[j]], field_name, "")
sim = text_similarity(text1, text2)
similarities.append(sim)
if sim >= threshold:
above_threshold_count += 1
# 存储满足阈值的对的信息
expr1 = expressions[group[i]]
expr2 = expressions[group[j]]
display_text1 = text1[:40] + "..." if len(text1) > 40 else text1
display_text2 = text2[:40] + "..." if len(text2) > 40 else text2
above_threshold_pairs.append((
expr1.id, display_text1,
expr2.id, display_text2,
sim
))
if similarities:
avg_sim = sum(similarities) / len(similarities)
min_sim = min(similarities)
max_sim = max(similarities)
above_threshold_ratio = above_threshold_count / len(similarities) * 100
print(f" 平均相似度: {avg_sim:.3f} (范围: {min_sim:.3f} - {max_sim:.3f})")
print(f" 满足阈值({threshold})的比例: {above_threshold_ratio:.1f}% ({above_threshold_count}/{len(similarities)})")
# 显示满足阈值的相似对(这些是直接连接,导致它们被分到一组)
if above_threshold_pairs:
print(" ⚠️ 直接相似的对 (这些对导致它们被分到一组):")
# 按相似度降序排序
above_threshold_pairs.sort(key=lambda x: x[4], reverse=True)
for idx1, text1, idx2, text2, sim in above_threshold_pairs[:10]: # 最多显示10对
print(f" [{idx1}] ↔ [{idx2}]: {sim:.3f}")
print(f" \"{text1}\"\"{text2}\"")
if len(above_threshold_pairs) > 10:
print(f" ... 还有 {len(above_threshold_pairs) - 10} 对满足阈值")
else:
print(f" ⚠️ 警告: 组内没有任何对满足阈值({threshold:.2f}),可能是通过传递性连接")
else:
# 只显示组内第一个项目作为示例
expr = expressions[group[0]]
text = getattr(expr, field_name, "")
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" 示例: {display_text}")
print(f" ... 还有 {len(group) - 1} 个相似项目")
print()
if len(groups) > max_groups:
print(f"... 还有 {len(groups) - max_groups} 组未显示")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="分析expression库中situation和style的相似度")
parser.add_argument(
"--situation-threshold",
type=float,
default=0.7,
help="situation相似度阈值 (0-1, 默认: 0.7)"
)
parser.add_argument(
"--style-threshold",
type=float,
default=0.7,
help="style相似度阈值 (0-1, 默认: 0.7)"
)
parser.add_argument(
"--no-details",
action="store_true",
help="不显示详细信息,只显示统计"
)
parser.add_argument(
"--max-groups",
type=int,
default=20,
help="最多显示的组数 (默认: 20)"
)
parser.add_argument(
"--output",
type=str,
default=None,
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
)
args = parser.parse_args()
# 验证阈值
if not 0 <= args.situation_threshold <= 1:
print("错误: situation-threshold 必须在 0-1 之间")
return
if not 0 <= args.style_threshold <= 1:
print("错误: style-threshold 必须在 0-1 之间")
return
# 确定输出文件路径
if args.output:
output_file = args.output
else:
# 自动生成带时间戳的输出文件
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"expression_similarity_analysis_{timestamp}.txt")
# 使用TeeOutput同时输出到控制台和文件
with TeeOutput(output_file) as tee:
# 临时替换sys.stdout
original_stdout = sys.stdout
sys.stdout = tee
try:
print("=" * 80)
print("Expression 相似度分析工具")
print("=" * 80)
print(f"输出文件: {output_file}")
print()
_run_analysis(args)
finally:
# 恢复原始stdout
sys.stdout = original_stdout
print(f"\n✅ 分析结果已保存到: {output_file}")
def _run_analysis(args):
"""执行分析的主逻辑"""
# 查询所有Expression记录
print("正在从数据库加载Expression数据...")
try:
expressions = list(Expression.select())
except Exception as e:
print(f"❌ 加载数据失败: {e}")
return
if not expressions:
print("❌ 数据库中没有找到Expression记录")
return
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
print()
# 构建chat_id分组映射
print("正在构建chat_id分组映射根据expression_groups配置...")
try:
chat_id_groups = build_chat_id_groups()
print(f"✅ 成功构建 {len(chat_id_groups)} 个chat_id的分组映射")
if chat_id_groups:
# 统计分组信息
total_related = sum(len(related) for related in chat_id_groups.values())
avg_related = total_related / len(chat_id_groups)
print(f" 平均每个chat_id与 {avg_related:.1f} 个chat_id相关包括自身")
print()
except Exception as e:
print(f"⚠️ 构建chat_id分组映射失败: {e}")
print(" 将使用默认行为只比较相同chat_id的项目")
chat_id_groups = {}
# 分析situation相似度
print_similarity_analysis(
expressions,
"situation",
args.situation_threshold,
chat_id_groups,
show_details=not args.no_details,
max_groups=args.max_groups
)
# 分析style相似度
print_similarity_analysis(
expressions,
"style",
args.style_threshold,
chat_id_groups,
show_details=not args.no_details,
max_groups=args.max_groups
)
print("\n" + "=" * 80)
print("分析完成!")
print("=" * 80)
if __name__ == "__main__":
main()

View File

@ -1,196 +0,0 @@
import time
import sys
import os
from typing import Dict, List
# Add project root to Python path
from src.common.database.database_model import Expression, ChatStreams
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
# 直接从数据库查询ChatStreams表
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
# 如果有群组信息,显示群组名称
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
# 如果是私聊,显示用户昵称
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def calculate_time_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of last active time in days"""
now = time.time()
distribution = {
"0-1天": 0,
"1-3天": 0,
"3-7天": 0,
"7-14天": 0,
"14-30天": 0,
"30-60天": 0,
"60-90天": 0,
"90+天": 0,
}
for expr in expressions:
diff_days = (now - expr.last_active_time) / (24 * 3600)
if diff_days < 1:
distribution["0-1天"] += 1
elif diff_days < 3:
distribution["1-3天"] += 1
elif diff_days < 7:
distribution["3-7天"] += 1
elif diff_days < 14:
distribution["7-14天"] += 1
elif diff_days < 30:
distribution["14-30天"] += 1
elif diff_days < 60:
distribution["30-60天"] += 1
elif diff_days < 90:
distribution["60-90天"] += 1
else:
distribution["90+天"] += 1
return distribution
def calculate_count_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of count values"""
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
for expr in expressions:
cnt = expr.count
if cnt < 1:
distribution["0-1"] += 1
elif cnt < 2:
distribution["1-2"] += 1
elif cnt < 3:
distribution["2-3"] += 1
elif cnt < 4:
distribution["3-4"] += 1
elif cnt < 5:
distribution["4-5"] += 1
elif cnt < 10:
distribution["5-10"] += 1
else:
distribution["10+"] += 1
return distribution
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
"""Get top N most used expressions for a specific chat_id"""
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
def show_overall_statistics(expressions, total: int) -> None:
"""Show overall statistics"""
time_dist = calculate_time_distribution(expressions)
count_dist = calculate_count_distribution(expressions)
print("\n=== 总体统计 ===")
print(f"总表达式数量: {total}")
print("\n上次激活时间分布:")
for period, count in time_dist.items():
print(f"{period}: {count} ({count / total * 100:.2f}%)")
print("\ncount分布:")
for range_, count in count_dist.items():
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
"""Show statistics for a specific chat"""
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
chat_total = len(chat_exprs)
print(f"\n=== {chat_name} ===")
print(f"表达式数量: {chat_total}")
if chat_total == 0:
print("该聊天没有表达式数据")
return
# Time distribution for this chat
time_dist = calculate_time_distribution(chat_exprs)
print("\n上次激活时间分布:")
for period, count in time_dist.items():
if count > 0:
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
# Count distribution for this chat
count_dist = calculate_count_distribution(chat_exprs)
print("\ncount分布:")
for range_, count in count_dist.items():
if count > 0:
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
# Top expressions
print("\nTop 10使用最多的表达式:")
top_exprs = get_top_expressions_by_chat(chat_id, 10)
for i, expr in enumerate(top_exprs, 1):
print(f"{i}. [{expr.type}] Count: {expr.count}")
print(f" Situation: {expr.situation}")
print(f" Style: {expr.style}")
print()
def interactive_menu() -> None:
"""Interactive menu for expression statistics"""
# Get all expressions
expressions = list(Expression.select())
if not expressions:
print("数据库中没有找到表达式")
return
total = len(expressions)
# Get unique chat_ids and their names
chat_ids = list(set(expr.chat_id for expr in expressions))
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
while True:
print("\n" + "=" * 50)
print("表达式统计分析")
print("=" * 50)
print("0. 显示总体统计")
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
print(f"{i}. {chat_name} ({chat_count}个表达式)")
print("q. 退出")
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
if choice.lower() == "q":
print("再见!")
break
try:
choice_num = int(choice)
if choice_num == 0:
show_overall_statistics(expressions, total)
elif 1 <= choice_num <= len(chat_info):
chat_id, chat_name = chat_info[choice_num - 1]
show_chat_statistics(chat_id, chat_name)
else:
print("无效的选择,请重新输入")
except ValueError:
print("请输入有效的数字")
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

View File

@ -470,7 +470,7 @@ def _run_embedding_helper() -> None:
test_path.rename(archive_path)
except Exception as exc: # pragma: no cover - 防御性兜底
logger.error("归档 embedding_model_test.json 失败: %s", exc)
print(f"[ERROR] 归档 embedding_model_test.json 失败,请检查文件权限与路径。错误详情已写入日志。")
print("[ERROR] 归档 embedding_model_test.json 失败,请检查文件权限与路径。错误详情已写入日志。")
return
print(

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,447 @@
import argparse
import asyncio
import os
import sys
import time
import json
import importlib
from typing import Optional, Dict, Any
from datetime import datetime
# 强制使用 utf-8避免控制台编码报错
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
# 确保能导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db
from src.common.database.database_model import LLMUsage
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo, GroupInfo
logger = get_logger("test_memory_retrieval")
# 使用 importlib 动态导入,避免循环导入问题
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try:
# 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
existing_module._process_single_question,
)
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith('src.memory_system.') and key != 'src.memory_system':
keys_to_remove.append(key)
for key in keys_to_remove:
try:
del sys.modules[key]
except KeyError:
pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try:
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
try:
import src.chat.replyer.group_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
try:
import src.chat.replyer.private_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name)
return (
memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question,
memory_retrieval_module._process_single_question,
)
except (ImportError, AttributeError) as e:
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
raise
def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream:
"""创建一个测试用的 ChatStream 对象"""
user_info = UserInfo(
platform="test",
user_id="test_user",
user_nickname="测试用户",
)
group_info = GroupInfo(
platform="test",
group_id="test_group",
group_name="测试群组",
)
return ChatStream(
stream_id=chat_id,
platform="test",
user_info=user_info,
group_info=group_info,
)
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况
Args:
start_time: 开始时间戳
Returns:
包含token使用统计的字典
"""
try:
start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录
records = (
LLMUsage.select()
.where(
(LLMUsage.timestamp >= start_datetime)
& (
(LLMUsage.request_type.like("%memory%"))
| (LLMUsage.request_type == "memory.question")
| (LLMUsage.request_type == "memory.react")
| (LLMUsage.request_type == "memory.react.final")
)
)
.order_by(LLMUsage.timestamp.asc())
)
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
total_cost = 0.0
request_count = 0
model_usage = {} # 按模型统计
for record in records:
total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0
request_count += 1
# 按模型统计
model_name = record.model_name or "unknown"
if model_name not in model_usage:
model_usage[model_name] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost": 0.0,
"request_count": 0,
}
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1
return {
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"total_cost": total_cost,
"request_count": request_count,
"model_usage": model_usage,
}
except Exception as e:
logger.error(f"获取token使用情况失败: {e}")
return {
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"total_tokens": 0,
"total_cost": 0.0,
"request_count": 0,
"model_usage": {},
}
def format_thinking_steps(thinking_steps: list) -> str:
"""格式化思考步骤为可读字符串"""
if not thinking_steps:
return "无思考步骤"
lines = []
for step in thinking_steps:
iteration = step.get("iteration", "?")
thought = step.get("thought", "")
actions = step.get("actions", [])
observations = step.get("observations", [])
lines.append(f"\n--- 迭代 {iteration} ---")
if thought:
lines.append(f"思考: {thought[:200]}...")
if actions:
lines.append("行动:")
for action in actions:
action_type = action.get("action_type", "unknown")
action_params = action.get("action_params", {})
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
if observations:
lines.append("观察:")
for obs in observations:
obs_str = str(obs)[:200]
if len(str(obs)) > 200:
obs_str += "..."
lines.append(f" - {obs_str}")
return "\n".join(lines)
async def test_memory_retrieval(
question: str,
chat_id: str = "test_memory_retrieval",
context: str = "",
max_iterations: Optional[int] = None,
) -> Dict[str, Any]:
"""测试记忆检索功能
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
max_iterations: 最大迭代次数
Returns:
包含测试结果的字典
"""
print("\n" + "=" * 80)
print(f"[测试] 记忆检索测试")
print(f"[问题] {question}")
print("=" * 80)
# 记录开始时间
start_time = time.time()
# 延迟导入并初始化记忆检索prompt这会自动加载 global_config
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
try:
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
# 检查 prompt 是否已经初始化,避免重复初始化
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt()
else:
logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化")
except Exception as e:
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
raise
# 获取 global_config此时应该已经加载
from src.config.config import global_config
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
timeout = global_config.memory.agent_timeout_seconds
print(f"\n[配置]")
print(f" 最大迭代次数: {max_iterations}")
print(f" 超时时间: {timeout}")
print(f" 聊天ID: {chat_id}")
# 执行检索
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=max_iterations,
timeout=timeout,
initial_info="",
)
# 记录结束时间
end_time = time.time()
elapsed_time = end_time - start_time
# 获取token使用情况
token_usage = get_token_usage_since(start_time)
# 构建结果
result = {
"question": question,
"found_answer": found_answer,
"answer": answer,
"is_timeout": is_timeout,
"elapsed_time": elapsed_time,
"thinking_steps": thinking_steps,
"iteration_count": len(thinking_steps),
"token_usage": token_usage,
}
# 输出结果
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
print(f"\n[结果]")
print(f" 是否找到答案: {'' if found_answer else ''}")
if found_answer and answer:
print(f" 答案: {answer}")
else:
print(f" 答案: (未找到答案)")
print(f" 是否超时: {'' if is_timeout else ''}")
print(f" 迭代次数: {len(thinking_steps)}")
print(f" 总耗时: {elapsed_time:.2f}")
print(f"\n[Token使用情况]")
print(f" 总请求数: {token_usage['request_count']}")
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
print(f" 总Tokens: {token_usage['total_tokens']:,}")
print(f" 总成本: ${token_usage['total_cost']:.6f}")
if token_usage['model_usage']:
print(f"\n[按模型统计]")
for model_name, usage in token_usage['model_usage'].items():
print(f" {model_name}:")
print(f" 请求数: {usage['request_count']}")
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
print(f" Completion Tokens: {usage['completion_tokens']:,}")
print(f" 总Tokens: {usage['total_tokens']:,}")
print(f" 成本: ${usage['cost']:.6f}")
print(f"\n[迭代详情]")
print(format_thinking_steps(thinking_steps))
print("\n" + "=" * 80)
return result
def main() -> None:
parser = argparse.ArgumentParser(
description="测试记忆检索功能。可以输入一个问题脚本会使用记忆检索的逻辑进行检索并记录迭代信息、时间和token总消耗。"
)
parser.add_argument(
"--chat-id",
default="test_memory_retrieval",
help="测试用的聊天ID默认: test_memory_retrieval",
)
parser.add_argument(
"--context",
default="",
help="上下文信息(可选)",
)
parser.add_argument(
"--output",
"-o",
help="将结果保存到JSON文件可选",
)
args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False)
# 交互式输入问题
print("\n" + "=" * 80)
print("记忆检索测试工具")
print("=" * 80)
question = input("\n请输入要查询的问题: ").strip()
if not question:
print("错误: 问题不能为空")
return
# 交互式输入最大迭代次数
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
max_iterations = None
if max_iterations_input:
try:
max_iterations = int(max_iterations_input)
if max_iterations <= 0:
print("警告: 迭代次数必须大于0将使用配置默认值")
max_iterations = None
except ValueError:
print("警告: 无效的迭代次数,将使用配置默认值")
max_iterations = None
# 连接数据库
try:
db.connect(reuse_if_open=True)
except Exception as e:
logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}")
return
# 运行测试
try:
result = asyncio.run(
test_memory_retrieval(
question=question,
chat_id=args.chat_id,
context=args.context,
max_iterations=max_iterations,
)
)
# 如果指定了输出文件,保存结果
if args.output:
# 将thinking_steps转换为可序列化的格式
output_result = result.copy()
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试")
except Exception as e:
logger.error(f"测试失败: {e}", exc_info=True)
print(f"\n[错误] 测试失败: {e}")
finally:
try:
db.close()
except Exception:
pass
if __name__ == "__main__":
main()

View File

@ -0,0 +1,252 @@
"""
表达方式自动检查定时任务
功能
1. 定期随机选取指定数量的表达方式
2. 使用LLM进行评估
3. 通过评估的rejected=0, checked=1
4. 未通过评估的rejected=1, checked=1
"""
import asyncio
import json
import random
from typing import List
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask
logger = get_logger("expression_auto_check_task")
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
# 基础评估标准
base_criteria = [
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
"允许部分语法错误或口头化或缺省出现",
"表达方式不能太过特指,需要具有泛用性",
"一般不涉及具体的人名或名称"
]
# 从配置中获取额外的自定义标准
custom_criteria = global_config.expression.expression_auto_check_custom_criteria
# 合并所有评估标准
all_criteria = base_criteria.copy()
if custom_criteria:
all_criteria.extend(custom_criteria)
# 构建评估标准列表字符串
criteria_list = "\n".join([f"{i+1}. {criterion}" for i, criterion in enumerate(all_criteria)])
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
使用条件或使用情景{situation}
表达方式或言语风格{style}
请从以下方面进行评估
{criteria_list}
请以JSON格式输出评估结果
{{
"suitable": true/false,
"reason": "评估理由(如果不合适,请说明原因)"
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容"""
return prompt
judge_llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_check"
)
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await judge_llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
class ExpressionAutoCheckTask(AsyncTask):
"""表达方式自动检查定时任务"""
def __init__(self):
# 从配置中获取检查间隔和一次检查数量
check_interval = global_config.expression.expression_auto_check_interval
super().__init__(
task_name="Expression Auto Check Task",
wait_before_start=60, # 启动后等待60秒再开始第一次检查
run_interval=check_interval
)
async def _select_expressions(self, count: int) -> List[Expression]:
"""
随机选择指定数量的未检查表达方式
Args:
count: 需要选择的数量
Returns:
选中的表达方式列表
"""
try:
# 查询所有未检查的表达方式checked=False
unevaluated_expressions = list(
Expression.select().where(~Expression.checked)
)
if not unevaluated_expressions:
logger.info("没有未检查的表达方式")
return []
# 随机选择指定数量
selected_count = min(count, len(unevaluated_expressions))
selected = random.sample(unevaluated_expressions, selected_count)
logger.info(f"{len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count}")
return selected
except Exception as e:
logger.error(f"选择表达方式时出错: {e}")
return []
async def _evaluate_expression(self, expression: Expression) -> bool:
"""
评估单个表达方式
Args:
expression: 要评估的表达方式
Returns:
True表示通过False表示不通过
"""
suitable, reason, error = await single_expression_check(
expression.situation,
expression.style,
)
# 更新数据库
try:
expression.checked = True
expression.rejected = not suitable # 通过则rejected=0不通过则rejected=1
expression.modified_by = 'ai' # 标记为AI检查
expression.save()
status = "通过" if suitable else "不通过"
logger.info(
f"表达方式评估完成 [ID: {expression.id}] - {status} | "
f"Situation: {expression.situation}... | "
f"Style: {expression.style}... | "
f"Reason: {reason[:50]}..."
)
if error:
logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}")
return suitable
except Exception as e:
logger.error(f"更新表达方式状态失败 [ID: {expression.id}]: {e}")
return False
async def run(self):
"""执行检查任务"""
try:
# 检查是否启用自动检查
if not global_config.expression.expression_self_reflect:
logger.debug("表达方式自动检查未启用,跳过本次执行")
return
check_count = global_config.expression.expression_auto_check_count
if check_count <= 0:
logger.warning(f"检查数量配置无效: {check_count},跳过本次执行")
return
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count}")
# 选择要检查的表达方式
expressions = await self._select_expressions(check_count)
if not expressions:
logger.info("没有需要检查的表达方式")
return
# 逐个评估
passed_count = 0
failed_count = 0
for i, expression in enumerate(expressions, 1):
logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
if await self._evaluate_expression(expression):
passed_count += 1
else:
failed_count += 1
# 避免请求过快
await asyncio.sleep(0.3)
logger.info(
f"表达方式自动检查完成: 总计 {len(expressions)} 条,"
f"通过 {passed_count} 条,不通过 {failed_count}"
)
except Exception as e:
logger.error(f"执行表达方式自动检查任务时出错: {e}", exc_info=True)

View File

@ -18,10 +18,13 @@ from src.bw_learner.learner_utils import (
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
calculate_style_similarity,
calculate_similarity,
parse_expression_response,
)
from src.bw_learner.jargon_miner import miner_manager
from json_repair import repair_json
from src.bw_learner.expression_auto_check_task import (
single_expression_check,
)
# MAX_EXPRESSION_COUNT = 300
@ -89,8 +92,9 @@ class ExpressionLearner:
model_set=model_config.model_task_config.utils, request_type="expression.learner"
)
self.summary_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
model_set=model_config.model_task_config.tool_use, request_type="expression.summary"
)
self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
@ -136,11 +140,21 @@ class ExpressionLearner:
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
expressions: List[Tuple[str, str, str]]
jargon_entries: List[Tuple[str, str]] # (content, source_id)
expressions, jargon_entries = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
expressions, jargon_entries = parse_expression_response(response)
# 从缓存中检查 jargon 是否出现在 messages 中
cached_jargon_entries = self._check_cached_jargons_in_messages(random_msg)
if cached_jargon_entries:
# 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过)
existing_contents = {content for content, _ in jargon_entries}
for content, source_id in cached_jargon_entries:
if content not in existing_contents:
jargon_entries.append((content, source_id))
existing_contents.add(content)
logger.info(f"从缓存中检查到黑话: {content}")
# 检查表达方式数量如果超过10个则放弃本次表达学习
if len(expressions) > 10:
if len(expressions) > 20:
logger.info(f"表达方式提取数量超过10个实际{len(expressions)}个),放弃本次表达学习")
expressions = []
@ -155,7 +169,7 @@ class ExpressionLearner:
# 如果没有表达方式,直接返回
if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
logger.info("解析后没有可用的表达方式")
return []
logger.info(f"学习的prompt: {prompt}")
@ -163,9 +177,60 @@ class ExpressionLearner:
logger.info(f"学习的jargon_entries: {jargon_entries}")
logger.info(f"学习的response: {response}")
# 直接根据 source_id 在 random_msg 中溯源,获取 context
# 过滤表达方式,根据 source_id 溯源并应用各种过滤规则
learnt_expressions = self._filter_expressions(expressions, random_msg)
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (situation,style) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表
for (situation,style) in learnt_expressions:
await self._upsert_expression_record(
situation=situation,
style=style,
current_time=current_time,
)
return learnt_expressions
def _filter_expressions(
self,
expressions: List[Tuple[str, str, str]],
messages: List[Any],
) -> List[Tuple[str, str, str]]:
"""
过滤表达方式移除不符合条件的条目
Args:
expressions: 表达方式列表每个元素是 (situation, style, source_id)
messages: 原始消息列表用于溯源和验证
Returns:
过滤后的表达方式列表每个元素是 (situation, style, context)
"""
filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
# 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达)
banned_names = set()
bot_nickname = (global_config.bot.nickname or "").strip()
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
alias = alias.strip()
if alias:
banned_names.add(alias)
banned_casefold = {name.casefold() for name in banned_names if name}
for situation, style, source_id in expressions:
source_id_str = (source_id or "").strip()
if not source_id_str.isdigit():
@ -173,12 +238,12 @@ class ExpressionLearner:
continue
line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
if line_index < 0 or line_index >= len(random_msg):
if line_index < 0 or line_index >= len(messages):
# 超出范围,跳过
continue
# 当前行的原始内容
current_msg = random_msg[line_index]
current_msg = messages[line_index]
# 过滤掉从bot自己发言中提取到的表达方式
if is_bot_message(current_msg):
@ -195,251 +260,53 @@ class ExpressionLearner:
)
continue
filtered_expressions.append((situation, style, context))
learnt_expressions = filtered_expressions
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (
situation,
style,
_context,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表
for (
situation,
style,
context,
) in learnt_expressions:
await self._upsert_expression_record(
situation=situation,
style=style,
context=context,
current_time=current_time,
)
return learnt_expressions
def parse_expression_response(self, response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
"""
解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表
期望的 JSON 结构
[
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
{"content": "词条", "source_id": "12"}, // 黑话
...
]
Returns:
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
第一个列表是表达方式 (situation, style, source_id)
第二个列表是黑话 (content, source_id)
"""
if not response:
return [], []
raw = response.strip()
# 尝试提取 ```json 代码块
json_block_pattern = r"```json\s*(.*?)\s*```"
match = re.search(json_block_pattern, raw, re.DOTALL)
if match:
raw = match.group(1).strip()
else:
# 去掉可能存在的通用 ``` 包裹
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
raw = raw.strip()
parsed = None
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
try:
# 优先尝试直接解析
if raw.startswith("[") and raw.endswith("]"):
parsed = json.loads(raw)
else:
repaired = repair_json(raw)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
except Exception as parse_error:
# 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try:
def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
in_string = False
escape_next = False
while i < len(text):
char = text[i]
if escape_next:
# 当前字符是转义字符后的字符,直接添加
result.append(char)
escape_next = False
i += 1
continue
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
i += 1
continue
if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态
in_string = not in_string
result.append(char)
i += 1
continue
if in_string:
# 在字符串值内部,将中文引号替换为转义的英文引号
if char == '"': # 中文左引号 U+201C
result.append('\\"')
elif char == '"': # 中文右引号 U+201D
result.append('\\"')
else:
result.append(char)
else:
# 不在字符串内,直接添加
result.append(char)
i += 1
return "".join(result)
fixed_raw = fix_chinese_quotes_in_json(raw)
# 再次尝试解析
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
parsed = json.loads(fixed_raw)
else:
repaired = repair_json(fixed_raw)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
except Exception as fix_error:
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
logger.error(f"处理后的 JSON 字符串前500字符{raw[:500]}")
return [], []
if isinstance(parsed, dict):
parsed_list = [parsed]
elif isinstance(parsed, list):
parsed_list = parsed
else:
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
return [], []
for item in parsed_list:
if not isinstance(item, dict):
# 过滤掉 style 与机器人名称/昵称重复的表达
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() in banned_casefold:
logger.debug(
f"跳过 style 与机器人名称重复的表达方式: situation={situation}, style={style}, source_id={source_id}"
)
continue
# 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip()
# 过滤掉包含 "表情:" 或 "表情:" 的内容
if "表情:" in (situation or "") or "表情:" in (situation or "") or \
"表情:" in (style or "") or "表情:" in (style or "") or \
"表情:" in context or "表情:" in context:
logger.info(
f"跳过包含表情标记的表达方式: situation={situation}, style={style}, source_id={source_id}"
)
continue
if situation and style and source_id:
# 表达方式条目
expressions.append((situation, style, source_id))
elif item.get("content"):
# 黑话条目(有 content 字段)
content = str(item.get("content", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if content and source_id:
jargon_entries.append((content, source_id))
# 过滤掉包含 "[图片" 的内容
if "[图片" in (situation or "") or "[图片" in (style or "") or "[图片" in context:
logger.info(
f"跳过包含图片标记的表达方式: situation={situation}, style={style}, source_id={source_id}"
)
continue
return expressions, jargon_entries
filtered_expressions.append((situation, style))
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
"""
过滤掉style与机器人名称/昵称重复的表达
"""
banned_names = set()
bot_nickname = (global_config.bot.nickname or "").strip()
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
alias = alias.strip()
if alias:
banned_names.add(alias)
banned_casefold = {name.casefold() for name in banned_names if name}
filtered: List[Tuple[str, str, str]] = []
removed_count = 0
for situation, style, source_id in expressions:
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() not in banned_casefold:
filtered.append((situation, style, source_id))
else:
removed_count += 1
if removed_count:
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
return filtered
return filtered_expressions
async def _upsert_expression_record(
self,
situation: str,
style: str,
context: str,
current_time: float,
) -> None:
# 第一层:检查是否有完全一致的 style检查 style 字段和 style_list
expr_obj = await self._find_exact_style_match(style)
# 检查是否有相似的 situation相似度 >= 0.75,检查 content_list
# 完全匹配(相似度 == 1.0)和相似匹配(相似度 >= 0.75)统一处理
expr_obj, similarity = await self._find_similar_situation_expression(situation, similarity_threshold=0.75)
if expr_obj:
# 找到完全匹配的 style合并到现有记录不使用 LLM 总结)
# 根据相似度决定是否使用 LLM 总结
# 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结
use_llm_summary = similarity < 1.0
await self._update_existing_expression(
expr_obj=expr_obj,
situation=situation,
style=style,
context=context,
current_time=current_time,
use_llm_summary=False,
)
return
# 第二层:检查是否有相似的 style相似度 >= 0.75,检查 style 字段和 style_list
similar_expr_obj = await self._find_similar_style_expression(style, similarity_threshold=0.75)
if similar_expr_obj:
# 找到相似的 style合并到现有记录使用 LLM 总结)
await self._update_existing_expression(
expr_obj=similar_expr_obj,
situation=situation,
style=style,
context=context,
current_time=current_time,
use_llm_summary=True,
use_llm_summary=use_llm_summary,
)
return
@ -447,7 +314,6 @@ class ExpressionLearner:
await self._create_expression_record(
situation=situation,
style=style,
context=context,
current_time=current_time,
)
@ -455,7 +321,6 @@ class ExpressionLearner:
self,
situation: str,
style: str,
context: str,
current_time: float,
) -> None:
content_list = [situation]
@ -466,26 +331,22 @@ class ExpressionLearner:
situation=formatted_situation,
style=style,
content_list=json.dumps(content_list, ensure_ascii=False),
style_list=None, # 新记录初始时 style_list 为空
count=1,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time,
context=context,
)
async def _update_existing_expression(
self,
expr_obj: Expression,
situation: str,
style: str,
context: str,
current_time: float,
use_llm_summary: bool = True,
) -> None:
"""
更新现有 Expression 记录style 完全匹配或相似的情况
将新的 situation 添加到 content_list将新的 style 添加到 style_list如果不同
更新现有 Expression 记录situation 完全匹配或相似的情况
将新的 situation 添加到 content_list不合并 style
Args:
use_llm_summary: 是否使用 LLM 进行总结完全匹配时为 False相似匹配时为 True
@ -495,43 +356,24 @@ class ExpressionLearner:
content_list.append(situation)
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
# 更新 style_list如果 style 不同,添加到 style_list
style_list = self._parse_style_list(expr_obj.style_list)
# 将原有的 style 也加入 style_list如果还没有的话
if expr_obj.style and expr_obj.style not in style_list:
style_list.append(expr_obj.style)
# 如果新的 style 不在 style_list 中,添加它
if style not in style_list:
style_list.append(style)
expr_obj.style_list = json.dumps(style_list, ensure_ascii=False)
# 更新其他字段
expr_obj.count = (expr_obj.count or 0) + 1
expr_obj.checked = False # count 增加时重置 checked 为 False
expr_obj.last_active_time = current_time
expr_obj.context = context
if use_llm_summary:
# 相似匹配时,使用 LLM 重新组合 situation 和 style
# 相似匹配时,使用 LLM 重新组合 situation
new_situation = await self._compose_situation_text(
content_list=content_list,
count=expr_obj.count,
fallback=expr_obj.situation,
)
expr_obj.situation = new_situation
new_style = await self._compose_style_text(
style_list=style_list,
count=expr_obj.count,
fallback=expr_obj.style or style,
)
expr_obj.style = new_style
else:
# 完全匹配时,不进行 LLM 总结,保持原有的 situation 和 style 不变
# 只更新 content_list 和 style_list
pass
expr_obj.save()
# count 增加后,立即进行一次检查
await self._check_expression_immediately(expr_obj)
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
if not stored_list:
return []
@ -541,49 +383,19 @@ class ExpressionLearner:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def _parse_style_list(self, stored_list: Optional[str]) -> List[str]:
"""解析 style_list JSON 字符串为列表,逻辑与 _parse_content_list 相同"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
async def _find_exact_style_match(self, style: str) -> Optional[Expression]:
async def _find_similar_situation_expression(self, situation: str, similarity_threshold: float = 0.75) -> Tuple[Optional[Expression], float]:
"""
查找具有完全匹配 style Expression 记录
只检查 style_list 中的每一项不检查 style 字段因为 style 可能是总结后的概括性描述
查找具有相似 situation Expression 记录
检查 content_list 中的每一项
Args:
style: 要查找的 style
Returns:
找到的 Expression 对象如果没有找到则返回 None
"""
# 查询同一 chat_id 的所有记录
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
for expr in all_expressions:
# 只检查 style_list 中的每一项
style_list = self._parse_style_list(expr.style_list)
if style in style_list:
return expr
return None
async def _find_similar_style_expression(self, style: str, similarity_threshold: float = 0.75) -> Optional[Expression]:
"""
查找具有相似 style Expression 记录
只检查 style_list 中的每一项不检查 style 字段因为 style 可能是总结后的概括性描述
Args:
style: 要查找的 style
situation: 要查找的 situation
similarity_threshold: 相似度阈值默认 0.75
Returns:
找到的最相似的 Expression 对象如果没有找到则返回 None
Tuple[Optional[Expression], float]:
- 找到的最相似的 Expression 对象如果没有找到则返回 None
- 相似度值如果找到匹配范围在 similarity_threshold 1.0 之间
"""
# 查询同一 chat_id 的所有记录
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
@ -592,96 +404,28 @@ class ExpressionLearner:
best_similarity = 0.0
for expr in all_expressions:
# 只检查 style_list 中的每一项
style_list = self._parse_style_list(expr.style_list)
for existing_style in style_list:
similarity = calculate_style_similarity(style, existing_style)
# 检查 content_list 中的每一项
content_list = self._parse_content_list(expr.content_list)
for existing_situation in content_list:
similarity = calculate_similarity(situation, existing_situation)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match:
logger.debug(f"找到相似的 style: 相似度={best_similarity:.3f}, 现有='{best_match.style}', 新='{style}'")
logger.debug(f"找到相似的 situation: 相似度={best_similarity:.3f}, 现有='{best_match.situation}', 新='{situation}'")
return best_match
return best_match, best_similarity
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
async def _compose_situation_text(self, content_list: List[str], fallback: str = "") -> str:
sanitized = [c.strip() for c in content_list if c.strip()]
summary = await self._summarize_situations(sanitized)
if summary:
return summary
return "/".join(sanitized) if sanitized else fallback
async def _compose_style_text(self, style_list: List[str], count: int, fallback: str = "") -> str:
"""
组合 style 文本如果 style_list 有多个元素则尝试总结
"""
sanitized = [s.strip() for s in style_list if s.strip()]
if len(sanitized) > 1:
# 只有当有多个 style 时才尝试总结
summary = await self._summarize_styles(sanitized)
if summary:
return summary
# 如果只有一个或总结失败,返回第一个或 fallback
return sanitized[0] if sanitized else fallback
async def _summarize_styles(self, styles: List[str]) -> Optional[str]:
"""总结多个 style生成一个概括性的 style 描述"""
if not styles or len(styles) <= 1:
return None
# 计算输入列表中最长项目的长度
max_input_length = max(len(s) for s in styles) if styles else 0
max_summary_length = max_input_length * 2
# 最多重试3次
max_retries = 3
retry_count = 0
while retry_count < max_retries:
# 如果是重试,在 prompt 中强调要更简洁
length_hint = f"长度不超过{max_summary_length}个字符," if retry_count > 0 else "长度不超过20个字"
prompt = (
"请阅读以下多个语言风格/表达方式,对其进行总结。"
"不要对其进行语义概括,而是尽可能找出其中不变的部分或共同表达,尽量使用原文"
f"{length_hint}保留共同特点:\n"
f"{chr(10).join(f'- {s}' for s in styles[-10:])}\n只输出概括内容。不要输出其他内容"
)
try:
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
# 检查总结长度是否超过限制
if len(summary) <= max_summary_length:
return summary
else:
retry_count += 1
logger.debug(
f"总结长度 {len(summary)} 超过限制 {max_summary_length} "
f"(输入最长项长度: {max_input_length}),重试第 {retry_count}"
)
continue
except Exception as e:
logger.error(f"概括表达风格失败: {e}")
return None
# 如果重试多次后仍然超过长度,返回 None不进行总结
logger.warning(
f"总结多次后仍超过长度限制,放弃总结。"
f"输入最长项长度: {max_input_length}, 最大允许长度: {max_summary_length}"
)
return None
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
if not situations:
return None
if not sanitized:
return fallback
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
@ -691,7 +435,126 @@ class ExpressionLearner:
return summary
except Exception as e:
logger.error(f"概括表达情境失败: {e}")
return None
return "/".join(sanitized) if sanitized else fallback
async def _init_check_model(self) -> None:
"""初始化检查用的 LLM 实例"""
if self.check_model is None:
try:
self.check_model = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression.check"
)
logger.debug("检查用 LLM 实例初始化成功")
except Exception as e:
logger.error(f"创建检查用 LLM 实例失败: {e}")
async def _check_expression_immediately(self, expr_obj: Expression) -> None:
"""
立即检查表达方式 count 增加后调用
Args:
expr_obj: 要检查的表达方式对象
"""
try:
# 检查是否启用自动检查
if not global_config.expression.expression_self_reflect:
logger.debug("表达方式自动检查未启用,跳过立即检查")
return
# 初始化检查用的 LLM
await self._init_check_model()
if self.check_model is None:
logger.warning("检查用 LLM 实例初始化失败,跳过立即检查")
return
# 执行 LLM 评估
suitable, reason, error = await single_expression_check(
expr_obj.situation,
expr_obj.style
)
# 更新数据库
expr_obj.checked = True
expr_obj.rejected = not suitable # 通过则 rejected=False不通过则 rejected=True
expr_obj.save()
status = "通过" if suitable else "不通过"
logger.info(
f"表达方式立即检查完成 [ID: {expr_obj.id}] - {status} | "
f"Situation: {expr_obj.situation[:30]}... | "
f"Style: {expr_obj.style[:30]}... | "
f"Reason: {reason[:50] if reason else ''}..."
)
if error:
logger.warning(f"表达方式立即检查时出现错误 [ID: {expr_obj.id}]: {error}")
except Exception as e:
logger.error(f"立即检查表达方式失败 [ID: {expr_obj.id}]: {e}", exc_info=True)
# 检查失败时,保持 checked=False等待后续自动检查任务处理
def _check_cached_jargons_in_messages(self, messages: List[Any]) -> List[Tuple[str, str]]:
"""
检查缓存中的 jargon 是否出现在 messages
Args:
messages: 消息列表
Returns:
List[Tuple[str, str]]: 匹配到的黑话条目列表每个元素是 (content, source_id)
"""
if not messages:
return []
# 获取 jargon_miner 实例
jargon_miner = miner_manager.get_miner(self.chat_id)
# 获取缓存中的所有 jargon
cached_jargons = jargon_miner.get_cached_jargons()
if not cached_jargons:
return []
matched_entries: List[Tuple[str, str]] = []
# 遍历 messages检查缓存中的 jargon 是否出现
for i, msg in enumerate(messages):
# 跳过机器人自己的消息
if is_bot_message(msg):
continue
# 获取消息文本
msg_text = (
getattr(msg, "processed_plain_text", None) or
""
).strip()
if not msg_text:
continue
# 检查每个缓存中的 jargon 是否出现在消息文本中
for jargon in cached_jargons:
if not jargon or not jargon.strip():
continue
jargon_content = jargon.strip()
# 使用正则匹配,考虑单词边界(类似 jargon_explainer 中的逻辑)
pattern = re.escape(jargon_content)
# 对于中文,使用更宽松的匹配;对于英文/数字,使用单词边界
if re.search(r"[\u4e00-\u9fff]", jargon_content):
# 包含中文,使用更宽松的匹配
search_pattern = pattern
else:
# 纯英文/数字,使用单词边界
search_pattern = r"\b" + pattern + r"\b"
if re.search(search_pattern, msg_text, re.IGNORECASE):
# 找到匹配构建条目source_id 从 1 开始,因为 build_anonymous_messages 的编号从 1 开始)
source_id = str(i + 1)
matched_entries.append((jargon_content, source_id))
return matched_entries
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
"""

View File

@ -28,11 +28,11 @@ class ExpressionReflector:
try:
logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
if not global_config.expression.reflect:
if not global_config.expression.expression_self_reflect:
logger.debug("[Expression Reflection] 表达反思功能未启用,跳过")
return False
operator_config = global_config.expression.reflect_operator_id
operator_config = global_config.expression.manual_reflect_operator_id
if not operator_config:
logger.debug("[Expression Reflection] Operator ID 未配置,跳过")
return False

View File

@ -45,7 +45,7 @@ def init_prompt():
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
model_set=model_config.model_task_config.tool_use, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
@ -123,9 +123,11 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
)
# 如果 expression_checked_only 为 True则只选择 checked=True 且 rejected=False 的
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
if global_config.expression.expression_checked_only:
base_conditions = base_conditions & (Expression.checked)
style_query = Expression.select().where(base_conditions)
style_exprs = [
{
@ -202,7 +204,11 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
# 如果 expression_checked_only 为 True则只选择 checked=True 且 rejected=False 的
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
if global_config.expression.expression_checked_only:
base_conditions = base_conditions & (Expression.checked)
style_query = Expression.select().where(base_conditions)
style_exprs = [
{
@ -295,7 +301,11 @@ class ExpressionSelector:
# think_level == 1: 先选高count再从所有表达方式中随机抽样
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
related_chat_ids = self.get_related_chat_ids(chat_id)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
# 如果 expression_checked_only 为 True则只选择 checked=True 且 rejected=False 的
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
if global_config.expression.expression_checked_only:
base_conditions = base_conditions & (Expression.checked)
style_query = Expression.select().where(base_conditions)
all_style_exprs = [
{
@ -407,8 +417,8 @@ class ExpressionSelector:
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
print(prompt)
print(content)
# print(prompt)
# print(content)
if not content:
logger.warning("LLM返回空结果")

View File

@ -45,7 +45,7 @@ class JargonExplainer:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.llm = LLMRequest(
model_set=model_config.model_task_config.utils,
model_set=model_config.model_task_config.tool_use,
request_type="jargon.explain",
)
@ -341,7 +341,7 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
meaning = result.get("meaning", "").strip()
if found_content and meaning:
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
results.append("".join(output_parts))
results.append("\n".join(output_parts)) # 换行分隔每个jargon解释
logger.info(f"在jargon库中找到匹配模糊搜索: {concept},找到{len(jargon_results)}条结果")
else:
# 精确匹配
@ -350,7 +350,8 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
meaning = result.get("meaning", "").strip()
if meaning:
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
results.append("".join(output_parts) if len(output_parts) > 1 else output_parts[0])
# 换行分隔每个jargon解释
results.append("\n".join(output_parts) if len(output_parts) > 1 else output_parts[0])
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
else:
# 未找到,不返回占位信息,只记录日志
@ -361,5 +362,5 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
if results:
return "【概念检索结果】\n" + "\n".join(results) + "\n"
return "你了解以下词语可能的含义:\n" + "\n".join(results) + "\n"
return ""

View File

@ -2,7 +2,7 @@ import json
import asyncio
import random
from collections import OrderedDict
from typing import List, Dict, Optional, Any, Callable
from typing import List, Dict, Optional, Callable
from json_repair import repair_json
from peewee import fn
@ -11,14 +11,8 @@ from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.learner_utils import (
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
parse_chat_id_list,
chat_id_list_contains,
update_chat_id_list,
@ -51,32 +45,32 @@ def _is_single_char_jargon(content: str) -> bool:
)
def _init_prompt() -> None:
prompt_str = """
**聊天内容其中的{bot_name}的发言内容是你自己的发言[msg_id] 是消息ID**
{chat_str}
# def _init_prompt() -> None:
# prompt_str = """
# **聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
# {chat_str}
请从上面这段聊天内容中提取"可能是黑话"的候选项黑话/俚语/网络缩写/口头禅
- 必须为对话中真实出现过的短词或短语
- 必须是你无法理解含义的词语没有明确含义的词语请不要选择有明确含义或者含义清晰的词语
- 排除人名@表情包/图片中的内容纯标点常规功能词如的啊等
- 每个词条长度建议 2-8 个字符不强制尽量短小
# 请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)
# - 必须为对话中真实出现过的短词或短语
# - 必须是你无法理解含义的词语,没有明确含义的词语,请不要选择有明确含义,或者含义清晰的词语
# - 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等
# - 每个词条长度建议 2-8 个字符(不强制),尽量短小
黑话必须为以下几种类型
- 由字母构成的汉语拼音首字母的简写词例如nbyydsxswl
- 英文词语的缩写用英文字母概括一个词汇或含义例如CPUGPUAPI
- 中文词语的缩写用几个汉字概括一个词汇或含义例如社死内卷
# 黑话必须为以下几种类型
# - 由字母构成的汉语拼音首字母的简写词例如nb、yyds、xswl
# - 英文词语的缩写用英文字母概括一个词汇或含义例如CPU、GPU、API
# - 中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
JSON 数组输出元素为对象严格按以下结构
请你提取出可能的黑话最多30个黑话请尽量提取所有
[
{{"content": "词条", "msg_id": "m12"}}, // msg_id 必须与上方聊天中展示的ID完全一致
{{"content": "词条2", "msg_id": "m15"}}
]
# 以 JSON 数组输出,元素为对象(严格按以下结构)
# 请你提取出可能的黑话最多30个黑话请尽量提取所有
# [
# {{"content": "词条", "msg_id": "m12"}}, // msg_id 必须与上方聊天中展示的ID完全一致
# {{"content": "词条2", "msg_id": "m15"}}
# ]
现在请输出
"""
Prompt(prompt_str, "extract_jargon_prompt")
# 现在请输出
# """
# Prompt(prompt_str, "extract_jargon_prompt")
def _init_inference_prompts() -> None:
@ -142,7 +136,6 @@ def _init_inference_prompts() -> None:
Prompt(prompt3_str, "jargon_compare_inference_prompt")
_init_prompt()
_init_inference_prompts()
@ -229,34 +222,9 @@ class JargonMiner:
if len(self.cache) > self.cache_limit:
self.cache.popitem(last=False)
def _collect_cached_entries(self, messages: List[Any]) -> List[Dict[str, List[str]]]:
"""检查缓存中的黑话是否出现在当前消息窗口,生成对应上下文"""
if not self.cache or not messages:
return []
cached_entries: List[Dict[str, List[str]]] = []
processed_pairs = set()
for idx, msg in enumerate(messages):
msg_text = (
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
).strip()
if not msg_text or is_bot_message(msg):
continue
for content in self.cache.keys():
if not content:
continue
if (content, idx) in processed_pairs:
continue
if content in msg_text:
paragraph = build_context_paragraph(messages, idx)
if not paragraph:
continue
cached_entries.append({"content": content, "raw_content": [paragraph]})
processed_pairs.add((content, idx))
return cached_entries
def get_cached_jargons(self) -> List[str]:
"""获取缓存中的所有黑话列表"""
return list(self.cache.keys())
async def _infer_meaning_by_id(self, jargon_id: int) -> None:
"""通过ID加载对象并推断"""
@ -480,263 +448,6 @@ class JargonMiner:
traceback.print_exc()
async def run_once(
self,
messages: List[Any],
person_name_filter: Optional[Callable[[str], bool]] = None
) -> None:
"""
运行一次黑话提取
Args:
messages: 外部传入的消息列表必需
person_name_filter: 可选的过滤函数用于检查内容是否包含人物名称
"""
# 使用异步锁防止并发执行
async with self._extraction_lock:
try:
if not messages:
return
# 按时间排序,确保编号与上下文一致
messages = sorted(messages, key=lambda msg: msg.time or 0)
chat_str, message_id_list = build_readable_messages_with_id(
messages=messages,
replace_bot_name=True,
timestamp_mode="relative",
truncate=False,
show_actions=False,
show_pic=True,
pic_single=True,
)
if not chat_str.strip():
return
msg_id_to_index: Dict[str, int] = {}
for idx, (msg_id, _msg) in enumerate(message_id_list or []):
if not msg_id:
continue
msg_id_to_index[msg_id] = idx
if not msg_id_to_index:
logger.warning("未能生成消息ID映射跳过本次提取")
return
prompt: str = await global_prompt_manager.format_prompt(
"extract_jargon_prompt",
bot_name=global_config.bot.nickname,
chat_str=chat_str,
)
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
if not response:
return
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon提取提示词: {prompt}")
logger.info(f"jargon提取结果: {response}")
# 解析为JSON
entries: List[dict] = []
try:
resp = response.strip()
parsed = None
if resp.startswith("[") and resp.endswith("]"):
parsed = json.loads(resp)
else:
repaired = repair_json(resp)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
if isinstance(parsed, dict):
parsed = [parsed]
if not isinstance(parsed, list):
return
for item in parsed:
if not isinstance(item, dict):
continue
content = str(item.get("content", "")).strip()
msg_id_value = item.get("msg_id")
if not content:
continue
if contains_bot_self_name(content):
logger.info(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
continue
# 检查是否包含人物名称
if person_name_filter and person_name_filter(content):
logger.info(f"解析阶段跳过包含人物名称的词条: {content}")
continue
msg_id_str = str(msg_id_value or "").strip()
if not msg_id_str:
logger.warning(f"解析jargon失败msg_id缺失content={content}")
continue
msg_index = msg_id_to_index.get(msg_id_str)
if msg_index is None:
logger.warning(f"解析jargon失败msg_id未找到content={content}, msg_id={msg_id_str}")
continue
target_msg = messages[msg_index]
if is_bot_message(target_msg):
logger.info(f"解析阶段跳过引用机器人自身消息的词条: content={content}, msg_id={msg_id_str}")
continue
context_paragraph = build_context_paragraph(messages, msg_index)
if not context_paragraph:
logger.warning(f"解析jargon失败上下文为空content={content}, msg_id={msg_id_str}")
continue
entries.append({"content": content, "raw_content": [context_paragraph]})
cached_entries = self._collect_cached_entries(messages)
if cached_entries:
entries.extend(cached_entries)
except Exception as e:
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
return
if not entries:
return
# 去重并合并raw_content按 content 聚合)
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
for entry in entries:
content_key = entry["content"]
raw_list = entry.get("raw_content", []) or []
if content_key in merged_entries:
merged_entries[content_key]["raw_content"].extend(raw_list)
else:
merged_entries[content_key] = {
"content": content_key,
"raw_content": list(raw_list),
}
uniq_entries = []
for merged_entry in merged_entries.values():
raw_content_list = merged_entry["raw_content"]
if raw_content_list:
merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list))
uniq_entries.append(merged_entry)
saved = 0
updated = 0
for entry in uniq_entries:
content = entry["content"]
raw_content_list = entry["raw_content"] # 已经是列表
try:
# 查询所有content匹配的记录
query = Jargon.select().where(Jargon.content == content)
# 查找匹配的记录
matched_obj = None
for obj in query:
if global_config.expression.all_global_jargon:
# 开启all_global所有content匹配的记录都可以
matched_obj = obj
break
else:
# 关闭all_global需要检查chat_id列表是否包含目标chat_id
chat_id_list = parse_chat_id_list(obj.chat_id)
if chat_id_list_contains(chat_id_list, self.chat_id):
matched_obj = obj
break
if matched_obj:
obj = matched_obj
try:
obj.count = (obj.count or 0) + 1
except Exception:
obj.count = 1
# 合并raw_content列表读取现有列表追加新值去重
existing_raw_content = []
if obj.raw_content:
try:
existing_raw_content = (
json.loads(obj.raw_content)
if isinstance(obj.raw_content, str)
else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError):
existing_raw_content = [obj.raw_content] if obj.raw_content else []
# 合并并去重
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
# 更新chat_id列表增加当前chat_id的计数
chat_id_list = parse_chat_id_list(obj.chat_id)
updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1)
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
# 开启all_global时确保记录标记为is_global=True
if global_config.expression.all_global_jargon:
obj.is_global = True
# 关闭all_global时保持原有is_global不变不修改
obj.save()
# 检查是否需要推断(达到阈值且超过上次判定值)
if _should_infer_meaning(obj):
# 异步触发推断,不阻塞主流程
# 重新加载对象以确保数据最新
jargon_id = obj.id
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
updated += 1
else:
# 没找到匹配记录,创建新记录
if global_config.expression.all_global_jargon:
# 开启all_global新记录默认为is_global=True
is_global_new = True
else:
# 关闭all_global新记录is_global=False
is_global_new = False
# 使用新格式创建chat_id列表[[chat_id, count]]
chat_id_list = [[self.chat_id, 1]]
chat_id_json = json.dumps(chat_id_list, ensure_ascii=False)
Jargon.create(
content=content,
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
chat_id=chat_id_json,
is_global=is_global_new,
count=1,
)
saved += 1
except Exception as e:
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
continue
finally:
self._add_to_cache(content)
# 固定输出提取的jargon结果格式化为可读形式只要有提取结果就输出
if uniq_entries:
# 收集所有提取的jargon内容
jargon_list = [entry["content"] for entry in uniq_entries]
jargon_str = ",".join(jargon_list)
# 输出格式化的结果使用logger.info会自动应用jargon模块的颜色
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
if saved or updated:
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated}chat_id={self.chat_id}")
except Exception as e:
logger.error(f"JargonMiner 运行失败: {e}")
# 即使失败也保持时间戳更新,避免频繁重试
async def process_extracted_entries(
self,
entries: List[Dict[str, List[str]]],

View File

@ -2,8 +2,7 @@ import re
import difflib
import random
import json
from datetime import datetime
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
@ -11,6 +10,7 @@ from src.chat.utils.chat_message_builder import (
build_readable_messages,
)
from src.chat.utils.utils import parse_platform_accounts
from json_repair import repair_json
logger = get_logger("learner_utils")
@ -88,33 +88,15 @@ def calculate_style_similarity(style1: str, style2: str) -> float:
return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
Args:
timestamp: 时间戳
Returns:
str: 格式化后的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~5之间
count越高权重越高但最多为基础权重的5倍
如果表达已checked权重会再乘以3倍
"""
if not population:
return []
counts = []
checked_flags = []
for item in population:
count = item.get("count", 1)
try:
@ -122,29 +104,19 @@ def _compute_weights(population: List[Dict]) -> List[float]:
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
# 获取checked状态
checked = item.get("checked", False)
checked_flags.append(bool(checked))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
base_weights = [1.0 for _ in counts]
weights = [1.0 for _ in counts]
else:
base_weights = []
weights = []
for count_value in counts:
# 线性映射到[1,5]区间
normalized = (count_value - min_count) / (max_count - min_count)
base_weights.append(1.0 + normalized * 4.0) # 1~5
weights.append(1.0 + normalized * 4.0) # 1~5
# 如果checked权重乘以3
weights = []
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
if checked:
weights.append(base_weight * 3.0)
else:
weights.append(base_weight)
return weights
@ -378,3 +350,149 @@ def is_bot_message(msg: Any) -> bool:
bot_account = bot_accounts.get(platform)
return bool(bot_account and user_id == bot_account)
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
"""
解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表
期望的 JSON 结构
[
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
{"content": "词条", "source_id": "12"}, // 黑话
...
]
Returns:
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
第一个列表是表达方式 (situation, style, source_id)
第二个列表是黑话 (content, source_id)
"""
if not response:
return [], []
raw = response.strip()
# 尝试提取 ```json 代码块
json_block_pattern = r"```json\s*(.*?)\s*```"
match = re.search(json_block_pattern, raw, re.DOTALL)
if match:
raw = match.group(1).strip()
else:
# 去掉可能存在的通用 ``` 包裹
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
raw = raw.strip()
parsed = None
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
try:
# 优先尝试直接解析
if raw.startswith("[") and raw.endswith("]"):
parsed = json.loads(raw)
else:
repaired = repair_json(raw)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
except Exception as parse_error:
# 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try:
def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
in_string = False
escape_next = False
while i < len(text):
char = text[i]
if escape_next:
# 当前字符是转义字符后的字符,直接添加
result.append(char)
escape_next = False
i += 1
continue
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
i += 1
continue
if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态
in_string = not in_string
result.append(char)
i += 1
continue
if in_string:
# 在字符串值内部,将中文引号替换为转义的英文引号
if char == '"': # 中文左引号 U+201C
result.append('\\"')
elif char == '"': # 中文右引号 U+201D
result.append('\\"')
else:
result.append(char)
else:
# 不在字符串内,直接添加
result.append(char)
i += 1
return "".join(result)
fixed_raw = fix_chinese_quotes_in_json(raw)
# 再次尝试解析
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
parsed = json.loads(fixed_raw)
else:
repaired = repair_json(fixed_raw)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
except Exception as fix_error:
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
logger.error(f"处理后的 JSON 字符串前500字符{raw[:500]}")
return [], []
if isinstance(parsed, dict):
parsed_list = [parsed]
elif isinstance(parsed, list):
parsed_list = parsed
else:
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
return [], []
for item in parsed_list:
if not isinstance(item, dict):
continue
# 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if situation and style and source_id:
# 表达方式条目
expressions.append((situation, style, source_id))
elif item.get("content"):
# 黑话条目(有 content 字段)
content = str(item.get("content", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if content and source_id:
jargon_entries.append((content, source_id))
return expressions, jargon_entries

View File

@ -116,20 +116,12 @@ class MessageRecorder:
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
)
# 分别触发 expression_learner 和 jargon_miner 的处理
# 传递提取的消息,避免它们重复获取
# 触发 expression 学习(如果启用)
# 触发 expression_learner 和 jargon_miner 的处理
if self.enable_expression_learning:
asyncio.create_task(
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
self._trigger_expression_learning(messages)
)
# 触发 jargon 提取(如果启用),传递消息
# if self.enable_jargon_learning:
# asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# )
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
import traceback
@ -138,7 +130,7 @@ class MessageRecorder:
# 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning(
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
self, messages: List[Any]
) -> None:
"""
触发 expression 学习使用指定的消息列表
@ -162,27 +154,6 @@ class MessageRecorder:
traceback.print_exc()
async def _trigger_jargon_extraction(
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
) -> None:
"""
触发 jargon 提取使用指定的消息列表
Args:
timestamp_start: 开始时间戳
timestamp_end: 结束时间戳
messages: 消息列表
"""
try:
# 传递消息给 JargonMiner避免它重复获取
await self.jargon_miner.run_once(messages=messages)
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
import traceback
traceback.print_exc()
class MessageRecorderManager:
"""MessageRecorder 管理器"""

View File

@ -28,7 +28,7 @@ class ReflectTracker:
self.max_duration = 15 * 60 # 15 minutes
# LLM for judging response
self.judge_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reflect.tracker")
self.judge_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="reflect.tracker")
self._init_prompts()
@ -134,12 +134,14 @@ class ReflectTracker:
if judgment == "Approve":
self.expression.checked = True
self.expression.rejected = False
self.expression.modified_by = 'ai' # 通过LLM判断也标记为ai
self.expression.save()
logger.info(f"Expression {self.expression.id} approved by operator.")
return True
elif judgment == "Reject":
self.expression.checked = True
self.expression.modified_by = 'ai' # 通过LLM判断也标记为ai
corrected_situation = json_obj.get("corrected_situation")
corrected_style = json_obj.get("corrected_style")

View File

@ -11,6 +11,7 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
@ -261,6 +262,7 @@ class BrainPlanner:
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作ReAct模式
"""
plan_start = time.perf_counter()
# 获取聊天上下文
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
@ -298,6 +300,7 @@ class BrainPlanner:
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
prompt_build_start = time.perf_counter()
# 构建包含所有动作的提示词:使用统一的 ReAct Prompt
prompt_key = "brain_planner_prompt_react"
# 这里不记录日志,避免重复打印,由调用方按需控制 log_prompt
@ -308,9 +311,10 @@ class BrainPlanner:
message_id_list=message_id_list,
prompt_key=prompt_key,
)
prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000
# 调用LLM获取决策
reasoning, actions = await self._execute_main_planner(
reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner(
prompt=prompt,
message_id_list=message_id_list,
filtered_actions=filtered_actions,
@ -324,6 +328,25 @@ class BrainPlanner:
)
self.add_plan_log(reasoning, actions)
try:
PlanReplyLogger.log_plan(
chat_id=self.chat_id,
prompt=prompt,
reasoning=reasoning,
raw_output=llm_raw_output,
raw_reasoning=llm_reasoning,
actions=actions,
timing={
"prompt_build_ms": round(prompt_build_ms, 2),
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
"total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2),
"loop_start_time": loop_start_time,
},
extra=None,
)
except Exception:
logger.exception(f"{self.log_prefix}记录plan日志失败")
return actions
async def build_planner_prompt(
@ -421,7 +444,7 @@ class BrainPlanner:
if action_info.activation_type == ActionActivationType.NEVER:
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
continue
elif action_info.activation_type in [ActionActivationType.LLM_JUDGE, ActionActivationType.ALWAYS]:
elif action_info.activation_type == ActionActivationType.ALWAYS:
filtered_actions[action_name] = action_info
elif action_info.activation_type == ActionActivationType.RANDOM:
if random.random() < action_info.random_activation_probability:
@ -479,15 +502,20 @@ class BrainPlanner:
filtered_actions: Dict[str, ActionInfo],
available_actions: Dict[str, ActionInfo],
loop_start_time: float,
) -> Tuple[str, List[ActionPlannerInfo]]:
) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]:
"""执行主规划器"""
llm_content = None
actions: List[ActionPlannerInfo] = []
extracted_reasoning = ""
llm_reasoning = None
llm_duration_ms = None
try:
# 调用LLM
llm_start = time.perf_counter()
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
llm_reasoning = reasoning_content
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
@ -514,7 +542,7 @@ class BrainPlanner:
action_message=None,
available_actions=available_actions,
)
]
], llm_content, llm_reasoning, llm_duration_ms
# 解析LLM响应
if llm_content:
@ -553,7 +581,7 @@ class BrainPlanner:
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
return extracted_reasoning, actions
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms
def _create_complete_talk(
self, reasoning: str, available_actions: Dict[str, ActionInfo]

View File

@ -382,7 +382,7 @@ class EmojiManager:
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see")
self.llm_emotion_judge = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度
)
self.emoji_num = 0
self.emoji_num_max = global_config.emoji.max_reg_num

View File

@ -30,7 +30,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
)
from src.chat.utils.utils import record_replyer_action_temp
from src.hippo_memorizer.chat_history_summarizer import ChatHistorySummarizer
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
@ -244,12 +244,14 @@ class HeartFChatting:
thinking_id,
actions,
selected_expressions: Optional[List[int]] = None,
quote_message: Optional[bool] = None,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers):
reply_text = await self._send_response(
reply_set=response_set,
message_data=action_message,
selected_expressions=selected_expressions,
quote_message=quote_message,
)
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
@ -526,15 +528,26 @@ class HeartFChatting:
reply_set: "ReplySetModel",
message_data: "DatabaseMessages",
selected_expressions: Optional[List[int]] = None,
quote_message: Optional[bool] = None,
) -> str:
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
need_reply = new_message_count >= random.randint(2, 3)
if need_reply:
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
# 根据 llm_quote 配置决定是否使用 quote_message 参数
if global_config.chat.llm_quote:
# 如果配置为 true使用 llm_quote 参数决定是否引用回复
if quote_message is None:
logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用")
need_reply = False
else:
need_reply = quote_message
if need_reply:
logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
else:
# 如果配置为 false使用原来的模式
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
need_reply = new_message_count >= random.randint(2, 3)
if need_reply:
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
reply_text = ""
first_replied = False
@ -640,6 +653,7 @@ class HeartFChatting:
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
unknown_words = None
quote_message = None
if isinstance(action_planner_info.action_data, dict):
uw = action_planner_info.action_data.get("unknown_words")
if isinstance(uw, list):
@ -651,6 +665,19 @@ class HeartFChatting:
cleaned_uw.append(s)
if cleaned_uw:
unknown_words = cleaned_uw
# 从 Planner 的 action_data 中提取 quote_message 参数
qm = action_planner_info.action_data.get("quote")
if qm is not None:
# 支持多种格式true/false, "true"/"false", 1/0
if isinstance(qm, bool):
quote_message = qm
elif isinstance(qm, str):
quote_message = qm.lower() in ("true", "1", "yes")
elif isinstance(qm, (int, float)):
quote_message = bool(qm)
logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream,
@ -682,12 +709,13 @@ class HeartFChatting:
thinking_id=thinking_id,
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
quote_message=quote_message,
)
self.last_active_time = time.time()
return {
"action_type": "reply",
"success": True,
"result": f"回复内容{reply_text}",
"result": f"使用reply动作' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'",
"loop_info": loop_info,
}

View File

@ -7,9 +7,8 @@ from .kg_manager import KGManager
# from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding
from src.config.config import global_config, model_config
from src.config.config import global_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@ -22,7 +21,6 @@ class QAManager:
):
self.embed_manager = embed_manager
self.kg_manager = kg_manager
self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
async def process_query(
self, question: str

View File

@ -0,0 +1,139 @@
import json
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import uuid4
from src.config.config import global_config
class PlanReplyLogger:
"""独立的Plan/Reply日志记录器负责落盘和容量控制。"""
_BASE_DIR = Path("logs")
_PLAN_DIR = _BASE_DIR / "plan"
_REPLY_DIR = _BASE_DIR / "reply"
_TRIM_COUNT = 100
@classmethod
def _get_max_per_chat(cls) -> int:
"""从配置中获取每个聊天流最大保存的日志数量"""
return getattr(global_config.chat, "plan_reply_log_max_per_chat", 1000)
@classmethod
def log_plan(
cls,
chat_id: str,
prompt: str,
reasoning: str,
raw_output: Optional[str],
raw_reasoning: Optional[str],
actions: List[Any],
timing: Optional[Dict[str, Any]] = None,
extra: Optional[Dict[str, Any]] = None,
) -> None:
payload = {
"type": "plan",
"chat_id": chat_id,
"timestamp": time.time(),
"prompt": prompt,
"reasoning": reasoning,
"raw_output": raw_output,
"raw_reasoning": raw_reasoning,
"actions": [cls._serialize_action(action) for action in actions],
"timing": timing or {},
"extra": cls._safe_data(extra),
}
cls._write_json(cls._PLAN_DIR, chat_id, payload)
@classmethod
def log_reply(
cls,
chat_id: str,
prompt: str,
output: Optional[str],
processed_output: Optional[List[Any]],
model: Optional[str],
timing: Optional[Dict[str, Any]] = None,
reasoning: Optional[str] = None,
think_level: Optional[int] = None,
error: Optional[str] = None,
success: bool = True,
) -> None:
payload = {
"type": "reply",
"chat_id": chat_id,
"timestamp": time.time(),
"prompt": prompt,
"output": output,
"processed_output": cls._safe_data(processed_output),
"model": model,
"reasoning": reasoning,
"think_level": think_level,
"timing": timing or {},
"error": error if not success else None,
"success": success,
}
cls._write_json(cls._REPLY_DIR, chat_id, payload)
@classmethod
def _write_json(cls, base_dir: Path, chat_id: str, payload: Dict[str, Any]) -> None:
chat_dir = base_dir / chat_id
chat_dir.mkdir(parents=True, exist_ok=True)
file_path = chat_dir / f"{int(time.time() * 1000)}_{uuid4().hex[:8]}.json"
try:
with file_path.open("w", encoding="utf-8") as f:
json.dump(cls._safe_data(payload), f, ensure_ascii=False, indent=2)
finally:
cls._trim_overflow(chat_dir)
@classmethod
def _trim_overflow(cls, chat_dir: Path) -> None:
"""超过阈值时删除最老的若干文件,避免目录无限增长。"""
files = sorted(chat_dir.glob("*.json"), key=lambda p: p.stat().st_mtime)
max_per_chat = cls._get_max_per_chat()
if len(files) <= max_per_chat:
return
# 删除最老的 TRIM_COUNT 条
for old_file in files[: cls._TRIM_COUNT]:
try:
old_file.unlink()
except FileNotFoundError:
continue
@classmethod
def _serialize_action(cls, action: Any) -> Dict[str, Any]:
# ActionPlannerInfo 结构的轻量序列化,避免引用复杂对象
message_info = None
action_message = getattr(action, "action_message", None)
if action_message:
user_info = getattr(action_message, "user_info", None)
message_info = {
"message_id": getattr(action_message, "message_id", None),
"user_id": getattr(user_info, "user_id", None) if user_info else None,
"platform": getattr(user_info, "platform", None) if user_info else None,
"text": getattr(action_message, "processed_plain_text", None),
}
return {
"action_type": getattr(action, "action_type", None),
"reasoning": getattr(action, "reasoning", None),
"action_data": cls._safe_data(getattr(action, "action_data", None)),
"action_message": message_info,
"available_actions": cls._safe_data(getattr(action, "available_actions", None)),
"action_reasoning": getattr(action, "action_reasoning", None),
}
@classmethod
def _safe_data(cls, value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, dict):
return {str(k): cls._safe_data(v) for k, v in value.items()}
if isinstance(value, (list, tuple, set)):
return [cls._safe_data(v) for v in value]
if isinstance(value, Path):
return str(value)
# Fallback to string for other complex types
return str(value)

View File

@ -1,12 +1,9 @@
import random
import asyncio
import hashlib
import time
from typing import List, Dict, TYPE_CHECKING, Tuple
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
@ -35,14 +32,6 @@ class ActionModifier:
self.action_manager = action_manager
# 用于LLM判定的小模型
self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge")
# 缓存相关属性
self._llm_judge_cache = {} # 缓存LLM判定结果
self._cache_expiry_time = 30 # 缓存过期时间(秒)
self._last_context_hash = None # 上次上下文的哈希值
async def modify_actions(
self,
message_content: str = "",
@ -159,9 +148,6 @@ class ActionModifier:
"""
deactivated_actions = []
# 分类处理不同激活类型的actions
llm_judge_actions: Dict[str, ActionInfo] = {}
actions_to_check = list(actions_with_info.items())
random.shuffle(actions_to_check)
@ -185,9 +171,6 @@ class ActionModifier:
deactivated_actions.append((action_name, reason))
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
elif activation_type == ActionActivationType.LLM_JUDGE:
llm_judge_actions[action_name] = action_info
elif activation_type == ActionActivationType.NEVER:
reason = "激活类型为never"
deactivated_actions.append((action_name, reason))
@ -196,194 +179,8 @@ class ActionModifier:
else:
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
# 并行处理LLM_JUDGE类型
if llm_judge_actions:
llm_results = await self._process_llm_judge_actions_parallel(
llm_judge_actions,
chat_content,
)
for action_name, should_activate in llm_results.items():
if not should_activate:
reason = "LLM判定未激活"
deactivated_actions.append((action_name, reason))
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
return deactivated_actions
def _generate_context_hash(self, chat_content: str) -> str:
"""生成上下文的哈希值用于缓存"""
context_content = f"{chat_content}"
return hashlib.md5(context_content.encode("utf-8")).hexdigest()
async def _process_llm_judge_actions_parallel(
self,
llm_judge_actions: Dict[str, ActionInfo],
chat_content: str = "",
) -> Dict[str, bool]:
"""
并行处理LLM判定actions支持智能缓存
Args:
llm_judge_actions: 需要LLM判定的actions
chat_content: 聊天内容
Returns:
Dict[str, bool]: action名称到激活结果的映射
"""
# 生成当前上下文的哈希值
current_context_hash = self._generate_context_hash(chat_content)
current_time = time.time()
results = {}
tasks_to_run: Dict[str, ActionInfo] = {}
# 检查缓存
for action_name, action_info in llm_judge_actions.items():
cache_key = f"{action_name}_{current_context_hash}"
# 检查是否有有效的缓存
if (
cache_key in self._llm_judge_cache
and current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time
):
results[action_name] = self._llm_judge_cache[cache_key]["result"]
logger.debug(
f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}"
)
else:
# 需要进行LLM判定
tasks_to_run[action_name] = action_info
# 如果有需要运行的任务,并行执行
if tasks_to_run:
logger.debug(f"{self.log_prefix}并行执行LLM判定任务数: {len(tasks_to_run)}")
# 创建并行任务
tasks = []
task_names = []
for action_name, action_info in tasks_to_run.items():
task = self._llm_judge_action(
action_name,
action_info,
chat_content,
)
tasks.append(task)
task_names.append(action_name)
# 并行执行所有任务
try:
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果并更新缓存
for action_name, result in zip(task_names, task_results, strict=False):
if isinstance(result, Exception):
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
results[action_name] = False
else:
results[action_name] = result
# 更新缓存
cache_key = f"{action_name}_{current_context_hash}"
self._llm_judge_cache[cache_key] = {"result": result, "timestamp": current_time}
logger.debug(f"{self.log_prefix}并行LLM判定完成耗时: {time.time() - current_time:.2f}s")
except Exception as e:
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
# 如果并行执行失败为所有任务返回False
for action_name in tasks_to_run:
results[action_name] = False
# 清理过期缓存
self._cleanup_expired_cache(current_time)
return results
def _cleanup_expired_cache(self, current_time: float):
"""清理过期的缓存条目"""
expired_keys = []
expired_keys.extend(
cache_key
for cache_key, cache_data in self._llm_judge_cache.items()
if current_time - cache_data["timestamp"] > self._cache_expiry_time
)
for key in expired_keys:
del self._llm_judge_cache[key]
if expired_keys:
logger.debug(f"{self.log_prefix}清理了 {len(expired_keys)} 个过期缓存条目")
async def _llm_judge_action(
self,
action_name: str,
action_info: ActionInfo,
chat_content: str = "",
) -> bool: # sourcery skip: move-assign-in-block, use-named-expression
"""
使用LLM判定是否应该激活某个action
Args:
action_name: 动作名称
action_info: 动作信息
observed_messages_str: 观察到的聊天消息
chat_context: 聊天上下文
extra_context: 额外上下文
Returns:
bool: 是否应该激活此action
"""
try:
# 构建判定提示词
action_description = action_info.description
action_require = action_info.action_require
custom_prompt = action_info.llm_judge_prompt
# 构建基础判定提示词
base_prompt = f"""
你需要判断在当前聊天情况下是否应该激活名为"{action_name}"的动作
动作描述{action_description}
动作使用场景
"""
for req in action_require:
base_prompt += f"- {req}\n"
if custom_prompt:
base_prompt += f"\n额外判定条件:\n{custom_prompt}\n"
if chat_content:
base_prompt += f"\n当前聊天记录:\n{chat_content}\n"
base_prompt += """
请根据以上信息判断是否应该激活这个动作
只需要回答""""不要有其他内容
"""
# 调用LLM进行判定
response, _ = await self.llm_judge.generate_response_async(prompt=base_prompt)
# 解析响应
response = response.strip().lower()
# print(base_prompt)
# print(f"LLM判定动作 {action_name}:响应='{response}'")
should_activate = "" in response or "yes" in response or "true" in response
logger.debug(
f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}"
)
return should_activate
except Exception as e:
logger.error(f"{self.log_prefix}LLM判定动作 {action_name} 时出错: {e}")
# 出错时默认不激活
return False
def _check_keyword_activation(
self,
action_name: str,

View File

@ -10,6 +10,7 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
@ -17,7 +18,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
@ -52,7 +53,7 @@ reply
4.不要选择回复你自己发送的消息
5.不要单独对表情包进行回复
6.将上下文中所有含义不明的疑似黑话的缩写词均写入unknown_words中
7.用一句简单的话来描述当前回复场景不超过10个字
7.如果你对上下文存在疑问有需要查询的问题写入question中
{reply_action_example}
no_reply
@ -223,6 +224,25 @@ class ActionPlanner:
else:
reasoning = "未提供原因"
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
# 验证和清理 question
if "question" in action_data:
q = action_data.get("question")
if isinstance(q, str):
cleaned_q = q.strip()
if cleaned_q:
action_data["question"] = cleaned_q
else:
# 如果清理后为空字符串,移除该字段
action_data.pop("question", None)
elif q is None:
# 如果为 None移除该字段
action_data.pop("question", None)
else:
# 如果不是字符串类型,记录警告并移除
logger.warning(f"{self.log_prefix}question 格式不正确,应为字符串类型,已忽略")
action_data.pop("question", None)
# 非no_reply动作需要target_message_id
target_message = None
@ -291,11 +311,9 @@ class ActionPlanner:
return action_planner_infos
def _is_message_from_self(self, message: "DatabaseMessages") -> bool:
"""判断消息是否由机器人自身发送"""
"""判断消息是否由机器人自身发送(支持多平台,包括 WebUI"""
try:
return str(message.user_info.user_id) == str(global_config.bot.qq_account) and (
message.user_info.platform or ""
) == (global_config.bot.platform or "")
return is_bot_self(message.user_info.platform or "", str(message.user_info.user_id))
except AttributeError:
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
return False
@ -310,6 +328,7 @@ class ActionPlanner:
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作
"""
plan_start = time.perf_counter()
# 获取聊天上下文
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
@ -345,6 +364,7 @@ class ActionPlanner:
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
prompt_build_start = time.perf_counter()
# 构建包含所有动作的提示词
prompt, message_id_list = await self.build_planner_prompt(
is_group_chat=is_group_chat,
@ -353,9 +373,10 @@ class ActionPlanner:
chat_content_block=chat_content_block,
message_id_list=message_id_list,
)
prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000
# 调用LLM获取决策
reasoning, actions = await self._execute_main_planner(
reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner(
prompt=prompt,
message_id_list=message_id_list,
filtered_actions=filtered_actions,
@ -397,6 +418,25 @@ class ActionPlanner:
self.add_plan_log(reasoning, actions)
try:
PlanReplyLogger.log_plan(
chat_id=self.chat_id,
prompt=prompt,
reasoning=reasoning,
raw_output=llm_raw_output,
raw_reasoning=llm_reasoning,
actions=actions,
timing={
"prompt_build_ms": round(prompt_build_ms, 2),
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
"total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2),
"loop_start_time": loop_start_time,
},
extra=None,
)
except Exception:
logger.exception(f"{self.log_prefix}记录plan日志失败")
return actions
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
@ -480,19 +520,34 @@ class ActionPlanner:
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
# 根据 think_mode 配置决定 reply action 的示例 JSON
# 在 JSON 中直接作为 action 参数携带 unknown_words
# 在 JSON 中直接作为 action 参数携带 unknown_words 和 question
if global_config.chat.think_mode == "classic":
reply_action_example = (
reply_action_example = ""
if global_config.chat.llm_quote:
reply_action_example += "5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
reply_action_example += (
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
'"unknown_words":["词语1","词语2"]}}'
'"unknown_words":["词语1","词语2"], '
'"question":"需要查询的问题"'
)
if global_config.chat.llm_quote:
reply_action_example += ', "quote":"如果需要引用该message设置为true"'
reply_action_example += "}"
else:
reply_action_example = (
"5.think_level表示思考深度0表示该回复不需要思考和回忆1表示该回复需要进行回忆和思考\n"
+ '{{"action":"reply", "think_level":数值等级(0或1), '
'"target_message_id":"消息id(m+数字)", '
'"unknown_words":["词语1","词语2"]}}'
)
if global_config.chat.llm_quote:
reply_action_example += "6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
reply_action_example += (
'{{"action":"reply", "think_level":数值等级(0或1), '
'"target_message_id":"消息id(m+数字)", '
'"unknown_words":["词语1","词语2"], '
'"question":"需要查询的问题"'
)
if global_config.chat.llm_quote:
reply_action_example += ', "quote":"如果需要引用该message设置为true"'
reply_action_example += "}"
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
@ -547,7 +602,7 @@ class ActionPlanner:
if action_info.activation_type == ActionActivationType.NEVER:
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
continue
elif action_info.activation_type in [ActionActivationType.LLM_JUDGE, ActionActivationType.ALWAYS]:
elif action_info.activation_type == ActionActivationType.ALWAYS:
filtered_actions[action_name] = action_info
elif action_info.activation_type == ActionActivationType.RANDOM:
if random.random() < action_info.random_activation_probability:
@ -610,14 +665,19 @@ class ActionPlanner:
filtered_actions: Dict[str, ActionInfo],
available_actions: Dict[str, ActionInfo],
loop_start_time: float,
) -> Tuple[str, List[ActionPlannerInfo]]:
) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]:
"""执行主规划器"""
llm_content = None
actions: List[ActionPlannerInfo] = []
llm_reasoning = None
llm_duration_ms = None
try:
# 调用LLM
llm_start = time.perf_counter()
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
llm_reasoning = reasoning_content
if global_config.debug.show_planner_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
@ -640,7 +700,7 @@ class ActionPlanner:
action_message=None,
available_actions=available_actions,
)
]
], llm_content, llm_reasoning, llm_duration_ms
# 解析LLM响应
extracted_reasoning = ""
@ -685,7 +745,7 @@ class ActionPlanner:
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
return extracted_reasoning, actions
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
"""创建no_reply"""

View File

@ -16,7 +16,7 @@ from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, Message
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
@ -31,6 +31,7 @@ from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
@ -74,6 +75,7 @@ class DefaultReplyer:
reply_time_point: Optional[float] = time.time(),
think_level: int = 1,
unknown_words: Optional[List[str]] = None,
log_reply: bool = True,
) -> Tuple[bool, LLMGenerationDataModel]:
# sourcery skip: merge-nested-ifs
"""
@ -92,6 +94,9 @@ class DefaultReplyer:
Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: (是否成功, 生成的回复, 使用的prompt)
"""
overall_start = time.perf_counter()
prompt_duration_ms: Optional[float] = None
llm_duration_ms: Optional[float] = None
prompt = None
selected_expressions: Optional[List[int]] = None
llm_response = LLMGenerationDataModel()
@ -101,6 +106,7 @@ class DefaultReplyer:
# 3. 构建 Prompt
timing_logs = []
almost_zero_str = ""
prompt_start = time.perf_counter()
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt, selected_expressions, timing_logs, almost_zero_str = await self.build_prompt_reply_context(
extra_info=extra_info,
@ -113,11 +119,37 @@ class DefaultReplyer:
think_level=think_level,
unknown_words=unknown_words,
)
prompt_duration_ms = (time.perf_counter() - prompt_start) * 1000
llm_response.prompt = prompt
llm_response.selected_expressions = selected_expressions
llm_response.timing = {
"prompt_ms": round(prompt_duration_ms or 0.0, 2),
"overall_ms": None, # 占位,稍后写入
}
llm_response.timing_logs = timing_logs
llm_response.timing["timing_logs"] = timing_logs
if not prompt:
logger.warning("构建prompt失败跳过回复生成")
llm_response.timing["overall_ms"] = round((time.perf_counter() - overall_start) * 1000, 2)
llm_response.timing["almost_zero"] = almost_zero_str
llm_response.timing["timing_logs"] = timing_logs
if log_reply:
try:
PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id,
prompt="",
output=None,
processed_output=None,
model=None,
timing=llm_response.timing,
reasoning=None,
think_level=think_level,
error="build_prompt_failed",
success=False,
)
except Exception:
logger.exception("记录reply日志失败")
return False, llm_response
from src.plugin_system.core.events_manager import events_manager
@ -137,7 +169,9 @@ class DefaultReplyer:
model_name = "unknown_model"
try:
llm_start = time.perf_counter()
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
# logger.debug(f"replyer生成内容: {content}")
# 统一输出所有日志信息使用try-except确保即使某个步骤出错也能输出
@ -161,6 +195,26 @@ class DefaultReplyer:
llm_response.reasoning = reasoning_content
llm_response.model = model_name
llm_response.tool_calls = tool_call
llm_response.timing["llm_ms"] = round(llm_duration_ms or 0.0, 2)
llm_response.timing["overall_ms"] = round((time.perf_counter() - overall_start) * 1000, 2)
llm_response.timing_logs = timing_logs
llm_response.timing["timing_logs"] = timing_logs
llm_response.timing["almost_zero"] = almost_zero_str
try:
if log_reply:
PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id,
prompt=prompt,
output=content,
processed_output=None,
model=model_name,
timing=llm_response.timing,
reasoning=reasoning_content,
think_level=think_level,
success=True,
)
except Exception:
logger.exception("记录reply日志失败")
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
)
@ -194,6 +248,27 @@ class DefaultReplyer:
except Exception as log_e:
logger.warning(f"输出日志时出错: {log_e}")
llm_response.timing["llm_ms"] = round(llm_duration_ms or 0.0, 2)
llm_response.timing["overall_ms"] = round((time.perf_counter() - overall_start) * 1000, 2)
llm_response.timing_logs = timing_logs
llm_response.timing["timing_logs"] = timing_logs
llm_response.timing["almost_zero"] = almost_zero_str
if log_reply:
try:
PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id,
prompt=prompt or "",
output=None,
processed_output=None,
model=model_name,
timing=llm_response.timing,
reasoning=None,
think_level=think_level,
error=str(llm_e),
success=False,
)
except Exception:
logger.exception("记录reply日志失败")
return False, llm_response # LLM 调用失败则无法生成回复
return True, llm_response
@ -541,107 +616,6 @@ class DefaultReplyer:
logger.error(f"上下文黑话解释失败: {e}")
return ""
def build_chat_history_prompts(
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""
Args:
message_list_before_now: 历史消息列表
target_user_id: 目标用户ID当前对话对象
Returns:
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
"""
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt = build_readable_messages(
latest_msgs,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=True,
)
return all_dialogue_prompt
def core_background_build_chat_history_prompts(
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""
Args:
message_list_before_now: 历史消息列表
target_user_id: 目标用户ID当前对话对象
Returns:
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
"""
core_dialogue_list: List[DatabaseMessages] = []
bot_id = str(global_config.bot.qq_account)
# 过滤消息分离bot和目标用户的对话 vs 其他用户的对话
for msg in message_list_before_now:
try:
msg_user_id = str(msg.user_info.user_id)
reply_to = msg.reply_to
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
# bot 和目标用户的对话
core_dialogue_list.append(msg)
except Exception as e:
logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
# 构建核心对话 prompt
core_dialogue_prompt = ""
if core_dialogue_list:
# 检查最新五条消息中是否包含bot自己说的消息
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.user_info.user_id) == bot_id for msg in latest_5_messages)
# logger.info(f"最新五条消息:{latest_5_messages}")
# logger.info(f"最新五条消息中是否包含bot自己说的消息{has_bot_message}")
# 如果最新五条消息中不包含bot的消息则返回空字符串
if not has_bot_message:
core_dialogue_prompt = ""
else:
core_dialogue_list = core_dialogue_list[
-int(global_config.chat.max_context_size * 0.6) :
] # 限制消息数量
core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是上述中你和{sender}的对话摘要内容从上面的对话中截取便于你理解
{core_dialogue_prompt_str}
--------------------------------
"""
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages(
latest_25_msgs,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=True,
)
if core_dialogue_prompt:
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
else:
all_dialogue_prompt = f"{all_dialogue_prompt_str}"
return core_dialogue_prompt, all_dialogue_prompt
async def build_actions_prompt(
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
) -> str:
@ -841,10 +815,8 @@ class DefaultReplyer:
person_list_short: List[Person] = []
for msg in message_list_before_short:
if (
global_config.bot.qq_account == msg.user_info.user_id
and global_config.bot.platform == msg.user_info.platform
):
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
continue
if (
reply_message
@ -865,6 +837,7 @@ class DefaultReplyer:
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
long_time_notice=True,
)
# 统一黑话解释构建:根据配置选择上下文或 Planner 模式
@ -872,6 +845,18 @@ class DefaultReplyer:
chat_id, message_list_before_short, chat_talking_prompt_short, unknown_words
)
# 从 chosen_actions 中提取 question仅在 reply 动作中)
question = None
if chosen_actions:
for action_info in chosen_actions:
if action_info.action_type == "reply" and isinstance(action_info.action_data, dict):
q = action_info.action_data.get("question")
if isinstance(q, str):
cleaned_q = q.strip()
if cleaned_q:
question = cleaned_q
break
# 并行执行构建任务(包括黑话解释,可配置关闭)
task_results = await asyncio.gather(
self._time_and_run_task(
@ -886,7 +871,7 @@ class DefaultReplyer:
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
self._time_and_run_task(
build_memory_retrieval_prompt(
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=think_level
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=think_level, unknown_words=unknown_words, question=question
),
"memory_retrieval",
),
@ -960,8 +945,16 @@ class DefaultReplyer:
else:
reply_target_block = ""
# 构建分离的对话 prompt
dialogue_prompt = self.build_chat_history_prompts(message_list_before_now_long, user_id, sender)
if message_list_before_now_long:
latest_msgs = message_list_before_now_long[-int(global_config.chat.max_context_size) :]
dialogue_prompt = build_readable_messages(
latest_msgs,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=True,
long_time_notice=True,
)
# 获取匹配的额外prompt
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)

View File

@ -16,7 +16,7 @@ from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, Message
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
@ -76,6 +76,7 @@ class PrivateReplyer:
reply_message: Optional[DatabaseMessages] = None,
reply_time_point: Optional[float] = time.time(),
unknown_words: Optional[List[str]] = None,
log_reply: bool = True,
) -> Tuple[bool, LLMGenerationDataModel]:
# sourcery skip: merge-nested-ifs
"""
@ -109,6 +110,7 @@ class PrivateReplyer:
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,
unknown_words=unknown_words,
)
llm_response.prompt = prompt
llm_response.selected_expressions = selected_expressions
@ -610,6 +612,7 @@ class PrivateReplyer:
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
enable_tool: bool = True,
unknown_words: Optional[List[str]] = None,
) -> Tuple[str, List[int]]:
"""
构建回复器上下文
@ -664,6 +667,7 @@ class PrivateReplyer:
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
long_time_notice=True
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
@ -675,10 +679,8 @@ class PrivateReplyer:
person_list_short: List[Person] = []
for msg in message_list_before_short:
if (
global_config.bot.qq_account == msg.user_info.user_id
and global_config.bot.platform == msg.user_info.platform
):
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
continue
if (
reply_message
@ -708,12 +710,24 @@ class PrivateReplyer:
else:
jargon_coroutine = self._build_disabled_jargon_explanation()
# 从 chosen_actions 中提取 question仅在 reply 动作中)
question = None
if chosen_actions:
for action_info in chosen_actions:
if action_info.action_type == "reply" and isinstance(action_info.action_data, dict):
q = action_info.action_data.get("question")
if isinstance(q, str):
cleaned_q = q.strip()
if cleaned_q:
question = cleaned_q
break
# 并行执行九个构建任务(包括黑话解释,可配置关闭)
task_results = await asyncio.gather(
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
),
self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"),
# self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
),
@ -722,7 +736,7 @@ class PrivateReplyer:
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
self._time_and_run_task(
build_memory_retrieval_prompt(
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=1, unknown_words=unknown_words, question=question
),
"memory_retrieval",
),
@ -759,7 +773,7 @@ class PrivateReplyer:
expression_habits_block, selected_expressions = results_dict["expression_habits"]
expression_habits_block: str
selected_expressions: List[int]
relation_info: str = results_dict["relation_info"]
relation_info: str = results_dict.get("relation_info") or ""
tool_info: str = results_dict["tool_info"]
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
actions_info: str = results_dict["actions_info"]
@ -796,7 +810,19 @@ class PrivateReplyer:
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)
chat_prompt_block = f"{chat_prompt_content}\n" if chat_prompt_content else ""
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
# 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换
reply_style = global_config.personality.reply_style
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
try:
reply_style = random.choice(list(multi_styles))
except Exception:
# 兜底:即使 multiple_reply_style 配置异常也不影响正常回复
reply_style = global_config.personality.reply_style
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
if is_bot_self(platform, user_id):
return await global_prompt_manager.format_prompt(
"private_replyer_self_prompt",
expression_habits_block=expression_habits_block,
@ -812,7 +838,7 @@ class PrivateReplyer:
target=target,
reason=reply_reason,
sender_name=sender,
reply_style=global_config.personality.reply_style,
reply_style=reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
memory_retrieval=memory_retrieval,
@ -832,7 +858,7 @@ class PrivateReplyer:
jargon_explanation=jargon_explanation,
time_block=time_block,
reply_target_block=reply_target_block,
reply_style=global_config.personality.reply_style,
reply_style=reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
sender_name=sender,
@ -917,6 +943,17 @@ class PrivateReplyer:
template_name = "default_expressor_prompt"
# 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换
reply_style = global_config.personality.reply_style
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
try:
reply_style = random.choice(list(multi_styles))
except Exception:
# 兜底:即使 multiple_reply_style 配置异常也不影响正常回复
reply_style = global_config.personality.reply_style
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
@ -929,7 +966,7 @@ class PrivateReplyer:
reply_target_block=reply_target_block,
raw_reply=raw_reply,
reason=reason,
reply_style=global_config.personality.reply_style,
reply_style=reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
)

View File

@ -13,7 +13,7 @@ from src.common.data_models.message_data_model import MessageAndActionModel
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.person_info.person_info import Person, get_person_id
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self
install(extra_lines=3)
logger = get_logger("chat_message_builder")
@ -43,12 +43,9 @@ def replace_user_references(
if name_resolver is None:
def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己(支持多平台)
if replace_bot_name:
if platform == "qq" and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
if platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""):
return f"{global_config.bot.nickname}(你)"
# 检查是否是机器人自己(支持多平台,包括 WebUI
if replace_bot_name and is_bot_self(platform, user_id):
return f"{global_config.bot.nickname}(你)"
person = Person(platform=platform, user_id=user_id)
return person.person_name or user_id # type: ignore
@ -61,8 +58,8 @@ def replace_user_references(
aaa = match[1]
bbb = match[2]
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
# 检查是否是机器人自己(支持多平台,包括 WebUI
if replace_bot_name and is_bot_self(platform, bbb):
reply_person_name = f"{global_config.bot.nickname}(你)"
else:
reply_person_name = name_resolver(platform, bbb) or aaa
@ -370,6 +367,7 @@ def _build_readable_messages_internal(
show_pic: bool = True,
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
pic_single: bool = False,
long_time_notice: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
# sourcery skip: use-getitem-for-re-match-groups
"""
@ -467,10 +465,8 @@ def _build_readable_messages_internal(
person_name = (
person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
)
if replace_bot_name and (
(platform == global_config.bot.platform and user_id == global_config.bot.qq_account)
or (platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""))
):
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
if replace_bot_name and is_bot_self(platform, user_id):
person_name = f"{global_config.bot.nickname}(你)"
# 使用独立函数处理用户引用格式
@ -523,7 +519,30 @@ def _build_readable_messages_internal(
# 3: 格式化为字符串
output_lines: List[str] = []
prev_timestamp: Optional[float] = None
for timestamp, name, content, is_action in detailed_message:
# 检查是否需要插入长时间间隔提示
if long_time_notice and prev_timestamp is not None:
time_diff = timestamp - prev_timestamp
time_diff_hours = time_diff / 3600
# 检查是否跨天
prev_date = time.strftime("%Y-%m-%d", time.localtime(prev_timestamp))
current_date = time.strftime("%Y-%m-%d", time.localtime(timestamp))
is_cross_day = prev_date != current_date
# 如果间隔大于8小时或跨天插入提示
if time_diff_hours > 8 or is_cross_day:
# 格式化日期为中文格式xxxx年xx月xx日去掉前导零
current_time_struct = time.localtime(timestamp)
year = current_time_struct.tm_year
month = current_time_struct.tm_mon
day = current_time_struct.tm_mday
date_str = f"{year}{month}{day}"
hours_str = f"{int(time_diff_hours)}h"
notice = f"以下聊天开始时间:{date_str}。距离上一条消息过去了{hours_str}\n"
output_lines.append(notice)
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
# 查找消息id如果有并构建id_prefix
@ -536,6 +555,8 @@ def _build_readable_messages_internal(
else:
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
prev_timestamp = timestamp
formatted_string = "".join(output_lines).strip()
@ -651,6 +672,7 @@ async def build_readable_messages_with_list(
show_pic=True,
message_id_list=None,
pic_single=pic_single,
long_time_notice=False,
)
if not pic_single:
@ -704,6 +726,7 @@ def build_readable_messages(
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
remove_emoji_stickers: bool = False,
pic_single: bool = False,
long_time_notice: bool = False,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式
@ -719,6 +742,7 @@ def build_readable_messages(
truncate: 是否截断长消息
show_actions: 是否显示动作记录
remove_emoji_stickers: 是否移除表情包并过滤空消息
long_time_notice: 是否在消息间隔过长>8小时或跨天时插入时间提示
"""
# WIP HERE and BELOW ----------------------------------------------
# 创建messages的深拷贝避免修改原始列表
@ -812,6 +836,7 @@ def build_readable_messages(
show_pic=show_pic,
message_id_list=message_id_list,
pic_single=pic_single,
long_time_notice=long_time_notice,
)
if not pic_single:
@ -839,6 +864,7 @@ def build_readable_messages(
show_pic=show_pic,
message_id_list=message_id_list,
pic_single=pic_single,
long_time_notice=long_time_notice,
)
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
messages_after_mark,
@ -850,6 +876,7 @@ def build_readable_messages(
show_pic=show_pic,
message_id_list=message_id_list,
pic_single=pic_single,
long_time_notice=long_time_notice,
)
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"

View File

@ -743,13 +743,13 @@ class StatisticOutputTask(AsyncTask):
"""
if stats[TOTAL_REQ_CNT] <= 0:
return ""
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模型分类统计:",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数 每次调用平均Token",
]
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name
@ -764,6 +764,9 @@ class StatisticOutputTask(AsyncTask):
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 计算每次调用平均token
avg_tokens_per_call = tokens / count if count > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@ -771,6 +774,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens_per_call = _format_large_number(avg_tokens_per_call) if count > 0 else "N/A"
output.append(
data_fmt.format(
@ -784,6 +788,7 @@ class StatisticOutputTask(AsyncTask):
std_time_cost,
formatted_avg_count,
formatted_avg_tokens,
formatted_avg_tokens_per_call,
)
)
@ -797,13 +802,13 @@ class StatisticOutputTask(AsyncTask):
"""
if stats[TOTAL_REQ_CNT] <= 0:
return ""
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模块分类统计:",
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数 每次调用平均Token",
]
for module_name, count in sorted(stats[REQ_CNT_BY_MODULE].items()):
name = f"{module_name[:29]}..." if len(module_name) > 32 else module_name
@ -818,6 +823,9 @@ class StatisticOutputTask(AsyncTask):
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 计算每次调用平均token
avg_tokens_per_call = tokens / count if count > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@ -825,6 +833,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens_per_call = _format_large_number(avg_tokens_per_call) if count > 0 else "N/A"
output.append(
data_fmt.format(
@ -838,6 +847,7 @@ class StatisticOutputTask(AsyncTask):
std_time_cost,
formatted_avg_count,
formatted_avg_tokens,
formatted_avg_tokens_per_call,
)
)
@ -935,11 +945,12 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODEL][model_name] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODEL][model_name] / count, html=True) if count > 0 else 'N/A'}</td>"
f"</tr>"
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
]
if stat_data[REQ_CNT_BY_MODEL]
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
else ["<tr><td colspan='11' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 按请求类型分类统计
type_rows = "\n".join(
@ -955,11 +966,12 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_TYPE][req_type] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_TYPE][req_type] / count, html=True) if count > 0 else 'N/A'}</td>"
f"</tr>"
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
]
if stat_data[REQ_CNT_BY_TYPE]
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
else ["<tr><td colspan='11' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 按模块分类统计
module_rows = "\n".join(
@ -975,11 +987,12 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODULE][module_name] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODULE][module_name] / count, html=True) if count > 0 else 'N/A'}</td>"
f"</tr>"
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
]
if stat_data[REQ_CNT_BY_MODULE]
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
else ["<tr><td colspan='11' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
)
# 聊天消息统计
@ -1054,7 +1067,7 @@ class StatisticOutputTask(AsyncTask):
<h2>按模型分类统计</h2>
<div class=\"table-wrap\">
<table>
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时()</th><th>标准差()</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr></thead>
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时()</th><th>标准差()</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th><th>每次调用平均Token</th></tr></thead>
<tbody>
{model_rows}
</tbody>
@ -1065,7 +1078,7 @@ class StatisticOutputTask(AsyncTask):
<div class=\"table-wrap\">
<table>
<thead>
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时()</th><th>标准差()</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr>
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时()</th><th>标准差()</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th><th>每次调用平均Token</th></tr>
</thead>
<tbody>
{module_rows}
@ -1077,7 +1090,7 @@ class StatisticOutputTask(AsyncTask):
<div class=\"table-wrap\">
<table>
<thead>
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时()</th><th>标准差()</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr>
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时()</th><th>标准差()</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th><th>每次调用平均Token</th></tr>
</thead>
<tbody>
{type_rows}

View File

@ -67,6 +67,53 @@ def get_current_platform_account(platform: str, platform_accounts: dict[str, str
return platform_accounts.get(platform, "")
def is_bot_self(platform: str, user_id: str) -> bool:
"""判断给定的平台和用户ID是否是机器人自己
这个函数统一处理所有平台包括 QQTelegramWebUI 的机器人识别逻辑
Args:
platform: 消息平台 "qq", "telegram", "webui"
user_id: 用户ID
Returns:
bool: 如果是机器人自己则返回 True否则返回 False
"""
if not platform or not user_id:
return False
# 将 user_id 转为字符串进行比较
user_id_str = str(user_id)
# 获取机器人的 QQ 账号(主账号)
qq_account = str(global_config.bot.qq_account or "")
# QQ 平台:直接比较 QQ 账号
if platform == "qq":
return user_id_str == qq_account
# WebUI 平台:机器人回复时使用的是 QQ 账号,所以也比较 QQ 账号
if platform == "webui":
return user_id_str == qq_account
# 获取各平台账号映射
platforms_list = getattr(global_config.bot, "platforms", []) or []
platform_accounts = parse_platform_accounts(platforms_list)
# Telegram 平台
if platform == "telegram":
tg_account = platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
return user_id_str == tg_account if tg_account else False
# 其他平台:尝试从 platforms 配置中查找
platform_account = platform_accounts.get(platform, "")
if platform_account:
return user_id_str == platform_account
# 默认情况:与主 QQ 账号比较(兼容性)
return user_id_str == qq_account
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]:
"""检查消息是否提到了机器人(统一多平台实现)"""
text = message.processed_plain_text or ""

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Any
from . import BaseDataModel
@ -17,3 +17,6 @@ class LLMGenerationDataModel(BaseDataModel):
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional["ReplySetModel"] = None
timing: Optional[Dict[str, Any]] = None
processed_output: Optional[List[str]] = None
timing_logs: Optional[List[str]] = None

View File

@ -321,18 +321,14 @@ class Expression(BaseModel):
situation = TextField()
style = TextField()
# new mode fields
context = TextField(null=True)
content_list = TextField(null=True)
style_list = TextField(null=True) # 存储相似的 style格式与 content_list 相同JSON 数组)
count = IntegerField(default=1)
last_active_time = FloatField()
chat_id = TextField(index=True)
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
checked = BooleanField(default=False) # 是否已检查
rejected = BooleanField(default=False) # 是否被拒绝但未更新
modified_by = TextField(null=True) # 最后修改来源:'ai' 或 'user',为空表示未检查
class Meta:
table_name = "expression"

View File

@ -97,6 +97,9 @@ class TaskConfig(ConfigBase):
slow_threshold: float = 15.0
"""慢请求阈值(秒),超过此值会输出警告日志"""
selection_strategy: str = field(default="balance")
"""模型选择策略balance负载均衡或 random随机选择"""
@dataclass
class ModelTaskConfig(ConfigBase):
@ -105,9 +108,6 @@ class ModelTaskConfig(ConfigBase):
utils: TaskConfig
"""组件模型配置"""
utils_small: TaskConfig
"""组件小模型配置"""
replyer: TaskConfig
"""normal_chat首要回复模型模型配置"""
@ -132,9 +132,6 @@ class ModelTaskConfig(ConfigBase):
lpmm_rdf_build: TaskConfig
"""LPMM RDF构建模型配置"""
lpmm_qa: TaskConfig
"""LPMM问答模型配置"""
def get_task(self, task_name: str) -> TaskConfig:
"""获取指定任务的配置"""
if hasattr(self, task_name):

View File

@ -57,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.12.0"
MMC_VERSION = "0.12.1"
def get_key_comment(toml_table, key):

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Union
import types
T = TypeVar("T", bound="ConfigBase")
@ -108,6 +109,39 @@ class ConfigBase:
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理 Union/Optional 类型(包括 float | None 这种 Python 3.10+ 语法)
# 注意:
# - Optional[float] 等价于 Union[float, None]get_origin() 返回 typing.Union
# - float | None 是 types.UnionTypeget_origin() 返回 None
is_union_type = (
field_origin_type is Union # typing.Optional / typing.Union
or isinstance(field_type, types.UnionType) # Python 3.10+ 的 | 语法
)
if is_union_type:
union_args = field_type_args if field_type_args else get_args(field_type)
# 安全检查:只允许 T | None 形式的 Optional 类型,禁止 float | str 这种多类型 Union
non_none_types = [arg for arg in union_args if arg is not type(None)]
if len(non_none_types) > 1:
raise TypeError(
f"配置字段不支持多类型 Union如 float | str只支持 Optional 类型(如 float | None"
f"当前类型: {field_type}"
)
# 如果值是 None 且 None 在 Union 中,直接返回
if value is None and type(None) in union_args:
return None
# 尝试转换为非 None 的类型
for arg in union_args:
if arg is not type(None):
try:
return cls._convert_field(value, arg)
except (ValueError, TypeError):
continue
# 如果所有类型都转换失败,抛出异常
raise TypeError("Cannot convert value to any type in Union")
# 处理基础类型,例如 int, str 等
if field_origin_type is type(None) and value is None: # 处理Optional类型
return None

View File

@ -122,6 +122,12 @@ class ChatConfig(ConfigBase):
- dynamic: think_level由planner动态给出根据planner返回的think_level决定
"""
plan_reply_log_max_per_chat: int = 1024
"""每个聊天流最大保存的Plan/Reply日志数量超过此数量时会自动删除最老的日志"""
llm_quote: bool = False
"""是否在 reply action 中启用 quote 参数,启用后 LLM 可以控制是否引用消息"""
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
try:
@ -260,12 +266,32 @@ class MemoryConfig(ConfigBase):
agent_timeout_seconds: float = 120.0
"""Agent超时时间"""
enable_jargon_detection: bool = True
"""记忆检索过程中是否启用黑话识别"""
global_memory: bool = False
"""是否允许记忆检索在聊天记录中进行全局查询忽略当前chat_id仅对 search_chat_history 等工具生效)"""
global_memory_blacklist: list[str] = field(default_factory=lambda: [])
"""
全局记忆黑名单当启用全局记忆时不将特定聊天流纳入检索
格式: ["platform:id:type", ...]
示例:
[
"qq:1919810:private", # 排除特定私聊
"qq:114514:group", # 排除特定群聊
]
说明:
- 当启用全局记忆时黑名单中的聊天流不会被检索
- 当在黑名单中的聊天流进行查询时仅使用该聊天流的本地记忆
"""
planner_question: bool = True
"""
是否使用 Planner 提供的 question 作为记忆检索问题
- True: Planner reply 动作中提供了 question 直接使用该问题进行记忆检索跳过 LLM 生成问题的步骤
- False: 沿用旧模式使用 LLM 生成问题
"""
def __post_init__(self):
"""验证配置值"""
if self.max_agent_iterations < 1:
@ -303,10 +329,13 @@ class ExpressionConfig(ConfigBase):
格式: [["qq:12345:group", "qq:67890:private"]]
"""
reflect: bool = False
"""是否启用表达反思"""
expression_self_reflect: bool = False
"""是否启用自动表达优化"""
expression_manual_reflect: bool = False
"""是否启用手动表达优化"""
reflect_operator_id: str = ""
manual_reflect_operator_id: str = ""
"""表达反思操作员ID"""
allow_reflect: list[str] = field(default_factory=list)
@ -330,6 +359,34 @@ class ExpressionConfig(ConfigBase):
- "planner": 仅使用 Planner reply 动作中给出的 unknown_words 列表进行黑话检索
"""
expression_checked_only: bool = False
"""
是否仅选择已检查且未拒绝的表达方式
当设置为 true 只有 checked=True rejected=False 的表达方式才会被选择
当设置为 false 保留旧的筛选原则仅排除 rejected=True 的表达方式
"""
expression_auto_check_interval: int = 3600
"""
表达方式自动检查的间隔时间单位
默认值36001小时
"""
expression_auto_check_count: int = 10
"""
每次自动检查时随机选取的表达方式数量
默认值10
"""
expression_auto_check_custom_criteria: list[str] = field(default_factory=list)
"""
表达方式自动检查的额外自定义评估标准
格式: ["标准1", "标准2", "标准3", ...]
这些标准会被添加到评估提示词中作为额外的评估要求
默认值空列表
"""
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""
解析流配置字符串并生成对应的 chat_id

View File

@ -44,18 +44,6 @@ def get_random_dream_styles(count: int = 2) -> List[str]:
"""从梦境风格列表中随机选择指定数量的风格"""
return random.sample(DREAM_STYLES, min(count, len(DREAM_STYLES)))
def get_dream_summary_model() -> LLMRequest:
"""获取用于生成梦境总结的 utils 模型实例"""
global _dream_summary_model
if _dream_summary_model is None:
_dream_summary_model = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="dream.summary",
)
return _dream_summary_model
def init_dream_summary_prompt() -> None:
"""初始化梦境总结的提示词"""
Prompt(
@ -186,10 +174,12 @@ async def generate_dream_summary(
)
# 调用 utils 模型生成梦境
summary_model = get_dream_summary_model()
summary_model = LLMRequest(
model_set=model_config.model_task_config.replyer,
request_type="dream.summary",
)
dream_content, (reasoning, model_name, _) = await summary_model.generate_response_async(
dream_prompt,
max_tokens=512,
temperature=0.8,
)

View File

@ -1,362 +0,0 @@
"""
记忆遗忘任务
每5分钟进行一次遗忘检查根据不同的遗忘阶段删除记忆
"""
import time
import random
from typing import List
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
from src.manager.async_task_manager import AsyncTask
logger = get_logger("memory_forget_task")
class MemoryForgetTask(AsyncTask):
"""记忆遗忘任务每5分钟执行一次"""
def __init__(self):
# 每5分钟执行一次300秒
super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)
async def run(self):
"""执行遗忘检查"""
try:
current_time = time.time()
# logger.info("[记忆遗忘] 开始遗忘检查...")
# 执行4个阶段的遗忘检查
# await self._forget_stage_1(current_time)
# await self._forget_stage_2(current_time)
# await self._forget_stage_3(current_time)
# await self._forget_stage_4(current_time)
# logger.info("[记忆遗忘] 遗忘检查完成")
except Exception as e:
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
async def _forget_stage_1(self, current_time: float):
"""
第一次遗忘检查
搜集所有记忆还未被遗忘检查过forget_times=0且已经是30分钟之外的记忆
取count最高25%和最低25%删除然后标记被遗忘检查次数为1
"""
try:
# 30分钟 = 1800秒
time_threshold = current_time - 1800
# 查询符合条件的记忆forget_times=0 且 end_time < time_threshold
candidates = list(
ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高25%和最低25%
total_count = len(candidates)
delete_count = int(total_count * 0.25) # 25%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段1] 删除数量为0跳过")
return
# 选择要删除的记录处理count相同的情况随机选择
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重避免重复删除使用id去重
seen_ids = set()
unique_to_delete = []
for record in to_delete:
if record.id not in seen_ids:
seen_ids.add(record.id)
unique_to_delete.append(record)
to_delete = unique_to_delete
# 删除记录并更新forget_times
deleted_count = 0
for record in to_delete:
try:
record.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")
# 更新剩余记录的forget_times为1
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
# 批量更新
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
async def _forget_stage_2(self, current_time: float):
"""
第二次遗忘检查
搜集所有记忆遗忘检查为1且已经是8小时之外的记忆
取count最高7%和最低7%删除然后标记被遗忘检查次数为2
"""
try:
# 8小时 = 28800秒
time_threshold = current_time - 28800
# 查询符合条件的记忆forget_times=1 且 end_time < time_threshold
candidates = list(
ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高7%和最低7%
total_count = len(candidates)
delete_count = int(total_count * 0.07) # 7%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段2] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
try:
record.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")
# 更新剩余记录的forget_times为2
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
async def _forget_stage_3(self, current_time: float):
"""
第三次遗忘检查
搜集所有记忆遗忘检查为2且已经是48小时之外的记忆
取count最高5%和最低5%删除然后标记被遗忘检查次数为3
"""
try:
# 48小时 = 172800秒
time_threshold = current_time - 172800
# 查询符合条件的记忆forget_times=2 且 end_time < time_threshold
candidates = list(
ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高5%和最低5%
total_count = len(candidates)
delete_count = int(total_count * 0.05) # 5%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段3] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
try:
record.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")
# 更新剩余记录的forget_times为3
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
async def _forget_stage_4(self, current_time: float):
"""
第四次遗忘检查
搜集所有记忆遗忘检查为3且已经是7天之外的记忆
取count最高2%和最低2%删除然后标记被遗忘检查次数为4
"""
try:
# 7天 = 604800秒
time_threshold = current_time - 604800
# 查询符合条件的记忆forget_times=3 且 end_time < time_threshold
candidates = list(
ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
)
if not candidates:
logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
return
logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")
# 按count排序
candidates.sort(key=lambda x: x.count, reverse=True)
# 计算要删除的数量最高2%和最低2%
total_count = len(candidates)
delete_count = int(total_count * 0.02) # 2%
if delete_count == 0:
logger.debug("[记忆遗忘-阶段4] 删除数量为0跳过")
return
# 选择要删除的记录
to_delete = []
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
# 去重
to_delete = list(set(to_delete))
# 删除记录
deleted_count = 0
for record in to_delete:
try:
record.delete_instance()
deleted_count += 1
except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")
# 更新剩余记录的forget_times为4
to_delete_ids = {r.id for r in to_delete}
remaining = [r for r in candidates if r.id not in to_delete_ids]
if remaining:
ids_to_update = [r.id for r in remaining]
ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
logger.info(
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
)
except Exception as e:
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
def _handle_same_count_random(
self, candidates: List[ChatHistory], delete_count: int, mode: str
) -> List[ChatHistory]:
"""
处理count相同的情况随机选择要删除的记录
Args:
candidates: 候选记录列表已按count排序
delete_count: 要删除的数量
mode: "high" 表示选择最高count的记录"low" 表示选择最低count的记录
Returns:
要删除的记录列表
"""
if not candidates or delete_count == 0:
return []
to_delete = []
if mode == "high":
# 从最高count开始选择
start_idx = 0
while start_idx < len(candidates) and len(to_delete) < delete_count:
# 找到所有count相同的记录
current_count = candidates[start_idx].count
same_count_records = []
idx = start_idx
while idx < len(candidates) and candidates[idx].count == current_count:
same_count_records.append(candidates[idx])
idx += 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete)
if len(same_count_records) <= needed:
to_delete.extend(same_count_records)
else:
# 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx
else: # mode == "low"
# 从最低count开始选择
start_idx = len(candidates) - 1
while start_idx >= 0 and len(to_delete) < delete_count:
# 找到所有count相同的记录
current_count = candidates[start_idx].count
same_count_records = []
idx = start_idx
while idx >= 0 and candidates[idx].count == current_count:
same_count_records.append(candidates[idx])
idx -= 1
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
needed = delete_count - len(to_delete)
if len(same_count_records) <= needed:
to_delete.extend(same_count_records)
else:
# 随机选择需要的数量
to_delete.extend(random.sample(same_count_records, needed))
start_idx = idx
return to_delete

View File

@ -447,8 +447,20 @@ def _default_normal_response_parser(
for call in message_part.tool_calls:
try:
arguments = json.loads(repair_json(call.function.arguments))
# 【新增修复逻辑】如果解析出来还是字符串,说明发生了双重编码,尝试二次解析
if isinstance(arguments, str):
try:
# 尝试对字符串内容再次进行修复和解析
arguments = json.loads(repair_json(arguments))
except Exception:
# 如果二次解析失败,保留原值,让下方的 isinstance(dict) 抛出更具体的错误
pass
if not isinstance(arguments, dict):
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
# 此时为了调试方便,建议打印出 arguments 的类型
raise RespParseException(
resp,
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}"
)
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
except json.JSONDecodeError as e:
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e

View File

@ -1,6 +1,7 @@
import re
import asyncio
import time
import random
from enum import Enum
from rich.traceback import install
@ -266,7 +267,7 @@ class LLMRequest:
def _select_model(self, exclude_models: Optional[Set[str]] = None) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
根据总tokens和惩罚值选择的模型
根据配置的策略选择模型balance负载均衡 random随机选择
"""
available_models = {
model: scores
@ -276,15 +277,30 @@ class LLMRequest:
if not available_models:
raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。")
least_used_model_name = min(
available_models,
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
)
model_info = model_config.get_model_info(least_used_model_name)
strategy = self.model_for_task.selection_strategy.lower()
if strategy == "random":
# 随机选择策略
selected_model_name = random.choice(list(available_models.keys()))
elif strategy == "balance":
# 负载均衡策略根据总tokens和惩罚值选择
selected_model_name = min(
available_models,
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
)
else:
# 默认使用负载均衡策略
logger.warning(f"未知的选择策略 '{strategy}',使用默认的负载均衡策略")
selected_model_name = min(
available_models,
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
)
model_info = model_config.get_model_info(selected_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
force_new_client = self.request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"选择请求模型: {model_info.name}")
logger.debug(f"选择请求模型: {model_info.name} (策略: {strategy})")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
return model_info, api_provider, client

View File

@ -24,6 +24,7 @@ from src.plugin_system.core.plugin_manager import plugin_manager
# 导入消息API和traceback模块
from src.common.message import get_global_api
from src.dream.dream_agent import start_dream_scheduler
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
# 插件系统现在使用统一的插件加载器
@ -87,16 +88,11 @@ class MainSystem:
# 添加统计信息输出任务
await async_task_manager.add_task(StatisticOutputTask())
# 添加聊天流统计任务每5分钟生成一次报告统计最近30天的数据
# await async_task_manager.add_task(TokenStatisticsTask())
# 添加遥测心跳任务
await async_task_manager.add_task(TelemetryHeartBeatTask())
# 添加记忆遗忘任务
from src.hippo_memorizer.memory_forget_task import MemoryForgetTask
await async_task_manager.add_task(MemoryForgetTask())
# 添加表达方式自动检查任务
await async_task_manager.add_task(ExpressionAutoCheckTask())
# 启动API服务器
# start_api_server()

View File

@ -15,10 +15,11 @@ from json_repair import repair_json
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import global_config, model_config
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import message_api
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.utils.utils import is_bot_self
from src.person_info.person_info import Person
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
@ -415,11 +416,11 @@ class ChatHistorySummarizer:
# 1. 检查当前批次内是否有 bot 发言(只检查当前批次,不往前推)
# 原因:我们要记录的是 bot 参与过的对话片段,如果当前批次内 bot 没有发言,
# 说明 bot 没有参与这段对话,不应该记录
bot_user_id = str(global_config.bot.qq_account)
has_bot_message = False
for msg in messages:
if msg.user_info.user_id == bot_user_id:
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
has_bot_message = True
break
@ -848,11 +849,7 @@ class ChatHistorySummarizer:
)
try:
response, _ = await self.summarizer_llm.generate_response_async(
prompt=prompt,
temperature=0.3,
max_tokens=500,
)
response, _ = await self.summarizer_llm.generate_response_async(prompt=prompt)
# 解析JSON响应
json_str = response.strip()
@ -912,8 +909,11 @@ class ChatHistorySummarizer:
result = _parse_with_quote_fix(extracted_json)
keywords = result.get("keywords", [])
summary = result.get("summary", "无概括")
summary = result.get("summary", "")
key_point = result.get("key_point", [])
if not (keywords and summary) and key_point:
logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少字段原文\n{response}")
# 确保keywords和key_point是列表
if isinstance(keywords, str):

View File

@ -2,7 +2,7 @@ import time
import json
import asyncio
import re
from typing import List, Dict, Any, Optional, Tuple, Set
from typing import List, Dict, Any, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
@ -11,7 +11,8 @@ from src.common.database.database_model import ThinkingBack
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.memory_system.memory_utils import parse_questions_json
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.bw_learner.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon
from src.chat.message_receive.chat_stream import get_chat_manager
from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval")
@ -100,6 +101,7 @@ def init_memory_retrieval_prompt():
**工具说明**
- 如果涉及过往事件或者查询某个过去可能提到过的概念或者某段时间发生的事件可以使用聊天记录查询工具查询过往事件
- 如果涉及人物可以使用人物信息查询工具查询人物信息
- 如果遇到不熟悉的词语缩写黑话或网络用语可以使用query_words工具查询其含义
- 如果没有可靠信息且查询时间充足或者不确定查询类别也可以使用lpmm知识库查询作为辅助信息
**思考**
@ -202,7 +204,6 @@ async def _react_agent_solve_question(
max_iterations: int = 5,
timeout: float = 30.0,
initial_info: str = "",
initial_jargon_concepts: Optional[List[str]] = None,
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
"""使用ReAct架构的Agent来解决问题
@ -211,28 +212,29 @@ async def _react_agent_solve_question(
chat_id: 聊天ID
max_iterations: 最大迭代次数
timeout: 超时时间
initial_info: 初始信息如概念检索结果将作为collected_info的初始值
initial_jargon_concepts: 预先已解析过的黑话列表避免重复解释
initial_info: 初始信息将作为collected_info的初始值
Returns:
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
"""
start_time = time.time()
collected_info = initial_info if initial_info else ""
enable_jargon_detection = global_config.memory.enable_jargon_detection
seen_jargon_concepts: Set[str] = set()
if enable_jargon_detection and initial_jargon_concepts:
for concept in initial_jargon_concepts:
concept = (concept or "").strip()
if concept:
seen_jargon_concepts.add(concept)
# 构造日志前缀:[聊天流名称],用于在日志中标识聊天流
try:
chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
except Exception:
chat_name = chat_id
react_log_prefix = f"[{chat_name}] "
thinking_steps = []
is_timeout = False
conversation_messages: List[Message] = []
first_head_prompt: Optional[str] = None # 保存第一次使用的head_prompt用于日志显示
last_tool_name: Optional[str] = None # 记录最后一次使用的工具名称
# 正常迭代max_iterations 次(最终评估单独处理,不算在迭代中)
for iteration in range(max_iterations):
# 使用 while 循环,支持额外迭代
iteration = 0
max_iterations_with_extra = max_iterations
while iteration < max_iterations_with_extra:
# 检查超时
if time.time() - start_time > timeout:
logger.warning(f"ReAct Agent超时已迭代{iteration}")
@ -475,7 +477,7 @@ async def _react_agent_solve_question(
step["observations"] = ["检测到finish_search文本格式调用找到答案"]
thinking_steps.append(step)
logger.info(
f"ReAct Agent {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
)
_log_conversation_messages(
@ -488,7 +490,7 @@ async def _react_agent_solve_question(
else:
# found_answer为True但没有提供answer视为错误继续迭代
logger.warning(
f"ReAct Agent {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
f"{react_log_prefix}{iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
@ -497,7 +499,9 @@ async def _react_agent_solve_question(
)
step["observations"] = ["检测到finish_search文本格式调用未找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案"
)
_log_conversation_messages(
conversation_messages,
@ -509,12 +513,15 @@ async def _react_agent_solve_question(
# 如果没有检测到finish_search格式记录思考过程继续下一轮迭代
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 思考完成但未调用工具: {response}"
)
collected_info += f"思考: {response}"
else:
logger.warning(f"ReAct Agent {iteration + 1} 次迭代 无工具调用且无响应")
logger.warning(f"{react_log_prefix}{iteration + 1} 次迭代 无工具调用且无响应")
step["observations"] = ["无响应且无工具调用"]
thinking_steps.append(step)
iteration += 1 # 在continue之前增加迭代计数避免跳过iteration += 1
continue
# 处理工具调用
@ -541,7 +548,7 @@ async def _react_agent_solve_question(
step["observations"] = ["检测到finish_search工具调用找到答案"]
thinking_steps.append(step)
logger.info(
f"ReAct Agent {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
)
_log_conversation_messages(
@ -554,14 +561,16 @@ async def _react_agent_solve_question(
else:
# found_answer为True但没有提供answer视为错误
logger.warning(
f"ReAct Agent {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
f"{react_log_prefix}{iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["observations"] = ["检测到finish_search工具调用未找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案")
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search工具判断未找到答案"
)
_log_conversation_messages(
conversation_messages,
@ -578,13 +587,16 @@ async def _react_agent_solve_question(
tool_args = tool_call.args or {}
logger.debug(
f"ReAct Agent {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
f"{react_log_prefix}{iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
)
# 跳过finish_search工具调用已经在上面处理过了
if tool_name == "finish_search":
continue
# 记录最后一次使用的工具名称(用于判断是否需要额外迭代)
last_tool_name = tool_name
# 普通工具调用
tool = tool_registry.get_tool(tool_name)
if tool:
@ -604,14 +616,18 @@ async def _react_agent_solve_question(
return f"查询{tool_name_str}({param_str})的结果:{observation}"
except Exception as e:
error_msg = f"工具执行失败: {str(e)}"
logger.error(f"ReAct Agent 第 {iter_num + 1} 次迭代 工具 {tool_name_str} {error_msg}")
logger.error(
f"{react_log_prefix}{iter_num + 1} 次迭代 工具 {tool_name_str} {error_msg}"
)
return f"查询{tool_name_str}失败: {error_msg}"
tool_tasks.append(execute_single_tool(tool, tool_params, tool_name, iteration))
step["actions"].append({"action_type": tool_name, "action_params": tool_args})
else:
error_msg = f"未知的工具类型: {tool_name}"
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}")
logger.warning(
f"{react_log_prefix}{iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}"
)
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
# 并行执行所有工具
@ -622,31 +638,16 @@ async def _react_agent_solve_question(
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
if isinstance(observation, Exception):
observation = f"工具执行异常: {str(observation)}"
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}")
logger.error(
f"{react_log_prefix}{iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}"
)
observation_text = observation if isinstance(observation, str) else str(observation)
stripped_observation = observation_text.strip()
step["observations"].append(observation_text)
collected_info += f"\n{observation_text}\n"
if stripped_observation:
# 检查工具输出中是否有新的jargon如果有则追加到工具结果中
if enable_jargon_detection:
jargon_concepts = match_jargon_from_text(stripped_observation, chat_id)
if jargon_concepts:
new_concepts = []
for concept in jargon_concepts:
normalized_concept = concept.strip()
if normalized_concept and normalized_concept not in seen_jargon_concepts:
new_concepts.append(normalized_concept)
seen_jargon_concepts.add(normalized_concept)
if new_concepts:
jargon_info = await retrieve_concepts_with_jargon(new_concepts, chat_id)
if jargon_info:
# 将jargon查询结果追加到工具结果中
observation_text += f"\n\n{jargon_info}"
collected_info += f"\n{jargon_info}\n"
logger.info(f"工具输出触发黑话解析: {new_concepts}")
# 不再自动检测工具输出中的jargon改为通过 query_words 工具主动查询
tool_builder = MessageBuilder()
tool_builder.set_role(RoleType.Tool)
tool_builder.add_text_content(observation_text)
@ -655,15 +656,24 @@ async def _react_agent_solve_question(
thinking_steps.append(step)
# 检查是否需要额外迭代:如果最后一次使用的工具是 search_chat_history 且达到最大迭代次数,额外增加一回合
if iteration + 1 >= max_iterations and last_tool_name == "search_chat_history" and not is_timeout:
max_iterations_with_extra = max_iterations + 1
logger.info(
f"{react_log_prefix}达到最大迭代次数(已迭代{iteration + 1}次),最后一次使用工具为 search_chat_history额外增加一回合尝试"
)
iteration += 1
# 正常迭代结束后,如果达到最大迭代次数或超时,执行最终评估
# 最终评估单独处理,不算在迭代中
should_do_final_evaluation = False
if is_timeout:
should_do_final_evaluation = True
logger.warning(f"ReAct Agent超时已迭代{iteration + 1}次,进入最终评估")
elif iteration + 1 >= max_iterations:
logger.warning(f"{react_log_prefix}超时,已迭代{iteration}次,进入最终评估")
elif iteration >= max_iterations:
should_do_final_evaluation = True
logger.info(f"ReAct Agent达到最大迭代次数已迭代{iteration + 1}次),进入最终评估")
logger.info(f"{react_log_prefix}达到最大迭代次数(已迭代{iteration}次),进入最终评估")
if should_do_final_evaluation:
# 获取必要变量用于最终评估
@ -766,8 +776,8 @@ async def _react_agent_solve_question(
return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout
if global_config.debug.show_memory_prompt:
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
logger.info(f"ReAct Agent 最终评估响应: {eval_response}")
logger.info(f"{react_log_prefix}最终评估Prompt: {evaluation_prompt}")
logger.info(f"{react_log_prefix}最终评估响应: {eval_response}")
# 从最终评估响应中提取found_answer或not_enough_info
found_answer_content = None
@ -998,7 +1008,6 @@ async def _process_single_question(
chat_id: str,
context: str,
initial_info: str = "",
initial_jargon_concepts: Optional[List[str]] = None,
max_iterations: Optional[int] = None,
) -> Optional[str]:
"""处理单个问题的查询
@ -1007,12 +1016,17 @@ async def _process_single_question(
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
initial_info: 初始信息如概念检索结果将传递给ReAct Agent
initial_jargon_concepts: 已经处理过的黑话概念列表用于ReAct阶段的去重
initial_info: 初始信息将传递给ReAct Agent
max_iterations: 最大迭代次数
Returns:
Optional[str]: 如果找到答案返回格式化的结果字符串否则返回None
"""
# 如果question为空或None直接返回None不进行查询
if not question or not question.strip():
logger.debug("问题为空,跳过查询")
return None
# logger.info(f"开始处理问题: {question}")
_cleanup_stale_not_found_thinking_back()
@ -1022,8 +1036,6 @@ async def _process_single_question(
# 直接使用ReAct Agent查询不再从thinking_back获取缓存
# logger.info(f"使用ReAct Agent查询问题: {question[:50]}...")
jargon_concepts_for_agent = initial_jargon_concepts if global_config.memory.enable_jargon_detection else None
# 如果未指定max_iterations使用配置的默认值
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
@ -1034,7 +1046,6 @@ async def _process_single_question(
max_iterations=max_iterations,
timeout=global_config.memory.agent_timeout_seconds,
initial_info=question_initial_info,
initial_jargon_concepts=jargon_concepts_for_agent,
)
# 存储查询历史到数据库(超时时不存储)
@ -1062,6 +1073,8 @@ async def build_memory_retrieval_prompt(
target: str,
chat_stream,
think_level: int = 1,
unknown_words: Optional[List[str]] = None,
question: Optional[str] = None,
) -> str:
"""构建记忆检索提示
使用两段式查询第一步生成问题第二步使用ReAct Agent查询答案
@ -1071,14 +1084,33 @@ async def build_memory_retrieval_prompt(
sender: 发送者名称
target: 目标消息内容
chat_stream: 聊天流对象
tool_executor: 工具执行器保留参数以兼容接口
think_level: 思考深度等级
unknown_words: Planner 提供的未知词语列表优先使用此列表而不是从聊天记录匹配
question: Planner 提供的问题 planner_question 配置开启时直接使用此问题进行检索
Returns:
str: 记忆检索结果字符串
"""
start_time = time.time()
logger.info(f"检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}")
# 构造日志前缀:[聊天流名称],用于在日志中标识聊天流(优先群名称/用户昵称)
try:
group_info = chat_stream.group_info
user_info = chat_stream.user_info
# 群聊优先使用群名称
if group_info is not None and getattr(group_info, "group_name", None):
stream_name = group_info.group_name.strip() or str(group_info.group_id)
# 私聊使用用户昵称
elif user_info is not None and getattr(user_info, "user_nickname", None):
stream_name = user_info.user_nickname.strip() or str(user_info.user_id)
# 兜底使用 stream_id
else:
stream_name = chat_stream.stream_id
except Exception:
stream_name = chat_stream.stream_id
log_prefix = f"[{stream_name}] " if stream_name else ""
logger.info(f"{log_prefix}检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}")
try:
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
@ -1089,66 +1121,81 @@ async def build_memory_retrieval_prompt(
if not recent_query_history:
recent_query_history = "最近没有查询记录。"
# 第一步:生成问题
question_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_question_prompt",
bot_name=bot_name,
time_now=time_now,
chat_history=message,
recent_query_history=recent_query_history,
sender=sender,
target_message=target,
)
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
question_prompt,
model_config=model_config.model_task_config.tool_use,
request_type="memory.question",
)
if global_config.debug.show_memory_prompt:
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
# logger.info(f"记忆检索问题生成响应: {response}")
if not success:
logger.error(f"LLM生成问题失败: {response}")
return ""
# 解析概念列表和问题列表
_, questions = parse_questions_json(response)
if questions:
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
enable_jargon_detection = global_config.memory.enable_jargon_detection
concepts: List[str] = []
if enable_jargon_detection:
# 使用匹配逻辑自动识别聊天中的黑话概念
concepts = match_jargon_from_text(message, chat_id)
if concepts:
logger.info(f"黑话匹配命中 {len(concepts)} 个概念: {concepts}")
# 第一步:生成问题或使用 Planner 提供的问题
single_question: Optional[str] = None
# 如果 planner_question 配置开启,只使用 Planner 提供的问题,不使用旧模式
if global_config.memory.planner_question:
if question and isinstance(question, str) and question.strip():
# 清理和验证 question
single_question = question.strip()
logger.info(f"{log_prefix}使用 Planner 提供的 question: {single_question}")
else:
logger.debug("黑话匹配未命中任何概念")
# planner_question 开启但没有提供 question跳过记忆检索
logger.debug(f"{log_prefix}planner_question 已开启但未提供 question跳过记忆检索")
end_time = time.time()
logger.info(f"{log_prefix}无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}")
return ""
else:
logger.debug("已禁用记忆检索中的黑话识别")
# planner_question 关闭使用旧模式LLM 生成问题
question_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_question_prompt",
bot_name=bot_name,
time_now=time_now,
chat_history=message,
recent_query_history=recent_query_history,
sender=sender,
target_message=target,
)
# 对匹配到的概念进行jargon检索作为初始信息
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
question_prompt,
model_config=model_config.model_task_config.tool_use,
request_type="memory.question",
)
if global_config.debug.show_memory_prompt:
logger.info(f"{log_prefix}记忆检索问题生成提示词: {question_prompt}")
# logger.info(f"记忆检索问题生成响应: {response}")
if not success:
logger.error(f"{log_prefix}LLM生成问题失败: {response}")
return ""
# 解析概念列表和问题列表,只取第一个问题
_, questions = parse_questions_json(response)
if questions and len(questions) > 0:
single_question = questions[0].strip()
logger.info(f"{log_prefix}解析到问题: {single_question}")
# 初始阶段:使用 Planner 提供的 unknown_words 进行检索(如果提供)
initial_info = ""
if enable_jargon_detection and concepts:
concept_info = await retrieve_concepts_with_jargon(concepts, chat_id)
if concept_info:
initial_info += concept_info
logger.debug(f"概念检索完成,结果: {concept_info}")
else:
logger.debug("概念检索未找到任何结果")
if unknown_words and len(unknown_words) > 0:
# 清理和去重 unknown_words
cleaned_concepts = []
for word in unknown_words:
if isinstance(word, str):
cleaned = word.strip()
if cleaned:
cleaned_concepts.append(cleaned)
if cleaned_concepts:
# 对匹配到的概念进行jargon检索作为初始信息
concept_info = await retrieve_concepts_with_jargon(cleaned_concepts, chat_id)
if concept_info:
initial_info += concept_info
logger.info(
f"{log_prefix}使用 Planner 提供的 unknown_words{len(cleaned_concepts)} 个概念,检索结果: {concept_info[:100]}..."
)
else:
logger.debug(f"{log_prefix}unknown_words 检索未找到任何结果")
if not questions:
logger.debug("模型认为不需要检索记忆或解析失败,不返回任何查询结果")
if not single_question:
logger.debug(f"{log_prefix}模型认为不需要检索记忆或解析失败,不返回任何查询结果")
end_time = time.time()
logger.info(f"无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}")
logger.info(f"{log_prefix}无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}")
return ""
# 第二步:并行处理所有问题(使用配置的最大迭代次数和超时时间)
# 第二步:处理问题(使用配置的最大迭代次数和超时时间)
base_max_iterations = global_config.memory.max_agent_iterations
# 根据think_level调整迭代次数think_level=1时不变think_level=0时减半
if think_level == 0:
@ -1157,32 +1204,21 @@ async def build_memory_retrieval_prompt(
max_iterations = base_max_iterations
timeout_seconds = global_config.memory.agent_timeout_seconds
logger.debug(
f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
f"{log_prefix}问题: {single_question}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
)
# 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [
_process_single_question(
question=question,
# 处理单个问题
try:
result = await _process_single_question(
question=single_question,
chat_id=chat_id,
context=message,
initial_info=initial_info,
initial_jargon_concepts=concepts if enable_jargon_detection else None,
max_iterations=max_iterations,
)
for question in questions
]
# 并行执行所有查询任务
results = await asyncio.gather(*question_tasks, return_exceptions=True)
# 收集所有有效结果
question_results: List[str] = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"处理问题 '{questions[i]}' 时发生异常: {result}")
elif result is not None:
question_results.append(result)
except Exception as e:
logger.error(f"{log_prefix}处理问题 '{single_question}' 时发生异常: {e}")
result = None
# 获取最近10分钟内已找到答案的缓存记录
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
@ -1191,39 +1227,39 @@ async def build_memory_retrieval_prompt(
all_results = []
# 先添加当前查询的结果
current_questions = set()
for result in question_results:
current_question = None
if result:
all_results.append(result)
# 提取问题(格式为 "问题xxx\n答案xxx"
if result.startswith("问题:"):
question_end = result.find("\n答案:")
if question_end != -1:
current_questions.add(result[4:question_end])
all_results.append(result)
current_question = result[4:question_end]
# 添加缓存答案(排除当前查询中已存在的问题)
# 添加缓存答案(排除当前查询的问题)
for cached_answer in cached_answers:
if cached_answer.startswith("问题:"):
question_end = cached_answer.find("\n答案:")
if question_end != -1:
cached_question = cached_answer[4:question_end]
if cached_question not in current_questions:
if cached_question != current_question:
all_results.append(cached_answer)
end_time = time.time()
if all_results:
retrieved_memory = "\n\n".join(all_results)
current_count = len(question_results)
current_count = 1 if result else 0
cached_count = len(all_results) - current_count
logger.info(
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,"
f"{log_prefix}记忆检索成功,耗时: {(end_time - start_time):.3f}秒,"
f"当前查询 {current_count} 条记忆,缓存 {cached_count} 条记忆,共 {len(all_results)} 条记忆"
)
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else:
logger.debug("所有问题未找到答案,且无缓存答案")
logger.debug(f"{log_prefix}问题未找到答案,且无缓存答案")
return ""
except Exception as e:
logger.error(f"记忆检索时发生异常: {str(e)}")
logger.error(f"{log_prefix}记忆检索时发生异常: {str(e)}")
return ""

View File

@ -14,6 +14,7 @@ from .tool_registry import (
from .query_chat_history import register_tool as register_query_chat_history
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_person_info import register_tool as register_query_person_info
from .query_words import register_tool as register_query_words
from .found_answer import register_tool as register_finish_search
from src.config.config import global_config
@ -22,6 +23,7 @@ def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_chat_history()
register_query_person_info()
register_query_words() # 注册query_words工具
register_finish_search() # 注册finish_search工具
if global_config.lpmm_knowledge.lpmm_mode == "agent":

View File

@ -4,7 +4,7 @@
"""
import json
from typing import Optional
from typing import Optional, Set
from datetime import datetime
from src.common.logger import get_logger
@ -16,35 +16,182 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]:
"""将黑名单配置platform:id:type格式转换为chat_id集合
Args:
blacklist: 黑名单配置列表格式为 ["platform:id:type", ...]
Returns:
Set[str]: chat_id集合
"""
chat_ids = set()
if not blacklist:
return chat_ids
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
for blacklist_item in blacklist:
if not isinstance(blacklist_item, str):
continue
try:
parts = blacklist_item.split(":")
if len(parts) != 3:
logger.warning(f"黑名单配置格式错误,应为 platform:id:type实际: {blacklist_item}")
continue
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
# 判断是否为群聊
is_group = stream_type == "group"
# 转换为chat_id
chat_id = chat_manager.get_stream_id(platform, str(id_str), is_group=is_group)
if chat_id:
chat_ids.add(chat_id)
else:
logger.warning(f"无法将黑名单配置转换为chat_id: {blacklist_item}")
except Exception as e:
logger.warning(f"解析黑名单配置失败: {blacklist_item}, 错误: {e}")
except Exception as e:
logger.error(f"初始化黑名单chat_id集合失败: {e}")
return chat_ids
def _is_chat_id_in_blacklist(chat_id: str) -> bool:
"""检查chat_id是否在全局记忆黑名单中
Args:
chat_id: 要检查的chat_id
Returns:
bool: 如果chat_id在黑名单中返回True否则返回False
"""
blacklist = getattr(global_config.memory, "global_memory_blacklist", [])
if not blacklist:
return False
blacklist_chat_ids = _parse_blacklist_to_chat_ids(blacklist)
return chat_id in blacklist_chat_ids
async def search_chat_history(
chat_id: str,
keyword: Optional[str] = None,
participant: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
) -> str:
"""根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords
Args:
chat_id: 聊天ID
keyword: 关键词可选支持多个关键词可用空格逗号等分隔匹配规则如果关键词数量<=2必须全部匹配如果关键词数量>2允许n-1个关键词匹配
participant: 参与人昵称可选
start_time: 开始时间可选格式如'2025-01-01' '2025-01-01 12:00:00' '2025/01/01'如果只提供start_time查询该时间点之后的记录
end_time: 结束时间可选格式如'2025-01-01' '2025-01-01 12:00:00' '2025/01/01'如果只提供end_time查询该时间点之前的记录如果同时提供start_time和end_time查询该时间段内的记录
Returns:
str: 查询结果包含记忆idtheme和keywords
"""
try:
# 检查参数
if not keyword and not participant:
return "未指定查询参数需要提供keyword或participant之一"
if not keyword and not participant and not start_time and not end_time:
return "未指定查询参数需要提供keyword、participant、start_time或end_time之一"
# 解析时间参数
start_timestamp = None
end_timestamp = None
if start_time:
try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp
start_timestamp = parse_datetime_to_timestamp(start_time)
except ValueError as e:
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
if end_time:
try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp
end_timestamp = parse_datetime_to_timestamp(end_time)
except ValueError as e:
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
# 验证时间范围
if start_timestamp and end_timestamp and start_timestamp > end_timestamp:
return "开始时间不能晚于结束时间"
# 构建查询条件
# 检查当前chat_id是否在黑名单中
is_current_chat_in_blacklist = _is_chat_id_in_blacklist(chat_id)
# 根据配置决定是否限制在当前 chat_id 内查询
use_global_search = global_config.memory.global_memory
# 如果当前chat_id在黑名单中强制使用本地查询
use_global_search = global_config.memory.global_memory and not is_current_chat_in_blacklist
if use_global_search:
# 全局查询所有聊天记录
query = ChatHistory.select()
logger.debug(
f"search_chat_history 启用全局查询模式,忽略 chat_id 过滤keyword={keyword}, participant={participant}"
)
# 全局查询所有聊天记录,但排除黑名单中的聊天流
blacklist_chat_ids = _parse_blacklist_to_chat_ids(global_config.memory.global_memory_blacklist)
if blacklist_chat_ids:
# 排除黑名单中的chat_id
query = ChatHistory.select().where(~(ChatHistory.chat_id.in_(blacklist_chat_ids)))
logger.debug(
f"search_chat_history 启用全局查询模式(排除黑名单 {len(blacklist_chat_ids)} 个聊天流keyword={keyword}, participant={participant}"
)
else:
# 没有黑名单,查询所有
query = ChatHistory.select()
logger.debug(
f"search_chat_history 启用全局查询模式,忽略 chat_id 过滤keyword={keyword}, participant={participant}"
)
else:
# 仅在当前聊天流内查询
if is_current_chat_in_blacklist:
logger.debug(
f"search_chat_history 当前聊天流在黑名单中强制使用本地查询chat_id={chat_id}, keyword={keyword}, participant={participant}"
)
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 添加时间过滤条件
if start_timestamp is not None and end_timestamp is not None:
# 查询指定时间段内的记录(记录的时间范围与查询时间段有交集)
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
query = query.where(
(
(ChatHistory.start_time >= start_timestamp)
& (ChatHistory.start_time <= end_timestamp)
) # 记录开始时间在查询时间段内
| (
(ChatHistory.end_time >= start_timestamp)
& (ChatHistory.end_time <= end_timestamp)
) # 记录结束时间在查询时间段内
| (
(ChatHistory.start_time <= start_timestamp)
& (ChatHistory.end_time >= end_timestamp)
) # 记录完全包含查询时间段
)
logger.debug(
f"search_chat_history 添加时间范围过滤: {start_timestamp} - {end_timestamp}, keyword={keyword}, participant={participant}"
)
elif start_timestamp is not None:
# 只提供开始时间,查询该时间点之后的记录(记录的开始时间或结束时间在该时间点之后)
query = query.where(ChatHistory.end_time >= start_timestamp)
logger.debug(
f"search_chat_history 添加开始时间过滤: >= {start_timestamp}, keyword={keyword}, participant={participant}"
)
elif end_timestamp is not None:
# 只提供结束时间,查询该时间点之前的记录(记录的开始时间或结束时间在该时间点之前)
query = query.where(ChatHistory.start_time <= end_timestamp)
logger.debug(
f"search_chat_history 添加结束时间过滤: <= {end_timestamp}, keyword={keyword}, participant={participant}"
)
# 执行查询
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
@ -134,21 +281,31 @@ async def search_chat_history(chat_id: str, keyword: Optional[str] = None, parti
filtered_records.append(record)
if not filtered_records:
if keyword and participant:
keywords_str = "".join(parse_keywords_string(keyword) if keyword else [])
return f"未找到包含关键词'{keywords_str}'且参与人包含'{participant}'的聊天记录"
elif keyword:
# 构建查询条件描述
conditions = []
if keyword:
keywords_str = "".join(parse_keywords_string(keyword))
keywords_list = parse_keywords_string(keyword)
if len(keywords_list) > 2:
required_count = len(keywords_list) - 1
return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant:
return f"未找到参与人包含'{participant}'的聊天记录"
conditions.append(f"关键词'{keywords_str}'")
if participant:
conditions.append(f"参与人'{participant}'")
if start_timestamp or end_timestamp:
time_desc = ""
if start_timestamp and end_timestamp:
start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S")
time_desc = f"时间范围'{start_str}''{end_str}'"
elif start_timestamp:
start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S")
time_desc = f"时间>='{start_str}'"
elif end_timestamp:
end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S")
time_desc = f"时间<='{end_str}'"
if time_desc:
conditions.append(time_desc)
if conditions:
conditions_str = "".join(conditions)
return f"未找到满足条件({conditions_str})的聊天记录"
else:
return "未找到相关聊天记录"
@ -336,7 +493,7 @@ def register_tool():
# 注册工具1搜索记忆
register_memory_retrieval_tool(
name="search_chat_history",
description="根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则如果关键词数量<=2必须全部匹配如果关键词数量>2允许n-1个关键词匹配容错匹配",
description="根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则如果关键词数量<=2必须全部匹配如果关键词数量>2允许n-1个关键词匹配容错匹配支持按时间点或时间段进行查询。",
parameters=[
{
"name": "keyword",
@ -350,6 +507,18 @@ def register_tool():
"description": "参与人昵称(可选),用于查询包含该参与人的记忆",
"required": False,
},
{
"name": "start_time",
"type": "string",
"description": "开始时间(可选),格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'。如果只提供start_time查询该时间点之后的记录。如果同时提供start_time和end_time查询该时间段内的记录",
"required": False,
},
{
"name": "end_time",
"type": "string",
"description": "结束时间(可选),格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'。如果只提供end_time查询该时间点之前的记录。如果同时提供start_time和end_time查询该时间段内的记录",
"required": False,
},
],
execute_func=search_chat_history,
)

View File

@ -0,0 +1,79 @@
"""
查询黑话/概念含义 - 工具实现
用于在记忆检索过程中主动查询未知词语或黑话的含义
"""
from src.common.logger import get_logger
from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_words(chat_id: str, words: str) -> str:
"""查询词语或黑话的含义
Args:
chat_id: 聊天ID
words: 要查询的词语可以是单个词语或多个词语用逗号空格等分隔
Returns:
str: 查询结果包含词语的含义解释
"""
try:
if not words or not words.strip():
return "未提供要查询的词语"
# 解析词语列表(支持逗号、空格等分隔符)
words_list = []
for separator in [",", "", " ", "\n", "\t"]:
if separator in words:
words_list = [w.strip() for w in words.split(separator) if w.strip()]
break
# 如果没有找到分隔符,整个字符串作为一个词语
if not words_list:
words_list = [words.strip()]
# 去重
unique_words = []
seen = set()
for word in words_list:
if word and word not in seen:
unique_words.append(word)
seen.add(word)
if not unique_words:
return "未提供有效的词语"
logger.info(f"查询词语含义: {unique_words}")
# 调用检索函数
result = await retrieve_concepts_with_jargon(unique_words, chat_id)
if result:
return result
else:
return f"未找到词语 '{', '.join(unique_words)}' 的含义或黑话解释"
except Exception as e:
logger.error(f"查询词语含义失败: {e}")
return f"查询失败: {str(e)}"
def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_words",
description="查询词语或黑话的含义。当遇到不熟悉的词语、缩写、黑话或网络用语时,可以使用此工具查询其含义。支持查询单个或多个词语(用逗号、空格等分隔)。",
parameters=[
{
"name": "words",
"type": "string",
"description": "要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔,如:'YYDS''YYDS,内卷,996'",
"required": True,
},
],
execute_func=query_words,
)

View File

@ -19,7 +19,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("person_info")
relation_selection_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="relation_selection"
model_set=model_config.model_task_config.tool_use, request_type="relation_selection"
)
@ -228,8 +228,59 @@ class Person:
return person
def _is_bot_self(self, platform: str, user_id: str) -> bool:
"""判断给定的平台和用户ID是否是机器人自己
这个函数统一处理所有平台包括 QQTelegramWebUI 的机器人识别逻辑
Args:
platform: 消息平台 "qq", "telegram", "webui"
user_id: 用户ID
Returns:
bool: 如果是机器人自己则返回 True否则返回 False
"""
if not platform or not user_id:
return False
# 将 user_id 转为字符串进行比较
user_id_str = str(user_id)
# 获取机器人的 QQ 账号(主账号)
qq_account = str(global_config.bot.qq_account or "")
# QQ 平台:直接比较 QQ 账号
if platform == "qq":
return user_id_str == qq_account
# WebUI 平台:机器人回复时使用的是 QQ 账号,所以也比较 QQ 账号
if platform == "webui":
return user_id_str == qq_account
# 获取各平台账号映射
platforms_list = getattr(global_config.bot, "platforms", []) or []
platform_accounts = {}
for platform_entry in platforms_list:
if ":" in platform_entry:
platform_name, account = platform_entry.split(":", 1)
platform_accounts[platform_name.strip()] = account.strip()
# Telegram 平台
if platform == "telegram":
tg_account = platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
return user_id_str == tg_account if tg_account else False
# 其他平台:尝试从 platforms 配置中查找
platform_account = platform_accounts.get(platform, "")
if platform_account:
return user_id_str == platform_account
# 默认情况:与主 QQ 账号比较(兼容性)
return user_id_str == qq_account
def __init__(self, platform: str = "", user_id: str = "", person_id: str = "", person_name: str = ""):
if platform == global_config.bot.platform and user_id == global_config.bot.qq_account:
# 使用统一的机器人识别函数(支持多平台,包括 WebUI
if self._is_bot_self(platform, user_id):
self.is_known = True
self.person_id = get_person_id(platform, user_id)
self.user_id = user_id

View File

@ -56,7 +56,6 @@ from .apis import (
person_api,
plugin_manage_api,
send_api,
auto_talk_api,
register_plugin,
get_logger,
)

View File

@ -19,7 +19,6 @@ from src.plugin_system.apis import (
send_api,
tool_api,
frequency_api,
auto_talk_api,
)
from .logging_api import get_logger
from .plugin_register_api import register_plugin
@ -41,5 +40,4 @@ __all__ = [
"register_plugin",
"tool_api",
"frequency_api",
"auto_talk_api",
]

View File

@ -1,56 +0,0 @@
from src.common.logger import get_logger
logger = get_logger("auto_talk_api")
def set_question_probability_multiplier(chat_id: str, multiplier: float) -> bool:
"""
设置指定 chat_id 的主动发言概率乘数
返回:
bool: 设置是否成功仅当目标聊天为群聊(HeartFChatting)且存在时为 True
"""
try:
if not isinstance(chat_id, str):
raise TypeError("chat_id 必须是 str")
if not isinstance(multiplier, (int, float)):
raise TypeError("multiplier 必须是数值类型")
# 延迟导入以避免循环依赖
from src.chat.heart_flow.heartflow import heartflow as _heartflow
chat = _heartflow.heartflow_chat_list.get(chat_id)
if chat is None:
logger.warning(f"未找到 chat_id={chat_id} 的心流实例,无法设置乘数")
return False
# 仅对拥有该属性的群聊心流生效(鸭子类型,避免导入类)
if not hasattr(chat, "question_probability_multiplier"):
logger.warning(f"chat_id={chat_id} 实例不支持主动发言乘数设置")
return False
# 约束:不允许负值
value = float(multiplier)
if value < 0:
value = 0.0
chat.question_probability_multiplier = value
logger.info(f"[auto_talk_api] chat_id={chat_id} 主动发言乘数已设为 {value}")
return True
except Exception as e:
logger.error(f"设置主动发言乘数失败: {e}")
return False
def get_question_probability_multiplier(chat_id: str) -> float:
"""获取指定 chat_id 的主动发言概率乘数,未找到则返回 0。"""
try:
# 延迟导入以避免循环依赖
from src.chat.heart_flow.heartflow import heartflow as _heartflow
chat = _heartflow.heartflow_chat_list.get(chat_id)
if chat is None:
return 0.0
return float(getattr(chat, "question_probability_multiplier", 0.0))
except Exception:
return 0.0

View File

@ -9,6 +9,7 @@
"""
import traceback
import time
from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING
from rich.traceback import install
from src.common.logger import get_logger
@ -19,6 +20,7 @@ from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo
from src.chat.logger.plan_reply_logger import PlanReplyLogger
if TYPE_CHECKING:
from src.common.data_models.info_data_model import ActionPlannerInfo
@ -118,6 +120,10 @@ async def generate_reply(
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
"""
try:
# 如果 reply_time_point 未传入,设置为当前时间戳
if reply_time_point is None:
reply_time_point = time.time()
# 获取回复器
logger.debug("[GeneratorAPI] 开始生成回复")
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
@ -152,21 +158,42 @@ async def generate_reply(
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,
unknown_words=unknown_words,
unknown_words=unknown_words,
think_level=think_level,
from_plugin=from_plugin,
stream_id=chat_stream.stream_id if chat_stream else chat_id,
reply_time_point=reply_time_point,
log_reply=False,
)
if not success:
logger.warning("[GeneratorAPI] 回复生成失败")
return False, None
reply_set: Optional[ReplySetModel] = None
if content := llm_response.content:
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
llm_response.processed_output = processed_response
reply_set = ReplySetModel()
for text in processed_response:
reply_set.add_text_content(text)
llm_response.reply_set = reply_set
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
# 统一在这里记录最终回复日志(包含分割后的 processed_output
try:
PlanReplyLogger.log_reply(
chat_id=chat_stream.stream_id if chat_stream else (chat_id or ""),
prompt=llm_response.prompt or "",
output=llm_response.content,
processed_output=llm_response.processed_output,
model=llm_response.model,
timing=llm_response.timing,
reasoning=llm_response.reasoning,
think_level=think_level,
success=True,
)
except Exception:
logger.exception("[GeneratorAPI] 记录reply日志失败")
return success, llm_response
except ValueError as ve:

View File

@ -12,7 +12,7 @@ import time
from typing import List, Dict, Any, Tuple, Optional
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database_model import Images
from src.config.config import global_config
from src.chat.utils.utils import is_bot_self
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
@ -511,7 +511,8 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag
Returns:
过滤后的消息列表
"""
return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)]
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]
def translate_pid_to_description(pid: str) -> str:

View File

@ -28,7 +28,6 @@ class BaseAction(ABC):
- keyword_case_sensitive: 关键词是否区分大小写
- parallel_action: 是否允许并行执行
- random_activation_probability: 随机激活概率
- llm_judge_prompt: LLM判断提示词
"""
def __init__(
@ -81,8 +80,6 @@ class BaseAction(ABC):
"""激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率"""
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "") # 已弃用
"""协助LLM进行判断的Prompt"""
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
"""激活类型为KEYWORD时的KEYWORDS列表"""
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
@ -504,7 +501,6 @@ class BaseAction(ABC):
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
parallel_action=getattr(cls, "parallel_action", True),
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
# 使用正确的字段名
action_parameters=getattr(cls, "action_parameters", {}).copy(),
action_require=getattr(cls, "action_require", []).copy(),

View File

@ -33,7 +33,6 @@ class ActionActivationType(Enum):
NEVER = "never" # 从不激活(默认关闭)
ALWAYS = "always" # 默认参与到planner
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
RANDOM = "random" # 随机启用action到planner
KEYWORD = "keyword" # 关键词触发启用action到planner
@ -128,7 +127,6 @@ class ActionInfo(ComponentInfo):
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
activation_type: ActionActivationType = ActionActivationType.ALWAYS
random_activation_probability: float = 0.0
llm_judge_prompt: str = ""
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
keyword_case_sensitive: bool = False
# 模式和并行设置

View File

@ -28,7 +28,7 @@
"type": "action",
"name": "tts_action",
"description": "将文本转换为语音进行播放",
"activation_modes": ["llm_judge", "keyword"],
"activation_modes": ["keyword"],
"keywords": ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
}
],

Some files were not shown because too many files have changed in this diff Show More