mirror of https://github.com/Mai-with-u/MaiBot.git
commit
0d685806a7
|
|
@ -46,7 +46,7 @@
|
|||
## 🔥 更新和安装
|
||||
|
||||
|
||||
**最新版本: v0.11.6** ([更新日志](changelogs/changelog.md))
|
||||
**最新版本: v0.12.0** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
|
||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
|
|
|
|||
1
bot.py
1
bot.py
|
|
@ -59,6 +59,7 @@ def run_runner_process():
|
|||
|
||||
while True:
|
||||
logger.info(f"正在启动 {script_file}...")
|
||||
logger.info("正在编译着色器:1/114514")
|
||||
|
||||
# 启动子进程 (Worker)
|
||||
# 使用 sys.executable 确保使用相同的 Python 解释器
|
||||
|
|
|
|||
|
|
@ -1,6 +1,18 @@
|
|||
# Changelog
|
||||
## [0.12.1] - 2025-12-31
|
||||
### 🌟 主要更新
|
||||
- 添加年度总结!可以在webui查看
|
||||
- 可选让llm判定引用回复
|
||||
- 表达方式优化!现在可以进行自动和手动评估,使其更精准
|
||||
- 回复和规划记录!webui可以查看每一条回复和plan的详情
|
||||
|
||||
## [0.12.0] - 2025-12-16
|
||||
### 细节功能更改
|
||||
- 优化间隔过长消息的显示
|
||||
- enable_jargon_detection
|
||||
- global_memory_blacklist。指定部分群聊不参与全局记忆
|
||||
- 移除utils_small模型,移除弃用的lpmm模型
|
||||
|
||||
## [0.12.0] - 2025-12-21
|
||||
### 🌟 重大更新
|
||||
- 添加思考力度机制,动态控制回复时间和长度
|
||||
- planner和replyer现在开启联动,更好的回复逻辑
|
||||
|
|
@ -9,12 +21,32 @@
|
|||
- mcp插件作为内置插件加入,默认不启用
|
||||
- 添加全局记忆配置项,现在可以选择让记忆为全局的
|
||||
|
||||
### 🌟 WebUI 重大更新
|
||||
- **模型预设市场功能正式完善并发布**:现在可以将模型配置完整分享,分享按钮位于模型配置界面右上角
|
||||
- **全面安全加固**:为所有 WebUI API 和 WebSocket 端点添加身份认证保护,Cookie 添加 Secure 和 SameSite 属性,支持环境感知动态配置
|
||||
- **前端认证重构**:从 localStorage 迁移到 HttpOnly Cookie,新增 WebSocket 临时 token 认证机制,解决跨域开发环境下 Cookie 无法携带的问题
|
||||
- **增强插件配置管理**:支持原始 TOML 配置的加载和保存,前端支持查看和编辑插件配置文件源文件
|
||||
|
||||
### 细节功能更改
|
||||
- 移除频率自动调整
|
||||
- 移除情绪功能
|
||||
- 优化记忆差许多呢超时设置
|
||||
- 部分配置为0的bug
|
||||
- 插件安装时可以主动选择克隆的分支
|
||||
- 首页中反馈问卷功能,可以提交反馈信息和建议信息
|
||||
- 黑话和表达不再提取包含名称的内容
|
||||
- 模型界面支持编辑 extra params 额外字段
|
||||
- 模型界面中的任务分配子界面支持编辑慢请求检测阈值
|
||||
- 模型界面中支持对单个模型单独指定温度参数和 max tokens 参数
|
||||
- 首页所有数据卡片支持自动选择单位+显示详细信息功能
|
||||
- WebUI 聊天室表情包、图片、富文本消息支持
|
||||
- 麦麦适配器配置界面的工作模式支持折叠
|
||||
- WebUI 插件配置解析支持动态 list 表单
|
||||
- WebUI 插件配置中的动态 list 支持开关、滑块和下拉框类型
|
||||
- 在插件商场、插件配置详情界面增加了重启按钮
|
||||
- 加强安全性和隐私保护:添加登录接口速率限制,防止暴力破解攻击,收紧 CORS 配置(限制允许的 HTTP 方法和请求头),完善路径校验(validate_safe_path 防止目录穿越攻击),fetchWithAuth 支持 FormData 文件上传,新增 robots.txt 路由和 X-Robots-Tag 响应头防止搜索引擎索引,前端添加 meta noindex/nofollow 标签阻止爬虫收录
|
||||
- 修复并优化了聊天室、模型配置、日志查看器、黑话管理、WebUI 端口占用、配置向导、首页图表、聊天室消息重复、移动端日志不可见、模型提供商删除、主程序配置换行符、HTTP 警告横幅、重启界面、LPMM 配置、人物信息、插件端点安全认证、WebSocket token 等问题,提升整体稳定性与体验。
|
||||
- 完成主程序配置与模型配置界面重构、模型提供商与麦麦适配器配置重构(引入 TOML 校验)、WebSocket 认证逻辑抽取为共享模块统一 WS 端点,升级 React 到 19.2.1 并更新依赖,WebUI 配置与可视化全部迁移到主配置及模型配置中,优化配置更新提示、插件详情页面和路径安全校验,并增强模型与梦境等多项配置的可视化和自动检验。
|
||||
|
||||
## [0.11.6] - 2025-12-2
|
||||
### 🌟 重大更新
|
||||
|
|
|
|||
|
|
@ -85,7 +85,6 @@ Action采用**两层决策机制**来优化性能和决策质量:
|
|||
| ----------- | ---------------------------------------- | ---------------------- |
|
||||
| [`NEVER`](#never-激活) | 从不激活,Action对麦麦不可见 | 临时禁用某个Action |
|
||||
| [`ALWAYS`](#always-激活) | 永远激活,Action总是在麦麦的候选池中 | 核心功能,如回复、不回复 |
|
||||
| [`LLM_JUDGE`](#llm_judge-激活) | 通过LLM智能判断当前情境是否需要激活此Action | 需要智能判断的复杂场景 |
|
||||
| `RANDOM` | 基于随机概率决定是否激活 | 增加行为随机性的功能 |
|
||||
| `KEYWORD` | 当检测到特定关键词时激活 | 明确触发条件的功能 |
|
||||
|
||||
|
|
@ -117,30 +116,6 @@ class AlwaysActivatedAction(BaseAction):
|
|||
return True, "执行了核心功能"
|
||||
```
|
||||
|
||||
#### `LLM_JUDGE` 激活
|
||||
|
||||
`ActionActivationType.LLM_JUDGE`会使得这个 Action 根据 LLM 的判断来决定是否加入候选池。
|
||||
|
||||
而 LLM 的判断是基于代码中预设的`llm_judge_prompt`和自动提供的聊天上下文进行的。
|
||||
|
||||
因此使用此种方法需要实现`llm_judge_prompt`属性。
|
||||
|
||||
```python
|
||||
class LLMJudgedAction(BaseAction):
|
||||
activation_type = ActionActivationType.LLM_JUDGE # 通过LLM判断激活
|
||||
# LLM判断提示词
|
||||
llm_judge_prompt = (
|
||||
"判定是否需要使用这个动作的条件:\n"
|
||||
"1. 用户希望调用XXX这个动作\n"
|
||||
"...\n"
|
||||
"请回答\"是\"或\"否\"。\n"
|
||||
)
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
# 根据LLM判断是否执行
|
||||
return True, "执行了LLM判断功能"
|
||||
```
|
||||
|
||||
#### `RANDOM` 激活
|
||||
|
||||
`ActionActivationType.RANDOM`会使得这个 Action 根据随机概率决定是否加入候选池。
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
!napcat
|
||||
!.env
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
stages:
|
||||
- build-image
|
||||
- package-helm-chart
|
||||
|
||||
# 仅在helm-chart分支运行
|
||||
workflow:
|
||||
rules:
|
||||
- if: '$CI_COMMIT_BRANCH == "helm-chart"'
|
||||
- when: never
|
||||
|
||||
# 构建并推送processor镜像
|
||||
build-preprocessor:
|
||||
stage: build-image
|
||||
image: reg.mikumikumi.xyz/base/kaniko-builder:latest
|
||||
variables:
|
||||
BUILD_NO_CACHE: true
|
||||
rules:
|
||||
- changes:
|
||||
- helm-chart/preprocessor/**
|
||||
script:
|
||||
- export BUILD_CONTEXT=helm-chart/preprocessor
|
||||
- export TMP_DST=reg.mikumikumi.xyz/maibot/preprocessor
|
||||
- export CHART_VERSION=$(cat helm-chart/Chart.yaml | grep '^version:' | cut -d' ' -f2)
|
||||
- export BUILD_DESTINATION="${TMP_DST}:${CHART_VERSION}"
|
||||
- export BUILD_ARGS="--destination ${TMP_DST}:latest"
|
||||
- build
|
||||
|
||||
# 打包并推送helm chart
|
||||
package-helm-chart:
|
||||
stage: package-helm-chart
|
||||
image: reg.mikumikumi.xyz/mirror/helm:latest
|
||||
rules:
|
||||
- changes:
|
||||
- helm-chart/files/**
|
||||
- helm-chart/templates/**
|
||||
- helm-chart/.gitignore
|
||||
- helm-chart/.helmignore
|
||||
- helm-chart/Chart.yaml
|
||||
- helm-chart/README.md
|
||||
- helm-chart/values.yaml
|
||||
script:
|
||||
- export CHART_VERSION=$(cat helm-chart/Chart.yaml | grep '^version:' | cut -d' ' -f2)
|
||||
- helm registry login reg.mikumikumi.xyz --username ${CI_REGISTRY_USER} --password ${CI_REGISTRY_PASSWORD}
|
||||
- helm package helm-chart
|
||||
- helm push maibot-${CHART_VERSION}.tgz oci://reg.mikumikumi.xyz/maibot
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
preprocessor
|
||||
.gitlab-ci.yml
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
apiVersion: v2
|
||||
name: maibot
|
||||
description: "Maimai Bot, a cyber friend dedicated to group chats"
|
||||
type: application
|
||||
version: 0.12.0
|
||||
appVersion: 0.12.0
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
# MaiBot Helm Chart
|
||||
|
||||
这是麦麦的Helm Chart,可以方便地将麦麦部署在Kubernetes集群中。
|
||||
|
||||
当前Helm Chart对应的麦麦版本可以在`Chart.yaml`中查看`appVersion`项。
|
||||
|
||||
详细部署文档:[Kubernetes 部署](https://docs.mai-mai.org/manual/deployment/mmc_deploy_kubernetes.html)
|
||||
|
||||
## 可用的Helm Chart版本列表
|
||||
|
||||
| Helm Chart版本 | 对应的MaiBot版本 | Commit SHA |
|
||||
|----------------|--------------|------------------------------------------|
|
||||
| 0.12.0 | 0.12.0 | baa6e90be7b20050fe25dfc74c0c70653601d00e |
|
||||
| 0.11.6-beta | 0.11.6-beta | 0bfff0457e6db3f7102fb7f77c58d972634fc93c |
|
||||
| 0.11.5-beta | 0.11.5-beta | ad2df627001f18996802f23c405b263e78af0d0f |
|
||||
| 0.11.3-beta | 0.11.3-beta | cd6dc18f546f81e08803d3b8dba48e504dad9295 |
|
||||
| 0.11.2-beta | 0.11.2-beta | d3c8cea00dbb97f545350f2c3d5bcaf252443df2 |
|
||||
| 0.11.1-beta | 0.11.1-beta | 94e079a340a43dff8a2bc178706932937fc10b11 |
|
||||
| 0.11.0-beta | 0.11.0-beta | 16059532d8ef87ac28e2be0838ff8b3a34a91d0f |
|
||||
| 0.10.3-beta | 0.10.3-beta | 7618937cd4fd0ab1a7bd8a31ab244a8b0742fced |
|
||||
| 0.10.0-alpha.0 | 0.10.0-alpha | 4efebed10aad977155d3d9e0c24bc6e14e1260ab |
|
||||
|
||||
## TL; DR
|
||||
|
||||
```shell
|
||||
helm install maimai \
|
||||
oci://reg.mikumikumi.xyz/maibot/maibot \
|
||||
--namespace bot \
|
||||
--version <MAIBOT_VERSION> \
|
||||
--values maibot.yaml
|
||||
```
|
||||
|
||||
## Values项说明
|
||||
|
||||
`values.yaml`分为几个大部分。
|
||||
|
||||
1. `EULA` & `PRIVACY`: 用户必须同意这里的协议才能成功部署麦麦。
|
||||
|
||||
2. `pre_processor`: 部署之前的预处理Job的配置。
|
||||
|
||||
3. `adapter`: 麦麦的Adapter的部署配置。
|
||||
|
||||
4. `core`: 麦麦本体的部署配置。
|
||||
|
||||
5. `statistics_dashboard`: 麦麦的运行统计看板部署配置。
|
||||
|
||||
麦麦每隔一段时间会自动输出html格式的运行统计报告,此统计报告可以部署为看板。
|
||||
|
||||
出于隐私考虑,默认禁用。
|
||||
|
||||
6. `napcat`: Napcat的部署配置。
|
||||
|
||||
考虑到复用外部Napcat实例的情况,Napcat部署已被解耦。用户可选是否要部署Napcat。
|
||||
|
||||
默认会捆绑部署Napcat。
|
||||
|
||||
7. `sqlite_web`: sqlite-web的部署配置。
|
||||
|
||||
通过sqlite-web可以在网页上操作麦麦的数据库,方便调试。不部署对麦麦的运行无影响。
|
||||
|
||||
此服务如果暴露在公网会十分危险,默认不会部署。
|
||||
|
||||
8. `config`: 这里填写麦麦各部分组件的运行配置。
|
||||
|
||||
这里填写的配置仅会在初次部署时或用户指定时覆盖实际配置文件,且需要严格遵守yaml文件的缩进格式。
|
||||
|
||||
- `override_*_config`: 指定本次部署/升级是否用以下配置覆盖实际配置文件。默认不覆盖。
|
||||
|
||||
- `adapter_config`: 对应adapter的`config.toml`。
|
||||
|
||||
此配置文件中对于`napcat_server`和`maibot_server`的`host`和`port`字段的配置会被上面`adapter.service`中的配置覆盖,因此不需要改动。
|
||||
|
||||
- `core_model_config`: 对应core的`model_config.toml`。
|
||||
|
||||
- `core_bot_config`: 对应core的`bot_config.toml`。
|
||||
|
||||
## 部署说明
|
||||
|
||||
使用此Helm Chart的一些注意事项。
|
||||
|
||||
### 麦麦的配置
|
||||
|
||||
要修改麦麦的配置,最好的方法是通过WebUI来操作。此处的配置只会在初次部署时或者指定覆盖时注入到MaiBot中。
|
||||
|
||||
`0.11.6-beta`之前的版本将配置存储于k8s的ConfigMap资源中。随着版本迭代,MaiBot对配置文件的操作复杂性增加,k8s的适配复杂度也同步增加,且WebUI可以直接修改配置文件,因此自`0.11.6-beta`版本开始,各组件的配置不再存储于k8s的ConfigMap中,而是直接存储于存储卷的实际文件中。
|
||||
|
||||
从旧版本升级的用户,旧的ConfigMap的配置会自动迁移到新的存储卷的配置文件中。
|
||||
|
||||
### 部署时自动重置的配置
|
||||
|
||||
adapter的配置中的`napcat_server`和`maibot_server`的`host`和`port`字段,会在每次部署/更新Helm安装实例时被自动重置。
|
||||
core的配置中的`webui`和`maim_message`的部分字段也会在每次部署/更新Helm安装实例时被自动重置。
|
||||
|
||||
自动重置的原因:
|
||||
|
||||
- core的Service的DNS名称是动态的(由安装实例名拼接),无法在adapter的配置文件中提前确定。
|
||||
- 为了使adapter监听所有地址以及保持Helm Chart中配置的端口号,需要在adapter的配置文件中覆盖这些配置。
|
||||
- core的WebUI启停需要由helm chart控制,以便正常创建Service和Ingress资源。
|
||||
- core的maim_message的api server现在可以作为k8s服务暴露出来。监听的IP和端口需要由helm chart控制,以便Service正确映射。
|
||||
|
||||
首次部署时,预处理任务会负责重置这些配置。这会需要一些时间,因此部署进程可能比较慢,且部分Pod可能会无法启动,等待一分钟左右即可。
|
||||
|
||||
### 跨节点PVC挂载问题
|
||||
|
||||
MaiBot的一些组件会挂载同一PVC,这主要是为了同步数据或修改配置。
|
||||
|
||||
如果k8s集群有多个节点,且共享相同PVC的Pod未调度到同一节点,那么就需要此PVC访问模式具备`ReadWriteMany`访问模式。
|
||||
|
||||
不是所有存储控制器都支持`ReadWriteMany`访问模式。
|
||||
|
||||
如果你的存储控制器无法支持`ReadWriteMany`访问模式,你可以通过`nodeSelector`配置将彼此之间共享相同PVC的Pod调度到同一节点来避免问题。
|
||||
|
||||
会共享PVC的组件列表:
|
||||
|
||||
- `core`和`adapter`:共享`adapter-config`,用于为`core`的WebUI提供修改adapter的配置文件的能力。
|
||||
- `core`和`statistics-dashboard`:共享`statistics-dashboard`,用于同步统计数据的html文件。
|
||||
- `core`和`sqlite-web`:共享`maibot-core`,用于为`sqlite-web`提供操作MaiBot数据库的能力。
|
||||
- 部署时的预处理任务`preprocessor`和`adapter`、`core`:共享`adapter-config`和`core-config`,用于初始化`core`和`adapter`的配置文件。
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
WEBUI_HOST=0.0.0.0
|
||||
WEBUI_PORT=8001
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
#!/bin/sh
|
||||
# 此脚本用于覆盖core容器的默认启动命令,进行一些初始化
|
||||
# 由于k8s与docker-compose的卷挂载方式有所不同,需要利用此脚本为一些文件和目录提前创建好软链接
|
||||
# /MaiMBot/data是麦麦数据的实际挂载路径
|
||||
# /MaiMBot/statistics是统计数据的实际挂载路径
|
||||
|
||||
set -e
|
||||
echo "[K8s Init] Preparing volume..."
|
||||
|
||||
# 初次启动,在存储卷中检查并创建关键文件和目录
|
||||
mkdir -p /MaiMBot/data/plugins
|
||||
mkdir -p /MaiMBot/data/logs
|
||||
if [ ! -d "/MaiMBot/statistics" ]
|
||||
then
|
||||
echo "[K8s Init] Statistics volume is disabled."
|
||||
else
|
||||
touch /MaiMBot/statistics/index.html
|
||||
fi
|
||||
|
||||
# 删除默认插件目录,准备创建用户插件目录软链接
|
||||
rm -rf /MaiMBot/plugins
|
||||
|
||||
# 创建软链接,从存储卷链接到实际位置
|
||||
ln -s /MaiMBot/data/plugins /MaiMBot/plugins
|
||||
ln -s /MaiMBot/data/logs /MaiMBot/logs
|
||||
if [ -f "/MaiMBot/statistics/index.html" ]
|
||||
then
|
||||
ln -s /MaiMBot/statistics/index.html /MaiMBot/maibot_statistics.html
|
||||
fi
|
||||
|
||||
echo "[K8s Init] Volume ready."
|
||||
|
||||
# 启动麦麦
|
||||
echo "[K8s Init] Waking up MaiBot..."
|
||||
echo
|
||||
exec python bot.py
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
# 此镜像用于在部署helm chart时动态生成adapter的配置文件
|
||||
FROM python:3.13-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
COPY . /app
|
||||
|
||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENTRYPOINT ["python3", "preprocessor.py"]
|
||||
|
|
@ -1,266 +0,0 @@
|
|||
#!/bin/python3
|
||||
# 此脚本会被helm chart的post-install hook触发,在正式部署后通过k8s的job自动运行一次。
|
||||
# 这个脚本的作用是在部署helm chart时迁移旧版ConfigMap到配置文件,调整adapter的配置文件中的服务监听和服务连接字段,调整core的配置文件中的maim_message_api_server和WebUI配置。
|
||||
#
|
||||
# - 迁移旧版ConfigMap到配置文件是因为0.11.6-beta之前版本的helm chart将各个配置文件存储在k8s的ConfigMap中,
|
||||
# 由于功能复杂度提升,自0.11.6-beta版本开始配置文件采用文件形式存储到存储卷中。
|
||||
# 从旧版升级来的用户会通过这个脚本自动执行配置的迁移。
|
||||
#
|
||||
# - 需要调整adapter的配置文件的原因是:
|
||||
# 1. core的Service的DNS名称是动态的(由安装实例名拼接),无法在adapter的配置文件中提前确定。
|
||||
# 用于对外连接的maibot_server.host和maibot_server.port字段,会被替换为core的Service对应的DNS名称和8000端口(硬编码,用户无需配置)。
|
||||
# 2. 为了使adapter监听所有地址以及保持chart中配置的端口号,需要在adapter的配置文件中覆盖这些配置。
|
||||
# 用于监听的napcat_server.host和napcat_server.port字段,会被替换为0.0.0.0和8095端口(实际映射到的Service端口会在Service中配置)。
|
||||
#
|
||||
# - 需要调整core的配置文件的原因是:
|
||||
# 1. core的WebUI启停需要由helm chart控制,以便正常创建Service和Ingress资源。
|
||||
# 配置文件中的webui.enabled、webui.allowed_ips将由此脚本覆盖为正确配置。
|
||||
# 2. core的maim_message的api server现在可以作为k8s服务暴露出来。监听的IP和端口需要由helm chart控制,以便Service正确映射。
|
||||
# 配置文件中的maim_message.enable_api_server、maim_message.api_server_host、maim_message.api_server_port将由此脚本覆盖为正确配置。
|
||||
|
||||
import os
|
||||
import toml
|
||||
import time
|
||||
import base64
|
||||
from kubernetes import client, config
|
||||
from kubernetes.client.exceptions import ApiException
|
||||
from datetime import datetime, timezone
|
||||
|
||||
config.load_incluster_config()
|
||||
core_api = client.CoreV1Api()
|
||||
apps_api = client.AppsV1Api()
|
||||
|
||||
# 读取部署的关键信息
|
||||
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", 'r') as f:
|
||||
namespace = f.read().strip()
|
||||
release_name = os.getenv("RELEASE_NAME").strip()
|
||||
is_webui_enabled = os.getenv("IS_WEBUI_ENABLED").lower() == "true"
|
||||
is_maim_message_api_server_enabled = os.getenv("IS_MMSG_ENABLED").lower() == "true"
|
||||
config_adapter_b64 = os.getenv("CONFIG_ADAPTER_B64")
|
||||
config_core_env_b64 = os.getenv("CONFIG_CORE_ENV_B64")
|
||||
config_core_bot_b64 = os.getenv("CONFIG_CORE_BOT_B64")
|
||||
config_core_model_b64 = os.getenv("CONFIG_CORE_MODEL_B64")
|
||||
|
||||
|
||||
def log(func: str, msg: str, level: str = 'INFO'):
|
||||
print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] [{level}] [{func}] {msg}')
|
||||
|
||||
|
||||
def migrate_old_config():
|
||||
"""迁移旧版配置"""
|
||||
func_name = 'migrate_old_config'
|
||||
log(func_name, 'Checking whether there are old configmaps to migrate...')
|
||||
old_configmap_version = None
|
||||
status_migrating = { # 存储adapter的config.toml、core的bot_config.toml和model_config.toml三个文件的迁移状态
|
||||
'adapter_config.toml': False,
|
||||
'core_bot_config.toml': False,
|
||||
'core_model_config.toml': False
|
||||
}
|
||||
|
||||
# 如果存储卷中已存在配置文件,则跳过迁移
|
||||
if os.path.isfile('/app/config/core/bot_config.toml') or os.path.isfile('/app/config/core/model_config.toml') or \
|
||||
os.path.isfile('/app/config/adapter/config.toml'):
|
||||
log(func_name, 'Found existing config file(s) in PV. Migration will be ignored. Done.')
|
||||
return
|
||||
|
||||
def migrate_cm_to_file(cm_name: str, key_name: str, file_path: str) -> bool:
|
||||
"""检测是否有指定名称的configmap,如果有的话备份到指定的配置文件里并删除configmap,返回是否已备份"""
|
||||
try:
|
||||
cm = core_api.read_namespaced_config_map(
|
||||
name=cm_name,
|
||||
namespace=namespace
|
||||
)
|
||||
log(func_name, f'\tMigrating `{key_name}` of `{cm_name}`...')
|
||||
with open(file_path, 'w', encoding='utf-8') as _f:
|
||||
_f.write(cm.data[key_name])
|
||||
core_api.delete_namespaced_config_map(
|
||||
name=cm_name,
|
||||
namespace=namespace
|
||||
)
|
||||
log(func_name, f'\tSuccessfully migrated `{key_name}` of `{cm_name}`.')
|
||||
except ApiException as e:
|
||||
if e.status == 404:
|
||||
return False
|
||||
return True
|
||||
|
||||
# 对于0.11.5-beta版本,adapter的config.toml、core的bot_config.toml和model_config.toml均存储于不同的ConfigMap,需要依次迁移
|
||||
if True not in status_migrating.values():
|
||||
status_migrating['adapter_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-adapter-config',
|
||||
'config.toml',
|
||||
'/app/config/adapter/config.toml')
|
||||
status_migrating['core_bot_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-core-bot-config',
|
||||
'bot_config.toml',
|
||||
'/app/config/core/bot_config.toml')
|
||||
status_migrating['core_model_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-core-model-config',
|
||||
'model_config.toml',
|
||||
'/app/config/core/model_config.toml')
|
||||
if True in status_migrating.values():
|
||||
old_configmap_version = '0.11.5-beta'
|
||||
|
||||
# 对于低于0.11.5-beta的版本,adapter的1个配置和core的3个配置位于各自的configmap中
|
||||
if True not in status_migrating.values():
|
||||
status_migrating['adapter_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-adapter',
|
||||
'config.toml',
|
||||
'/app/config/adapter/config.toml')
|
||||
status_migrating['core_bot_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-core',
|
||||
'bot_config.toml',
|
||||
'/app/config/core/bot_config.toml')
|
||||
status_migrating['core_model_config.toml'] = migrate_cm_to_file(f'{release_name}-maibot-core',
|
||||
'model_config.toml',
|
||||
'/app/config/core/model_config.toml')
|
||||
if True in status_migrating.values():
|
||||
old_configmap_version = 'before 0.11.5-beta'
|
||||
|
||||
if old_configmap_version:
|
||||
log(func_name, f'Migrating status for version `{old_configmap_version}`:')
|
||||
for k, v in status_migrating.items():
|
||||
log(func_name, f'\t{k}: {v}')
|
||||
if False in status_migrating.values():
|
||||
log(func_name, 'There is/are config(s) that not been migrated. Please check the config manually.',
|
||||
level='WARNING')
|
||||
else:
|
||||
log(func_name, 'Successfully migrated old configs. Done.')
|
||||
else:
|
||||
log(func_name, 'Old config not found. Ignoring migration. Done.')
|
||||
|
||||
|
||||
def write_config_files():
|
||||
"""当注入了配置文件时(一般是首次安装或者用户指定覆盖),将helm chart注入的配置写入存储卷中的实际文件"""
|
||||
func_name = 'write_config_files'
|
||||
log(func_name, 'Detecting config files...')
|
||||
if config_adapter_b64:
|
||||
log(func_name, '\tWriting `config.toml` of adapter...')
|
||||
config_str = base64.b64decode(config_adapter_b64).decode("utf-8")
|
||||
with open('/app/config/adapter/config.toml', 'w', encoding='utf-8') as _f:
|
||||
_f.write(config_str)
|
||||
log(func_name, '\t`config.toml` of adapter wrote.')
|
||||
if True: # .env直接覆盖
|
||||
log(func_name, '\tWriting .env file of core...')
|
||||
config_str = base64.b64decode(config_core_env_b64).decode("utf-8")
|
||||
with open('/app/config/core/.env', 'w', encoding='utf-8') as _f:
|
||||
_f.write(config_str)
|
||||
log(func_name, '\t`.env` of core wrote.')
|
||||
if config_core_bot_b64:
|
||||
log(func_name, '\tWriting `bot_config.toml` of core...')
|
||||
config_str = base64.b64decode(config_core_bot_b64).decode("utf-8")
|
||||
with open('/app/config/core/bot_config.toml', 'w', encoding='utf-8') as _f:
|
||||
_f.write(config_str)
|
||||
log(func_name, '\t`bot_config.toml` of core wrote.')
|
||||
if config_core_model_b64:
|
||||
log(func_name, '\tWriting `model_config.toml` of core...')
|
||||
config_str = base64.b64decode(config_core_model_b64).decode("utf-8")
|
||||
with open('/app/config/core/model_config.toml', 'w', encoding='utf-8') as _f:
|
||||
_f.write(config_str)
|
||||
log(func_name, '\t`model_config.toml` of core wrote.')
|
||||
log(func_name, 'Detection done.')
|
||||
|
||||
|
||||
def reconfigure_adapter():
|
||||
"""调整adapter的配置文件的napcat_server和maibot_server字段,使其Service能被napcat连接以及连接到core的Service"""
|
||||
func_name = 'reconfigure_adapter'
|
||||
log(func_name, 'Reconfiguring `config.toml` of adapter...')
|
||||
with open('/app/config/adapter/config.toml', 'r', encoding='utf-8') as _f:
|
||||
config_adapter = toml.load(_f)
|
||||
config_adapter.setdefault('napcat_server', {})
|
||||
config_adapter['napcat_server']['host'] = '0.0.0.0'
|
||||
config_adapter['napcat_server']['port'] = 8095
|
||||
config_adapter.setdefault('maibot_server', {})
|
||||
config_adapter['maibot_server']['host'] = f'{release_name}-maibot-core' # 根据release名称动态拼接core服务的DNS名称
|
||||
config_adapter['maibot_server']['port'] = 8000
|
||||
with open('/app/config/adapter/config.toml', 'w', encoding='utf-8') as _f:
|
||||
_f.write(toml.dumps(config_adapter))
|
||||
log(func_name, 'Reconfiguration done.')
|
||||
|
||||
|
||||
def reconfigure_core():
|
||||
"""调整core的配置文件的webui和maim_message字段,使其服务能被正确映射"""
|
||||
func_name = 'reconfigure_core'
|
||||
log(func_name, 'Reconfiguring `bot_config.toml` of core...')
|
||||
with open('/app/config/core/bot_config.toml', 'r', encoding='utf-8') as _f:
|
||||
config_core = toml.load(_f)
|
||||
config_core.setdefault('webui', {})
|
||||
config_core['webui']['enabled'] = is_webui_enabled
|
||||
config_core['webui']['allowed_ips'] = '0.0.0.0/0' # 部署于k8s内网,使用宽松策略
|
||||
config_core.setdefault('maim_message', {})
|
||||
config_core['maim_message']['enable_api_server'] = is_maim_message_api_server_enabled
|
||||
config_core['maim_message']['api_server_host'] = '0.0.0.0'
|
||||
config_core['maim_message']['api_server_port'] = 8090
|
||||
with open('/app/config/core/bot_config.toml', 'w', encoding='utf-8') as _f:
|
||||
_f.write(toml.dumps(config_core))
|
||||
log(func_name, 'Reconfiguration done.')
|
||||
|
||||
|
||||
def _scale_statefulsets(statefulsets: list[str], replicas: int, wait: bool = False, timeout: int = 300):
|
||||
"""调整指定几个statefulset的副本数,wait参数控制是否等待调整完成再返回"""
|
||||
statefulsets = set(statefulsets)
|
||||
for name in statefulsets:
|
||||
apps_api.patch_namespaced_stateful_set_scale(
|
||||
name=name,
|
||||
namespace=namespace,
|
||||
body={"spec": {"replicas": replicas}}
|
||||
)
|
||||
if not wait:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
while True:
|
||||
remaining_pods = []
|
||||
|
||||
pods = core_api.list_namespaced_pod(namespace).items
|
||||
|
||||
for pod in pods:
|
||||
owners = pod.metadata.owner_references or []
|
||||
for owner in owners:
|
||||
if owner.kind == "StatefulSet" and owner.name in statefulsets:
|
||||
remaining_pods.append(pod.metadata.name)
|
||||
|
||||
if not remaining_pods:
|
||||
return
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for Pods to be deleted. "
|
||||
f"Remaining Pods: {remaining_pods}"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def _restart_statefulset(name: str, ignore_error: bool = False):
|
||||
"""重启指定的statefulset"""
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
body = {
|
||||
"spec": {
|
||||
"template": {
|
||||
"metadata": {
|
||||
"annotations": {
|
||||
"kubectl.kubernetes.io/restartedAt": now
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
try:
|
||||
apps_api.patch_namespaced_stateful_set(
|
||||
name=name,
|
||||
namespace=namespace,
|
||||
body=body
|
||||
)
|
||||
except ApiException as e:
|
||||
if ignore_error:
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
log('main', 'Start to process data before install/upgrade...')
|
||||
log('main', 'Scaling adapter and core to 0...')
|
||||
_scale_statefulsets([f'{release_name}-maibot-adapter', f'{release_name}-maibot-core'], 0, wait=True)
|
||||
migrate_old_config()
|
||||
write_config_files()
|
||||
reconfigure_adapter()
|
||||
reconfigure_core()
|
||||
log('main', 'Scaling adapter and core to 1...')
|
||||
_scale_statefulsets([f'{release_name}-maibot-adapter', f'{release_name}-maibot-core'], 1)
|
||||
log('main', 'Process done.')
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
toml~=0.10.2
|
||||
kubernetes~=34.1.0
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
MaiBot has been successfully deployed.
|
||||
|
||||
MaiBot on GitHub: https://github.com/Mai-with-u/MaiBot
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-adapter
|
||||
namespace: {{ .Release.Namespace }}
|
||||
spec:
|
||||
{{- if .Values.adapter.persistence.data.accessModes }}
|
||||
accessModes:
|
||||
{{ toYaml .Values.adapter.persistence.data.accessModes | nindent 4 }}
|
||||
{{- end }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.adapter.persistence.data.size }}
|
||||
{{- if .Values.adapter.persistence.data.storageClass }}
|
||||
storageClassName: {{ .Values.adapter.persistence.data.storageClass | default nil }}
|
||||
{{- end }}
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-adapter-config
|
||||
namespace: {{ .Release.Namespace }}
|
||||
spec:
|
||||
{{- if .Values.adapter.persistence.config.accessModes }}
|
||||
accessModes:
|
||||
{{ toYaml .Values.adapter.persistence.config.accessModes | nindent 4 }}
|
||||
{{- end }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.adapter.persistence.config.size }}
|
||||
{{- if .Values.adapter.persistence.config.storageClass }}
|
||||
storageClassName: {{ .Values.adapter.persistence.config.storageClass | default nil }}
|
||||
{{- end }}
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-adapter
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-adapter
|
||||
spec:
|
||||
ports:
|
||||
- name: napcat-ws
|
||||
port: {{ .Values.adapter.service.port }}
|
||||
protocol: TCP
|
||||
targetPort: 8095
|
||||
{{- if eq .Values.adapter.service.type "NodePort" }}
|
||||
nodePort: {{ .Values.adapter.service.nodePort | default nil }}
|
||||
{{- end }}
|
||||
selector:
|
||||
app: {{ .Release.Name }}-maibot-adapter
|
||||
type: {{ .Values.adapter.service.type }}
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-adapter
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-adapter
|
||||
spec:
|
||||
serviceName: {{ .Release.Name }}-maibot-adapter
|
||||
replicas: 0 # post-install任务初始化完毕后自动扩容至1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Release.Name }}-maibot-adapter
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-adapter
|
||||
spec:
|
||||
containers:
|
||||
- name: adapter
|
||||
env:
|
||||
- name: TZ
|
||||
value: Asia/Shanghai
|
||||
image: {{ .Values.adapter.image.repository | default "unclas/maimbot-adapter" }}:{{ .Values.adapter.image.tag | default "main-20251211074617" }}
|
||||
imagePullPolicy: {{ .Values.adapter.image.pullPolicy }}
|
||||
ports:
|
||||
- containerPort: 8095
|
||||
name: napcat-ws
|
||||
protocol: TCP
|
||||
{{- if .Values.adapter.resources }}
|
||||
resources:
|
||||
{{ toYaml .Values.adapter.resources | nindent 12 }}
|
||||
{{- end }}
|
||||
volumeMounts:
|
||||
- mountPath: /adapters/data
|
||||
name: data
|
||||
- mountPath: /adapters/config.toml
|
||||
name: config
|
||||
subPath: config.toml
|
||||
{{- if .Values.adapter.image.pullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{ toYaml .Values.adapter.image.pullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.adapter.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{ toYaml .Values.adapter.nodeSelector | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.adapter.tolerations }}
|
||||
tolerations:
|
||||
{{ toYaml .Values.adapter.tolerations | nindent 8 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: data
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-adapter
|
||||
- name: config
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-adapter-config
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
{{- if and .Values.core.webui.enabled .Values.core.webui.ingress.enabled }}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-webui
|
||||
namespace: {{ .Release.Namespace }}
|
||||
{{- if .Values.core.webui.ingress.annotations }}
|
||||
annotations:
|
||||
{{ toYaml .Values.core.webui.ingress.annotations | nindent 4 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-core
|
||||
spec:
|
||||
ingressClassName: {{ .Values.core.webui.ingress.className }}
|
||||
rules:
|
||||
- host: {{ .Values.core.webui.ingress.host }}
|
||||
http:
|
||||
paths:
|
||||
- backend:
|
||||
service:
|
||||
name: {{ .Release.Name }}-maibot-core
|
||||
port:
|
||||
number: {{ .Values.core.webui.service.port }}
|
||||
path: {{ .Values.core.webui.ingress.path }}
|
||||
pathType: {{ .Values.core.webui.ingress.pathType }}
|
||||
{{- end }}
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-core
|
||||
namespace: {{ .Release.Namespace }}
|
||||
spec:
|
||||
{{- if .Values.core.persistence.data.accessModes }}
|
||||
accessModes:
|
||||
{{ toYaml .Values.core.persistence.data.accessModes | nindent 4 }}
|
||||
{{- end }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.core.persistence.data.size }}
|
||||
{{- if .Values.core.persistence.data.storageClass }}
|
||||
storageClassName: {{ .Values.core.persistence.data.storageClass | default nil }}
|
||||
{{- end }}
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-core-config
|
||||
namespace: {{ .Release.Namespace }}
|
||||
spec:
|
||||
{{- if .Values.core.persistence.config.accessModes }}
|
||||
accessModes:
|
||||
{{ toYaml .Values.core.persistence.config.accessModes | nindent 4 }}
|
||||
{{- end }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.core.persistence.config.size }}
|
||||
{{- if .Values.core.persistence.config.storageClass }}
|
||||
storageClassName: {{ .Values.core.persistence.config.storageClass | default nil }}
|
||||
{{- end }}
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-core
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-core
|
||||
spec:
|
||||
ports:
|
||||
- name: adapter-ws
|
||||
port: 8000
|
||||
protocol: TCP
|
||||
targetPort: 8000
|
||||
{{- if .Values.core.webui.enabled }}
|
||||
- name: webui
|
||||
port: {{ .Values.core.webui.service.port }}
|
||||
protocol: TCP
|
||||
targetPort: 8001
|
||||
{{- if eq .Values.core.webui.service.type "NodePort" }}
|
||||
nodePort: {{ .Values.core.webui.service.nodePort | default nil }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- if .Values.core.maim_message_api_server.enabled }}
|
||||
- name: maim-message
|
||||
port: {{ .Values.core.maim_message_api_server.service.port }}
|
||||
protocol: TCP
|
||||
targetPort: 8090
|
||||
{{- if eq .Values.core.maim_message_api_server.service.type "NodePort" }}
|
||||
nodePort: {{ .Values.core.maim_message_api_server.service.nodePort | default nil }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
selector:
|
||||
app: {{ .Release.Name }}-maibot-core
|
||||
type: ClusterIP
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-core
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-core
|
||||
spec:
|
||||
serviceName: {{ .Release.Name }}-maibot-core
|
||||
replicas: 0 # post-install任务初始化完毕后自动扩容至1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Release.Name }}-maibot-core
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-core
|
||||
spec:
|
||||
containers:
|
||||
- name: core
|
||||
command: # 为了在k8s中初始化,这里替换启动命令为指定脚本
|
||||
- sh
|
||||
args:
|
||||
- /MaiMBot/k8s-init.sh
|
||||
env:
|
||||
- name: TZ
|
||||
value: "Asia/Shanghai"
|
||||
- name: EULA_AGREE
|
||||
value: "1b662741904d7155d1ce1c00b3530d0d"
|
||||
- name: PRIVACY_AGREE
|
||||
value: "9943b855e72199d0f5016ea39052f1b6"
|
||||
image: {{ .Values.core.image.repository | default "sengokucola/maibot" }}:{{ .Values.core.image.tag | default "0.12.0" }}
|
||||
imagePullPolicy: {{ .Values.core.image.pullPolicy }}
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
name: adapter-ws
|
||||
protocol: TCP
|
||||
{{- if .Values.core.webui.enabled }}
|
||||
- containerPort: 8001
|
||||
name: webui
|
||||
protocol: TCP
|
||||
{{- end }}
|
||||
{{- if .Values.core.resources }}
|
||||
resources:
|
||||
{{ toYaml .Values.core.resources | nindent 12 }}
|
||||
{{- end }}
|
||||
volumeMounts:
|
||||
- mountPath: /MaiMBot/data
|
||||
name: data
|
||||
- mountPath: /MaiMBot/k8s-init.sh
|
||||
name: scripts
|
||||
readOnly: true
|
||||
subPath: k8s-init.sh
|
||||
- mountPath: /MaiMBot/.env
|
||||
name: config
|
||||
subPath: .env
|
||||
- mountPath: /MaiMBot/config/model_config.toml
|
||||
name: config
|
||||
subPath: model_config.toml
|
||||
- mountPath: /MaiMBot/config/bot_config.toml
|
||||
name: config
|
||||
subPath: bot_config.toml
|
||||
- mountPath: /MaiMBot/adapters-config/config.toml # WebUI修改adapter配置所用
|
||||
name: adapter-config
|
||||
subPath: config.toml
|
||||
{{- if .Values.statistics_dashboard.enabled }}
|
||||
- mountPath: /MaiMBot/statistics
|
||||
name: statistics
|
||||
{{- end }}
|
||||
serviceAccountName: {{ .Release.Name }}-maibot-sa
|
||||
{{- if .Values.core.image.pullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{ toYaml .Values.core.image.pullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.core.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{ toYaml .Values.core.nodeSelector | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.core.tolerations }}
|
||||
tolerations:
|
||||
{{ toYaml .Values.core.tolerations | nindent 8 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: data
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-core
|
||||
- configMap:
|
||||
items:
|
||||
- key: k8s-init.sh
|
||||
path: k8s-init.sh
|
||||
name: {{ .Release.Name }}-maibot-scripts
|
||||
name: scripts
|
||||
- name: config
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-core-config
|
||||
- name: adapter-config
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-adapter-config
|
||||
{{- if .Values.statistics_dashboard.enabled }}
|
||||
- name: statistics
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
{{- end }}
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
{{- if and .Values.napcat.enabled .Values.napcat.ingress.enabled }}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-napcat
|
||||
namespace: {{ .Release.Namespace }}
|
||||
{{- if .Values.napcat.ingress.annotations }}
|
||||
annotations:
|
||||
{{ toYaml .Values.napcat.ingress.annotations | nindent 4 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-napcat
|
||||
spec:
|
||||
ingressClassName: {{ .Values.napcat.ingress.className }}
|
||||
rules:
|
||||
- host: {{ .Values.napcat.ingress.host }}
|
||||
http:
|
||||
paths:
|
||||
- backend:
|
||||
service:
|
||||
name: {{ .Release.Name }}-maibot-napcat
|
||||
port:
|
||||
number: {{ .Values.napcat.service.port }}
|
||||
path: {{ .Values.napcat.ingress.path }}
|
||||
pathType: {{ .Values.napcat.ingress.pathType }}
|
||||
{{- end }}
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
{{- if .Values.napcat.enabled }}
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-napcat
|
||||
namespace: {{ .Release.Namespace }}
|
||||
spec:
|
||||
{{- if .Values.napcat.persistence.accessModes }}
|
||||
accessModes:
|
||||
{{ toYaml .Values.napcat.persistence.accessModes | nindent 4 }}
|
||||
{{- end }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.napcat.persistence.size }}
|
||||
{{- if .Values.napcat.persistence.storageClass }}
|
||||
storageClassName: {{ .Values.napcat.persistence.storageClass | default nil }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
{{- if .Values.napcat.enabled }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-napcat
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-napcat
|
||||
spec:
|
||||
ports:
|
||||
- name: webui
|
||||
port: {{ .Values.napcat.service.port }}
|
||||
protocol: TCP
|
||||
targetPort: 6099
|
||||
{{- if eq .Values.napcat.service.type "NodePort" }}
|
||||
nodePort: {{ .Values.napcat.service.nodePort | default nil }}
|
||||
{{- end }}
|
||||
selector:
|
||||
app: {{ .Release.Name }}-maibot-napcat
|
||||
type: {{ .Values.napcat.service.type }}
|
||||
{{- end }}
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
{{- if .Values.napcat.enabled }}
|
||||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-napcat
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-napcat
|
||||
spec:
|
||||
serviceName: {{ .Release.Name }}-maibot-napcat
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Release.Name }}-maibot-napcat
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-napcat
|
||||
spec:
|
||||
containers:
|
||||
- name: napcat
|
||||
env:
|
||||
- name: NAPCAT_GID
|
||||
value: "{{ .Values.napcat.permission.gid }}"
|
||||
- name: NAPCAT_UID
|
||||
value: "{{ .Values.napcat.permission.uid }}"
|
||||
- name: TZ
|
||||
value: Asia/Shanghai
|
||||
image: {{ .Values.napcat.image.repository | default "mlikiowa/napcat-docker" }}:{{ .Values.napcat.image.tag | default "v4.9.80" }}
|
||||
imagePullPolicy: {{ .Values.napcat.image.pullPolicy }}
|
||||
livenessProbe:
|
||||
failureThreshold: 3
|
||||
httpGet:
|
||||
path: /
|
||||
port: 6099
|
||||
scheme: HTTP
|
||||
initialDelaySeconds: 60
|
||||
periodSeconds: 60
|
||||
successThreshold: 1
|
||||
timeoutSeconds: 10
|
||||
ports:
|
||||
- containerPort: 6099
|
||||
name: webui
|
||||
protocol: TCP
|
||||
{{- if .Values.napcat.resources }}
|
||||
resources:
|
||||
{{ toYaml .Values.napcat.resources | nindent 12 }}
|
||||
{{- end }}
|
||||
volumeMounts:
|
||||
- mountPath: /app/napcat/config
|
||||
name: napcat
|
||||
subPath: config
|
||||
- mountPath: /app/.config/QQ
|
||||
name: napcat
|
||||
subPath: data
|
||||
{{- if .Values.napcat.image.pullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{ toYaml .Values.napcat.image.pullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.napcat.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{ toYaml .Values.napcat.nodeSelector | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.napcat.tolerations }}
|
||||
tolerations:
|
||||
{{ toYaml .Values.napcat.tolerations | nindent 8 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: napcat
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-napcat
|
||||
{{- end }}
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
# 检查EULA和PRIVACY
|
||||
{{- if not .Values.EULA_AGREE }}
|
||||
{{ fail "You must accept the EULA by setting 'EULA_AGREE: true'. EULA: https://github.com/Mai-with-u/MaiBot/blob/main/EULA.md" }}
|
||||
{{- end }}
|
||||
|
||||
{{- if not .Values.PRIVACY_AGREE }}
|
||||
{{ fail "You must accept the Privacy Policy by setting 'PRIVACY_AGREE: true'. Privacy Policy: https://github.com/Mai-with-u/MaiBot/blob/main/PRIVACY.md" }}
|
||||
{{- end }}
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-scripts
|
||||
namespace: {{ .Release.Namespace }}
|
||||
data:
|
||||
# core
|
||||
k8s-init.sh: |
|
||||
{{ .Files.Get "files/k8s-init.sh" | nindent 4 }}
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
# 预处理脚本,仅会在部署前运行一次
|
||||
apiVersion: batch/v1
|
||||
kind: Job
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-preprocessor
|
||||
namespace: {{ .Release.Namespace }}
|
||||
annotations:
|
||||
"helm.sh/hook": post-install,post-upgrade
|
||||
"helm.sh/hook-delete-policy": before-hook-creation,hook-succeeded
|
||||
spec:
|
||||
backoffLimit: 2
|
||||
template:
|
||||
spec:
|
||||
serviceAccountName: {{ .Release.Name }}-maibot-sa
|
||||
restartPolicy: Never
|
||||
containers:
|
||||
- name: preprocessor
|
||||
image: {{ .Values.pre_processor.image.repository | default "reg.mikumikumi.xyz/maibot/preprocessor" }}:{{ .Values.pre_processor.image.tag | default "0.12.0" }}
|
||||
imagePullPolicy: {{ .Values.pre_processor.image.pullPolicy }}
|
||||
env:
|
||||
- name: RELEASE_NAME
|
||||
value: {{ .Release.Name }}
|
||||
- name: IS_WEBUI_ENABLED
|
||||
value: {{ .Values.core.webui.enabled | quote }}
|
||||
- name: IS_MMSG_ENABLED
|
||||
value: {{ .Values.core.maim_message_api_server.enabled | quote }}
|
||||
{{- if or .Values.config.override_adapter_config .Release.IsInstall }}
|
||||
- name: CONFIG_ADAPTER_B64
|
||||
value: {{ .Values.config.adapter_config | b64enc | quote }}
|
||||
{{- end }}
|
||||
- name: CONFIG_CORE_ENV_B64
|
||||
value: {{ tpl (.Files.Get "files/.env") . | b64enc | quote }}
|
||||
{{- if or .Values.config.override_core_bot_config .Release.IsInstall }}
|
||||
- name: CONFIG_CORE_BOT_B64
|
||||
value: {{ .Values.config.core_bot_config | b64enc | quote }}
|
||||
{{- end }}
|
||||
{{- if or .Values.config.override_core_model_config .Release.IsInstall }}
|
||||
- name: CONFIG_CORE_MODEL_B64
|
||||
value: {{ .Values.config.core_model_config | b64enc | quote }}
|
||||
{{- end }}
|
||||
volumeMounts:
|
||||
- mountPath: /app/config/adapter
|
||||
name: adapter-config
|
||||
- mountPath: /app/config/core
|
||||
name: core-config
|
||||
imagePullSecrets:
|
||||
{{ toYaml .Values.pre_processor.image.pullSecrets | nindent 8 }}
|
||||
{{- if .Values.pre_processor.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{ toYaml .Values.pre_processor.nodeSelector | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.pre_processor.tolerations }}
|
||||
tolerations:
|
||||
{{ toYaml .Values.pre_processor.tolerations | nindent 8 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: adapter-config
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-adapter-config
|
||||
- name: core-config
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-core-config
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
# 初始化及反向修改ConfigMap所需要的rbac授权
|
||||
apiVersion: v1
|
||||
kind: ServiceAccount
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-sa
|
||||
namespace: {{ .Release.Namespace }}
|
||||
---
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: Role
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-role
|
||||
namespace: {{ .Release.Namespace }}
|
||||
rules:
|
||||
- apiGroups: [""]
|
||||
resources: ["configmaps", "pods"]
|
||||
verbs: ["get", "list", "delete"]
|
||||
- apiGroups: ["apps"]
|
||||
resources: ["statefulsets"]
|
||||
verbs: ["get", "list", "update", "patch"]
|
||||
- apiGroups: ["apps"]
|
||||
resources: ["statefulsets/scale"]
|
||||
verbs: ["get", "patch", "update"]
|
||||
---
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: RoleBinding
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-rolebinding
|
||||
namespace: {{ .Release.Namespace }}
|
||||
subjects:
|
||||
- kind: ServiceAccount
|
||||
name: {{ .Release.Name }}-maibot-sa
|
||||
namespace: {{ .Release.Namespace }}
|
||||
roleRef:
|
||||
kind: Role
|
||||
name: {{ .Release.Name }}-maibot-role
|
||||
apiGroup: rbac.authorization.k8s.io
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
{{- if and .Values.sqlite_web.enabled .Values.sqlite_web.ingress.enabled }}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-sqlite-web
|
||||
namespace: {{ .Release.Namespace }}
|
||||
{{- if .Values.sqlite_web.ingress.annotations }}
|
||||
annotations:
|
||||
{{ toYaml .Values.sqlite_web.ingress.annotations | nindent 4 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-sqlite-web
|
||||
spec:
|
||||
ingressClassName: {{ .Values.sqlite_web.ingress.className }}
|
||||
rules:
|
||||
- host: {{ .Values.sqlite_web.ingress.host }}
|
||||
http:
|
||||
paths:
|
||||
- backend:
|
||||
service:
|
||||
name: {{ .Release.Name }}-maibot-sqlite-web
|
||||
port:
|
||||
number: {{ .Values.sqlite_web.service.port }}
|
||||
path: {{ .Values.sqlite_web.ingress.path }}
|
||||
pathType: {{ .Values.sqlite_web.ingress.pathType }}
|
||||
{{- end }}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
{{- if .Values.sqlite_web.enabled }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-sqlite-web
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-sqlite-web
|
||||
spec:
|
||||
ports:
|
||||
- name: webui
|
||||
port: {{ .Values.sqlite_web.service.port }}
|
||||
protocol: TCP
|
||||
targetPort: 8080
|
||||
{{- if eq .Values.sqlite_web.service.type "NodePort" }}
|
||||
nodePort: {{ .Values.sqlite_web.service.nodePort | default nil }}
|
||||
{{- end }}
|
||||
selector:
|
||||
app: {{ .Release.Name }}-maibot-sqlite-web
|
||||
type: {{ .Values.sqlite_web.service.type }}
|
||||
{{- end }}
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
{{- if .Values.sqlite_web.enabled }}
|
||||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-sqlite-web
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-sqlite-web
|
||||
spec:
|
||||
serviceName: {{ .Release.Name }}-maibot-sqlite-web
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Release.Name }}-maibot-sqlite-web
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-sqlite-web
|
||||
spec:
|
||||
containers:
|
||||
- name: sqlite-web
|
||||
env:
|
||||
- name: SQLITE_DATABASE
|
||||
value: /data/MaiMBot/MaiBot.db
|
||||
image: {{ .Values.sqlite_web.image.repository | default "coleifer/sqlite-web" }}:{{ .Values.sqlite_web.image.tag | default "latest" }}
|
||||
imagePullPolicy: {{ .Values.sqlite_web.image.pullPolicy }}
|
||||
livenessProbe:
|
||||
failureThreshold: 3
|
||||
httpGet:
|
||||
path: /
|
||||
port: 8080
|
||||
scheme: HTTP
|
||||
initialDelaySeconds: 60
|
||||
periodSeconds: 60
|
||||
successThreshold: 1
|
||||
timeoutSeconds: 10
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
name: webui
|
||||
protocol: TCP
|
||||
{{- if .Values.sqlite_web.resources }}
|
||||
resources:
|
||||
{{ toYaml .Values.sqlite_web.resources | nindent 12 }}
|
||||
{{- end }}
|
||||
volumeMounts:
|
||||
- mountPath: /data/MaiMBot
|
||||
name: data
|
||||
{{- if .Values.sqlite_web.image.pullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{ toYaml .Values.sqlite_web.image.pullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.sqlite_web.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{ toYaml .Values.sqlite_web.nodeSelector | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.sqlite_web.tolerations }}
|
||||
tolerations:
|
||||
{{ toYaml .Values.sqlite_web.tolerations | nindent 8 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: data
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-core
|
||||
{{- end }}
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
{{- if .Values.statistics_dashboard.enabled }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
spec:
|
||||
replicas: {{ .Values.statistics_dashboard.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
spec:
|
||||
containers:
|
||||
- name: nginx
|
||||
image: {{ .Values.statistics_dashboard.image.repository | default "nginx" }}:{{ .Values.statistics_dashboard.image.tag | default "latest" }}
|
||||
imagePullPolicy: {{ .Values.statistics_dashboard.image.pullPolicy }}
|
||||
livenessProbe:
|
||||
failureThreshold: 3
|
||||
httpGet:
|
||||
path: /
|
||||
port: 80
|
||||
scheme: HTTP
|
||||
initialDelaySeconds: 60
|
||||
periodSeconds: 10
|
||||
successThreshold: 1
|
||||
timeoutSeconds: 1
|
||||
ports:
|
||||
- containerPort: 80
|
||||
name: dashboard
|
||||
protocol: TCP
|
||||
{{- if .Values.statistics_dashboard.resources }}
|
||||
resources:
|
||||
{{ toYaml .Values.statistics_dashboard.resources | nindent 12 }}
|
||||
{{- end }}
|
||||
volumeMounts:
|
||||
- mountPath: /usr/share/nginx/html
|
||||
name: statistics
|
||||
readOnly: true
|
||||
{{- if .Values.statistics_dashboard.image.pullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{ toYaml .Values.statistics_dashboard.image.pullSecrets | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.statistics_dashboard.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{ toYaml .Values.statistics_dashboard.nodeSelector | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- if .Values.statistics_dashboard.tolerations }}
|
||||
tolerations:
|
||||
{{ toYaml .Values.statistics_dashboard.tolerations | nindent 8 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: statistics
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
{{- end }}
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
{{- if and .Values.statistics_dashboard.enabled .Values.statistics_dashboard.ingress.enabled }}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
namespace: {{ .Release.Namespace }}
|
||||
{{- if .Values.statistics_dashboard.ingress.annotations }}
|
||||
annotations:
|
||||
{{ toYaml .Values.statistics_dashboard.ingress.annotations | nindent 4 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
spec:
|
||||
ingressClassName: {{ .Values.statistics_dashboard.ingress.className }}
|
||||
rules:
|
||||
- host: {{ .Values.statistics_dashboard.ingress.host }}
|
||||
http:
|
||||
paths:
|
||||
- backend:
|
||||
service:
|
||||
name: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
port:
|
||||
number: {{ .Values.statistics_dashboard.service.port }}
|
||||
path: {{ .Values.statistics_dashboard.ingress.path }}
|
||||
pathType: {{ .Values.statistics_dashboard.ingress.pathType }}
|
||||
{{- end }}
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
{{- if .Values.statistics_dashboard.enabled }}
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
namespace: {{ .Release.Namespace }}
|
||||
spec:
|
||||
{{- if .Values.statistics_dashboard.persistence.accessModes }}
|
||||
accessModes:
|
||||
{{ toYaml .Values.statistics_dashboard.persistence.accessModes | nindent 4 }}
|
||||
{{- end }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.statistics_dashboard.persistence.size }}
|
||||
{{- if .Values.statistics_dashboard.persistence.storageClass }}
|
||||
storageClassName: {{ .Values.statistics_dashboard.persistence.storageClass | default nil }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
{{- if .Values.statistics_dashboard.enabled }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
spec:
|
||||
ports:
|
||||
- name: dashboard
|
||||
port: {{ .Values.statistics_dashboard.service.port }}
|
||||
protocol: TCP
|
||||
targetPort: 80
|
||||
{{- if eq .Values.statistics_dashboard.service.type "NodePort" }}
|
||||
nodePort: {{ .Values.statistics_dashboard.service.nodePort | default nil }}
|
||||
{{- end }}
|
||||
selector:
|
||||
app: {{ .Release.Name }}-maibot-statistics-dashboard
|
||||
type: {{ .Values.statistics_dashboard.service.type }}
|
||||
{{- end }}
|
||||
|
|
@ -1,772 +0,0 @@
|
|||
# 只有同意了EULA和PRIVACY协议才可以部署麦麦
|
||||
# 配置以下的选项为true表示你同意了EULA和PRIVACY条款
|
||||
# https://github.com/MaiM-with-u/MaiBot/blob/main/EULA.md
|
||||
# https://github.com/MaiM-with-u/MaiBot/blob/main/PRIVACY.md
|
||||
EULA_AGREE: false
|
||||
PRIVACY_AGREE: false
|
||||
|
||||
# 预处理Job的配置
|
||||
pre_processor:
|
||||
image:
|
||||
repository: # 默认 reg.mikumikumi.xyz/maibot/preprocessor
|
||||
tag: # 默认 0.12.0
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: [ ]
|
||||
|
||||
nodeSelector: { }
|
||||
tolerations: [ ]
|
||||
|
||||
# 麦麦Adapter的部署配置
|
||||
adapter:
|
||||
|
||||
image:
|
||||
repository: # 默认 unclas/maimbot-adapter
|
||||
tag: # 默认 main-20251211074617
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: [ ]
|
||||
|
||||
resources: { }
|
||||
|
||||
nodeSelector: { }
|
||||
tolerations: [ ]
|
||||
|
||||
# 配置adapter的napcat websocket service
|
||||
# adapter会启动一个websocket服务端,用于与napcat通信
|
||||
# 这里的选项可以帮助你自定义服务端口
|
||||
# !!!默认不使用NodePort。如果通过NodePort将服务端口映射到公网可能会被恶意客户端连接,请自行使用中间件鉴权!!!
|
||||
service:
|
||||
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的websocket端口映射到物理节点的端口
|
||||
port: 8095 # websocket监听端口ClusterIP的端口
|
||||
nodePort: # 仅在设置NodePort类型时有效,不指定则会随机分配端口
|
||||
|
||||
persistence:
|
||||
config: # 配置文件的存储卷
|
||||
storageClass:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 10Mi
|
||||
data: # 数据的存储卷
|
||||
storageClass:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 1Gi
|
||||
|
||||
# 麦麦本体的部署配置
|
||||
core:
|
||||
|
||||
image:
|
||||
repository: # 默认 sengokucola/maibot
|
||||
tag: # 默认 0.12.0
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: [ ]
|
||||
|
||||
resources: { }
|
||||
|
||||
nodeSelector: { }
|
||||
tolerations: [ ]
|
||||
|
||||
persistence:
|
||||
config: # 配置文件的存储卷
|
||||
storageClass:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 10Mi
|
||||
data: # 数据的存储卷
|
||||
storageClass:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 10Gi
|
||||
|
||||
webui: # WebUI相关配置
|
||||
enabled: true # 默认启用
|
||||
service:
|
||||
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
|
||||
port: 8001 # 服务端口
|
||||
nodePort: # 仅在设置NodePort类型时有效,不指定则会随机分配端口
|
||||
ingress:
|
||||
enabled: false
|
||||
className: nginx
|
||||
annotations: { }
|
||||
host: maim.example.com # 访问麦麦WebUI的域名
|
||||
path: /
|
||||
pathType: Prefix
|
||||
|
||||
maim_message_api_server:
|
||||
enabled: false
|
||||
service:
|
||||
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
|
||||
port: 8090 # 服务端口
|
||||
nodePort: # 仅在设置NodePort类型时有效,不指定则会随机分配端口
|
||||
|
||||
|
||||
# 麦麦的运行统计看板配置
|
||||
# 麦麦每隔一段时间会自动输出html格式的运行统计报告,此统计报告可以作为静态网页访问
|
||||
# 此功能默认禁用。如果你认为报告可以被公开访问(报告包含联系人/群组名称、模型token花费信息等),则可以启用此功能
|
||||
# 如果启用此功能,你也可以考虑使用中间件进行鉴权,保护隐私信息
|
||||
statistics_dashboard:
|
||||
|
||||
enabled: false # 是否启用运行统计看板
|
||||
|
||||
replicaCount: 1
|
||||
|
||||
image:
|
||||
repository: # 默认 nginx
|
||||
tag: # 默认 latest
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: [ ]
|
||||
|
||||
resources: { }
|
||||
|
||||
nodeSelector: { }
|
||||
tolerations: [ ]
|
||||
|
||||
service:
|
||||
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
|
||||
port: 80 # 服务端口
|
||||
nodePort: # 仅在设置NodePort类型时有效,不指定则会随机分配端口
|
||||
ingress:
|
||||
enabled: false
|
||||
className: nginx
|
||||
annotations: { }
|
||||
host: maim-statistics.example.com # 访问运行统计看板的域名
|
||||
path: /
|
||||
pathType: Prefix
|
||||
|
||||
persistence:
|
||||
storageClass:
|
||||
# 如果你希望运行统计看板服务与麦麦本体运行在不同的节点(多活部署),那么需要ReadWriteMany访问模式
|
||||
# 注意:ReadWriteMany特性需要存储类底层支持
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 100Mi
|
||||
|
||||
# napcat的部署配置
|
||||
# !!!napcat部署完毕后,务必修改默认密码!!!
|
||||
napcat:
|
||||
|
||||
# 考虑到复用外部napcat实例的情况,napcat部署已被解耦
|
||||
# 如果你有外部部署的napcat,则可以修改下面的enabled为false,本次不会重复部署napcat
|
||||
# 如果没有外部部署的napcat,默认会捆绑部署napcat,不需要修改此项
|
||||
enabled: true
|
||||
|
||||
image:
|
||||
repository: # 默认 mlikiowa/napcat-docker
|
||||
tag: # 默认 v4.9.91
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: [ ]
|
||||
|
||||
resources: { }
|
||||
|
||||
nodeSelector: { }
|
||||
tolerations: [ ]
|
||||
|
||||
# napcat进程的权限,默认不是特权用户
|
||||
permission:
|
||||
uid: 1000
|
||||
gid: 1000
|
||||
|
||||
# 配置napcat web面板的service
|
||||
service:
|
||||
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
|
||||
port: 6099 # 服务端口
|
||||
nodePort: # 仅在设置NodePort类型时有效,不指定则会随机分配端口
|
||||
|
||||
# 配置napcat web面板的ingress
|
||||
ingress:
|
||||
enabled: false # 是否启用
|
||||
className: nginx
|
||||
annotations: { }
|
||||
host: napcat.example.com # 暴露napcat web面板使用的域名
|
||||
path: /
|
||||
pathType: Prefix
|
||||
|
||||
persistence:
|
||||
storageClass:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 5Gi
|
||||
|
||||
# sqlite-web的部署配置
|
||||
sqlite_web:
|
||||
|
||||
# 通过sqlite-web可以在网页上操作麦麦的数据库,方便调试。不部署对麦麦的运行无影响
|
||||
# 默认不会捆绑部署sqlite-web,如果你需要部署,请修改下面的enabled为true
|
||||
# !!!sqlite-web服务无鉴权,暴露在公网上十分危险,推荐使用集群ClusterIP内网访问!!!
|
||||
# !!!如果一定要暴露在公网,请自行使用中间件鉴权!!!
|
||||
enabled: false
|
||||
|
||||
image:
|
||||
repository: # 默认 coleifer/sqlite-web
|
||||
tag: # 默认 latest
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: [ ]
|
||||
|
||||
resources: { }
|
||||
|
||||
nodeSelector: { }
|
||||
tolerations: [ ]
|
||||
|
||||
# 配置sqlite-web面板的service
|
||||
# !!!默认不使用NodePort。如果使用NodePort暴露到公网,请自行使用中间件鉴权!!!
|
||||
service:
|
||||
type: ClusterIP # ClusterIP / NodePort 指定NodePort可以将内网的服务端口映射到物理节点的端口
|
||||
port: 8080 # 服务端口
|
||||
nodePort: # 仅在设置NodePort类型时有效,不指定则会随机分配端口
|
||||
|
||||
# 配置sqlite-web面板的ingress
|
||||
# !!!默认不使用ingress。如果使用ingress暴露到公网,请自行使用中间件鉴权!!!
|
||||
ingress:
|
||||
enabled: false # 是否启用
|
||||
className: nginx
|
||||
annotations: { }
|
||||
host: maim-sqlite.example.com # 暴露websocket使用的域名
|
||||
path: /
|
||||
pathType: Prefix
|
||||
|
||||
# 设置麦麦各部分组件的初始运行配置
|
||||
# 考虑到配置文件的操作复杂性增加,k8s的适配复杂度也同步增加,且WebUI可以直接修改配置文件
|
||||
# 自0.11.6-beta版本开始,各组件的配置不再存储于k8s的configmap中,而是直接存储于存储卷的实际文件中
|
||||
# 从旧版本升级的用户,旧的configmap的配置会自动迁移到新的存储卷的配置文件中
|
||||
# 此处的配置只在初次部署时或者指定覆盖时注入到MaiBot中
|
||||
config:
|
||||
|
||||
# 指定是否用下面的配置覆盖MaiBot现有的配置文件
|
||||
override_adapter_config: false
|
||||
override_core_bot_config: false
|
||||
override_core_model_config: false
|
||||
|
||||
# adapter的config.toml
|
||||
adapter_config: |
|
||||
[inner]
|
||||
version = "0.1.2" # 版本号
|
||||
# 请勿修改版本号,除非你知道自己在做什么
|
||||
|
||||
[nickname] # 现在没用
|
||||
nickname = ""
|
||||
|
||||
[napcat_server] # Napcat连接的ws服务设置
|
||||
token = "" # Napcat设定的访问令牌,若无则留空
|
||||
heartbeat_interval = 30 # 与Napcat设置的心跳相同(按秒计)
|
||||
|
||||
[chat] # 黑白名单功能
|
||||
group_list_type = "whitelist" # 群组名单类型,可选为:whitelist, blacklist
|
||||
group_list = [] # 群组名单
|
||||
# 当group_list_type为whitelist时,只有群组名单中的群组可以聊天
|
||||
# 当group_list_type为blacklist时,群组名单中的任何群组无法聊天
|
||||
private_list_type = "whitelist" # 私聊名单类型,可选为:whitelist, blacklist
|
||||
private_list = [] # 私聊名单
|
||||
# 当private_list_type为whitelist时,只有私聊名单中的用户可以聊天
|
||||
# 当private_list_type为blacklist时,私聊名单中的任何用户无法聊天
|
||||
ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天)
|
||||
ban_qq_bot = false # 是否屏蔽QQ官方机器人
|
||||
enable_poke = true # 是否启用戳一戳功能
|
||||
|
||||
[voice] # 发送语音设置
|
||||
use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter)
|
||||
|
||||
[debug]
|
||||
level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
|
||||
# core的model_config.toml
|
||||
core_model_config: |
|
||||
[inner]
|
||||
version = "1.9.1"
|
||||
|
||||
# 配置文件版本号迭代规则同bot_config.toml
|
||||
|
||||
[[api_providers]] # API服务提供商(可以配置多个)
|
||||
name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名)
|
||||
base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL
|
||||
api_key = "your-api-key-here" # API密钥(请替换为实际的API密钥)
|
||||
client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini")
|
||||
max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数)
|
||||
timeout = 120 # API请求超时时间(单位:秒)
|
||||
retry_interval = 10 # 重试间隔时间(单位:秒)
|
||||
|
||||
[[api_providers]] # 阿里 百炼 API服务商配置
|
||||
name = "BaiLian"
|
||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
api_key = "your-bailian-key"
|
||||
client_type = "openai"
|
||||
max_retry = 2
|
||||
timeout = 120
|
||||
retry_interval = 5
|
||||
|
||||
[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"gemini"
|
||||
name = "Google"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
api_key = "your-google-api-key-1"
|
||||
client_type = "gemini"
|
||||
max_retry = 2
|
||||
timeout = 120
|
||||
retry_interval = 10
|
||||
|
||||
[[api_providers]] # SiliconFlow的API服务商配置
|
||||
name = "SiliconFlow"
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
api_key = "your-siliconflow-api-key"
|
||||
client_type = "openai"
|
||||
max_retry = 3
|
||||
timeout = 120
|
||||
retry_interval = 5
|
||||
|
||||
|
||||
[[models]] # 模型(可以配置多个)
|
||||
model_identifier = "deepseek-chat" # 模型标识符(API服务商提供的模型标识符)
|
||||
name = "deepseek-v3" # 模型名称(可随意命名,在后面中需使用这个命名)
|
||||
api_provider = "DeepSeek" # API服务商名称(对应在api_providers中配置的服务商名称)
|
||||
price_in = 2.0 # 输入价格(用于API调用统计,单位:元/ M token)(可选,若无该字段,默认值为0)
|
||||
price_out = 8.0 # 输出价格(用于API调用统计,单位:元/ M token)(可选,若无该字段,默认值为0)
|
||||
# force_stream_mode = true # 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出,若无该字段,默认值为false)
|
||||
|
||||
[[models]]
|
||||
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
|
||||
name = "siliconflow-deepseek-v3.2"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 2.0
|
||||
price_out = 3.0
|
||||
# temperature = 0.5 # 可选:为该模型单独指定温度,会覆盖任务配置中的温度
|
||||
# max_tokens = 4096 # 可选:为该模型单独指定最大token数,会覆盖任务配置中的max_tokens
|
||||
[models.extra_params] # 可选的额外参数配置
|
||||
enable_thinking = false # 不启用思考
|
||||
|
||||
|
||||
[[models]]
|
||||
model_identifier = "deepseek-ai/DeepSeek-V3.2-Exp"
|
||||
name = "siliconflow-deepseek-v3.2-think"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 2.0
|
||||
price_out = 3.0
|
||||
# temperature = 0.7 # 可选:为该模型单独指定温度,会覆盖任务配置中的温度
|
||||
# max_tokens = 4096 # 可选:为该模型单独指定最大token数,会覆盖任务配置中的max_tokens
|
||||
[models.extra_params] # 可选的额外参数配置
|
||||
enable_thinking = true # 启用思考
|
||||
|
||||
|
||||
[[models]]
|
||||
model_identifier = "Qwen/Qwen3-Next-80B-A3B-Instruct"
|
||||
name = "qwen3-next-80b"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 1.0
|
||||
price_out = 4.0
|
||||
|
||||
[[models]]
|
||||
model_identifier = "zai-org/GLM-4.6"
|
||||
name = "siliconflow-glm-4.6"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 3.5
|
||||
price_out = 14.0
|
||||
[models.extra_params] # 可选的额外参数配置
|
||||
enable_thinking = false # 不启用思考
|
||||
|
||||
[[models]]
|
||||
model_identifier = "zai-org/GLM-4.6"
|
||||
name = "siliconflow-glm-4.6-think"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 3.5
|
||||
price_out = 14.0
|
||||
[models.extra_params] # 可选的额外参数配置
|
||||
enable_thinking = true # 启用思考
|
||||
|
||||
[[models]]
|
||||
model_identifier = "deepseek-ai/DeepSeek-R1"
|
||||
name = "siliconflow-deepseek-r1"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 4.0
|
||||
price_out = 16.0
|
||||
|
||||
|
||||
[[models]]
|
||||
model_identifier = "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||
name = "qwen3-30b"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 0.7
|
||||
price_out = 2.8
|
||||
|
||||
[[models]]
|
||||
model_identifier = "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
name = "qwen3-vl-30"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 4.13
|
||||
price_out = 4.13
|
||||
|
||||
[[models]]
|
||||
model_identifier = "FunAudioLLM/SenseVoiceSmall"
|
||||
name = "sensevoice-small"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 0
|
||||
price_out = 0
|
||||
|
||||
[[models]]
|
||||
model_identifier = "BAAI/bge-m3"
|
||||
name = "bge-m3"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 0
|
||||
price_out = 0
|
||||
|
||||
|
||||
|
||||
[model_task_config.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,麦麦的情绪变化等,是麦麦必须的模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"] # 使用的模型列表,每个子项对应上面的模型名称(name)
|
||||
temperature = 0.2 # 模型温度,新V3建议0.1-0.3
|
||||
max_tokens = 4096 # 最大输出token数
|
||||
slow_threshold = 15.0 # 慢请求阈值(秒),模型等待回复时间超过此值会输出警告日志
|
||||
|
||||
[model_task_config.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
|
||||
model_list = ["qwen3-30b","qwen3-next-80b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 2048
|
||||
slow_threshold = 10.0
|
||||
|
||||
[model_task_config.tool_use] #工具调用模型,需要使用支持工具调用的模型
|
||||
model_list = ["qwen3-30b","qwen3-next-80b"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
slow_threshold = 10.0
|
||||
|
||||
[model_task_config.replyer] # 首要回复模型,还用于表达器和表达方式学习
|
||||
model_list = ["siliconflow-deepseek-v3.2","siliconflow-deepseek-v3.2-think","siliconflow-glm-4.6","siliconflow-glm-4.6-think"]
|
||||
temperature = 0.3 # 模型温度,新V3建议0.1-0.3
|
||||
max_tokens = 2048
|
||||
slow_threshold = 25.0
|
||||
|
||||
[model_task_config.planner] #决策:负责决定麦麦该什么时候回复的模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.3
|
||||
max_tokens = 800
|
||||
slow_threshold = 12.0
|
||||
|
||||
[model_task_config.vlm] # 图像识别模型
|
||||
model_list = ["qwen3-vl-30"]
|
||||
max_tokens = 256
|
||||
slow_threshold = 15.0
|
||||
|
||||
[model_task_config.voice] # 语音识别模型
|
||||
model_list = ["sensevoice-small"]
|
||||
slow_threshold = 12.0
|
||||
|
||||
# 嵌入模型
|
||||
[model_task_config.embedding]
|
||||
model_list = ["bge-m3"]
|
||||
slow_threshold = 5.0
|
||||
|
||||
# ------------LPMM知识库模型------------
|
||||
|
||||
[model_task_config.lpmm_entity_extract] # 实体提取模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
slow_threshold = 20.0
|
||||
|
||||
[model_task_config.lpmm_rdf_build] # RDF构建模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.2
|
||||
max_tokens = 800
|
||||
slow_threshold = 20.0
|
||||
|
||||
[model_task_config.lpmm_qa] # 问答模型
|
||||
model_list = ["siliconflow-deepseek-v3.2"]
|
||||
temperature = 0.7
|
||||
max_tokens = 800
|
||||
slow_threshold = 20.0
|
||||
|
||||
# core的bot_config.toml
|
||||
core_bot_config: |
|
||||
[inner]
|
||||
version = "7.2.5"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
# 如果你想要修改配置文件,请递增version的值
|
||||
# 如果新增项目,请阅读src/config/official_configs.py中的说明
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
|
||||
# 主版本号:MMC版本更新
|
||||
# 次版本号:配置文件内容大更新
|
||||
# 修订号:配置文件内容小更新
|
||||
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
|
||||
[bot]
|
||||
platform = "qq"
|
||||
qq_account = "1145141919810" # 麦麦的QQ账号
|
||||
|
||||
platforms = ["wx:114514","xx:1919810"] # 麦麦的其他平台账号
|
||||
|
||||
nickname = "麦麦" # 麦麦的昵称
|
||||
alias_names = ["麦叠", "牢麦"] # 麦麦的别名
|
||||
|
||||
[personality]
|
||||
# 建议120字以内,描述人格特质 和 身份特征
|
||||
personality = "是一个大二在读女大学生,现在正在上网和群友聊天,有时有点攻击性,有时比较温柔"
|
||||
# アイデンティティがない 生まれないらららら
|
||||
# 描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容,建议1-2行
|
||||
reply_style = "请不要刻意突出自身学科背景。可以参考贴吧,知乎和微博的回复风格。"
|
||||
|
||||
# 多种回复风格列表,可选配置:当 non-empty 且 multiple_probability>0 时,会按概率随机从中选择一个替换 reply_style
|
||||
multiple_reply_style = [
|
||||
# "你的风格平淡但不失讽刺,很简短,很白话。可以参考贴吧,微博的回复风格。",
|
||||
# "用1-2个字进行回复",
|
||||
# "用1-2个符号进行回复",
|
||||
# "言辭凝練古雅,穿插《論語》經句卻不晦澀,以文言短句為基,輔以淺白語意,持長者溫和風範,全用繁體字表達,具先秦儒者談吐韻致。",
|
||||
# "带点翻译腔,但不要太长",
|
||||
]
|
||||
|
||||
# 替换概率:每次构建回复时,以该概率从 multiple_reply_style 中随机选择一个替换 reply_style(0.0-1.0)
|
||||
multiple_probability = 0.3
|
||||
|
||||
# 麦麦的说话规则,行为风格:
|
||||
plan_style = """
|
||||
1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用
|
||||
2.如果相同的内容已经被执行,请不要重复执行
|
||||
3.你对技术相关话题,游戏和动漫相关话题感兴趣,也对日常话题感兴趣,不喜欢太过沉重严肃的话题
|
||||
4.请控制你的发言频率,不要太过频繁的发言
|
||||
5.如果有人对你感到厌烦,请减少回复
|
||||
6.如果有人在追问你,或者话题没有说完,请你继续回复"""
|
||||
|
||||
# 麦麦识图规则,不建议修改
|
||||
visual_style = "请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本"
|
||||
|
||||
# 麦麦私聊的说话规则,行为风格:
|
||||
private_plan_style = """
|
||||
1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用
|
||||
2.如果相同的内容已经被执行,请不要重复执行
|
||||
3.某句话如果已经被回复过,不要重复回复"""
|
||||
|
||||
# 状态,可以理解为人格多样性,会随机替换人格
|
||||
states = [
|
||||
"是一个女大学生,喜欢上网聊天,会刷小红书。" ,
|
||||
"是一个大二心理学生,会刷贴吧和中国知网。" ,
|
||||
"是一个赛博网友,最近很想吐槽人。"
|
||||
]
|
||||
|
||||
# 替换概率,每次构建人格时替换personality的概率(0.0-1.0)
|
||||
state_probability = 0.3
|
||||
|
||||
|
||||
|
||||
[expression]
|
||||
# 表达学习配置
|
||||
learning_list = [ # 表达学习配置列表,支持按聊天流配置
|
||||
["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习
|
||||
["qq:1919810:group", "enable", "enable", "enable"], # 特定群聊配置:使用表达,启用学习,启用jargon学习
|
||||
["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习
|
||||
# 格式说明:
|
||||
# 第一位: chat_stream_id,空字符串表示全局配置
|
||||
# 第二位: 是否使用学到的表达 ("enable"/"disable")
|
||||
# 第三位: 是否学习表达 ("enable"/"disable")
|
||||
# 第四位: 是否启用jargon学习 ("enable"/"disable")
|
||||
]
|
||||
|
||||
expression_groups = [
|
||||
# ["*"], # 全局共享组:所有chat_id共享学习到的表达方式(取消注释以启用全局共享)
|
||||
["qq:1919810:private","qq:114514:private","qq:1111111:group"], # 特定互通组,相同组的chat_id会共享学习到的表达方式
|
||||
# 格式说明:
|
||||
# ["*"] - 启用全局共享,所有聊天流共享表达方式
|
||||
# ["qq:123456:private","qq:654321:group"] - 特定互通组,组内chat_id共享表达方式
|
||||
# 注意:如果为群聊,则需要设置为group,如果设置为私聊,则需要设置为private
|
||||
]
|
||||
|
||||
reflect = false # 是否启用表达反思(Bot主动向管理员询问表达方式是否合适)
|
||||
reflect_operator_id = "" # 表达反思操作员ID,格式:platform:id:type (例如 "qq:123456:private" 或 "qq:654321:group")
|
||||
allow_reflect = [] # 允许进行表达反思的聊天流ID列表,格式:["qq:123456:private", "qq:654321:group", ...],只有在此列表中的聊天流才会提出问题并跟踪。如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true)
|
||||
|
||||
all_global_jargon = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除
|
||||
enable_jargon_explanation = true # 是否在回复前尝试对上下文中的黑话进行解释(关闭可减少一次LLM调用,仅影响回复前的黑话匹配与解释,不影响黑话学习)
|
||||
jargon_mode = "planner" # 黑话解释来源模式,可选: "context"(使用上下文自动匹配黑话) 或 "planner"(仅使用Planner在reply动作中给出的unknown_words列表)
|
||||
|
||||
|
||||
[chat] # 麦麦的聊天设置
|
||||
talk_value = 1 # 聊天频率,越小越沉默,范围0-1
|
||||
mentioned_bot_reply = true # 是否启用提及必回复
|
||||
max_context_size = 30 # 上下文长度
|
||||
planner_smooth = 3 # 规划器平滑,增大数值会减小planner负荷,略微降低反应速度,推荐1-5,0为关闭,必须大于等于0
|
||||
think_mode = "dynamic" # 思考模式,可选:classic(默认浅度思考和回复)、deep(会进行比较长的,深度回复)、dynamic(动态选择两种模式)
|
||||
|
||||
enable_talk_value_rules = true # 是否启用动态发言频率规则
|
||||
|
||||
# 动态发言频率规则:按时段/按chat_id调整 talk_value(优先匹配具体chat,再匹配全局)
|
||||
# 推荐格式(对象数组):{ target="platform:id:type" 或 "", time="HH:MM-HH:MM", value=0.5 }
|
||||
# 说明:
|
||||
# - target 为空字符串表示全局;type 为 group/private,例如:"qq:1919810:group" 或 "qq:114514:private";
|
||||
# - 支持跨夜区间,例如 "23:00-02:00";数值范围建议 0-1,如果 value 设置为0会自动转换为0.0001以避免除以零错误。
|
||||
talk_value_rules = [
|
||||
{ target = "", time = "00:00-08:59", value = 0.8 },
|
||||
{ target = "", time = "09:00-22:59", value = 1.0 },
|
||||
{ target = "qq:1919810:group", time = "20:00-23:59", value = 0.6 },
|
||||
{ target = "qq:114514:private", time = "00:00-23:59", value = 0.3 },
|
||||
]
|
||||
|
||||
[memory]
|
||||
max_agent_iterations = 3 # 记忆思考深度(最低为1)
|
||||
agent_timeout_seconds = 45.0 # 最长回忆时间(秒)
|
||||
enable_jargon_detection = true # 记忆检索过程中是否启用黑话识别
|
||||
global_memory = false # 是否允许记忆检索进行全局查询
|
||||
|
||||
[dream]
|
||||
interval_minutes = 60 # 做梦时间间隔(分钟),默认30分钟
|
||||
max_iterations = 20 # 做梦最大轮次,默认20轮
|
||||
first_delay_seconds = 1800 # 程序启动后首次做梦前的延迟时间(秒),默认60秒
|
||||
|
||||
# 做梦结果推送目标,格式为 "platform:user_id"
|
||||
# 例如: "qq:123456" 表示在做梦结束后,将梦境文本额外发送给该QQ私聊用户。
|
||||
# 为空字符串时不推送。
|
||||
dream_send = ""
|
||||
|
||||
# 做梦时间段配置,格式:["HH:MM-HH:MM", ...]
|
||||
# 如果列表为空,则表示全天允许做梦。
|
||||
# 如果配置了时间段,则只有在这些时间段内才会实际执行做梦流程。
|
||||
# 时间段外,调度器仍会按间隔检查,但不会进入做梦流程。
|
||||
# 支持跨夜区间,例如 "23:00-02:00" 表示从23:00到次日02:00。
|
||||
# 示例:
|
||||
dream_time_ranges = [
|
||||
# "09:00-22:00", # 白天允许做梦
|
||||
"23:00-10:00", # 跨夜时间段(23:00到次日10:00)
|
||||
]
|
||||
# dream_time_ranges = []
|
||||
|
||||
[tool]
|
||||
enable_tool = true # 是否启用工具
|
||||
|
||||
|
||||
[emoji]
|
||||
emoji_chance = 0.4 # 麦麦激活表情包动作的概率
|
||||
max_reg_num = 100 # 表情包最大注册数量
|
||||
do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包
|
||||
check_interval = 10 # 检查表情包(注册,破损,删除)的时间间隔(分钟)
|
||||
steal_emoji = true # 是否偷取表情包,让麦麦可以将一些表情包据为己有
|
||||
content_filtration = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存
|
||||
filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存
|
||||
|
||||
[voice]
|
||||
enable_asr = false # 是否启用语音识别,启用后麦麦可以识别语音消息,启用该功能需要配置语音识别模型[model_task_config.voice]
|
||||
|
||||
[message_receive]
|
||||
# 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息
|
||||
ban_words = [
|
||||
# "403","张三"
|
||||
]
|
||||
|
||||
ban_msgs_regex = [
|
||||
# 需要过滤的消息(原始消息)匹配的正则表达式,匹配到的消息将被过滤,若不了解正则表达式请勿修改
|
||||
# "https?://[^\\s]+", # 匹配https链接
|
||||
# "\\d{4}-\\d{2}-\\d{2}", # 匹配日期
|
||||
]
|
||||
|
||||
|
||||
[lpmm_knowledge] # lpmm知识库配置
|
||||
enable = false # 是否启用lpmm知识库
|
||||
lpmm_mode = "agent"
|
||||
# 可选择classic传统模式/agent 模式,结合新的记忆一同使用
|
||||
rag_synonym_search_top_k = 10 # 同义检索TopK
|
||||
rag_synonym_threshold = 0.8 # 同义阈值,相似度高于该值的关系会被当作同义词
|
||||
info_extraction_workers = 3 # 实体抽取同时执行线程数,非Pro模型不要设置超过5
|
||||
qa_relation_search_top_k = 10 # 关系检索TopK
|
||||
qa_relation_threshold = 0.5 # 关系阈值,相似度高于该值的关系会被认为是相关关系
|
||||
qa_paragraph_search_top_k = 1000 # 段落检索TopK(不能过小,可能影响搜索结果)
|
||||
qa_paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重,当搜索仅使用DPR时,此参数不起作用)
|
||||
qa_ent_filter_top_k = 10 # 实体过滤TopK
|
||||
qa_ppr_damping = 0.8 # PPR阻尼系数
|
||||
qa_res_top_k = 3 # 最终提供段落TopK
|
||||
embedding_dimension = 1024 # 嵌入向量维度,输出维度一致
|
||||
# 性能与降级参数(低配机器可下调)
|
||||
# 低配机器参考:单/双核或内存≤4GB(如轻量云主机/云函数/开发板),建议先关闭PPR并降低并发
|
||||
max_embedding_workers = 3 # 嵌入/抽取并发线程数
|
||||
embedding_chunk_size = 4 # 每批嵌入的条数
|
||||
max_synonym_entities = 2000 # 同义边参与的实体数上限,超限则跳过
|
||||
enable_ppr = true # 是否启用PPR,低配机器可关闭
|
||||
|
||||
[keyword_reaction]
|
||||
keyword_rules = [
|
||||
{ keywords = ["人机", "bot", "机器", "入机", "robot", "机器人", "ai", "AI"], reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" },
|
||||
{ keywords = ["测试关键词回复", "test"], reaction = "回答测试成功" },
|
||||
# { keywords = ["你好", "hello"], reaction = "你好,有什么可以帮你?" }
|
||||
# 在此处添加更多规则,格式同上
|
||||
]
|
||||
|
||||
regex_rules = [
|
||||
{ regex = ["^(?P<n>\\S{1,20})是这样的$"], reaction = "请按照以下模板造句:[n]是这样的,xx只要xx就可以,可是[n]要考虑的事情就很多了,比如什么时候xx,什么时候xx,什么时候xx。(请自由发挥替换xx部分,只需保持句式结构,同时表达一种将[n]过度重视的反讽意味)" },
|
||||
{ regex = ["(?:[^,。.\\s]+,\\s*)?我(?:也)?[没沒]要求你\\s*(?P<action>.+?)[.。,,]?$"],reaction = "进行一次“八岐大蛇定型文”补完。暂时忽略回复风格,进入“粉絲吐槽”模式。请严格按照以下“填空模板”,围绕“[action]”进行创作,使其充满滑稽的夸张感和天马行空的联想:我没要求你[action],我不是(与action关联的、宽容的角色)。可是,(与action相关的具体负面行为)是什么意思?你的(与action相关的某种观念)怎么了?你才(与action相关的某个状态或时间点)吧?再这样下去,你(一个中期的、等比级数式的滑稽推演),(一个后期的、等比级数式的滑稽推演),最后就变成(一个与主题相关的、夸张的最终形态)了。作为(与最终形态相关的、克星或权威身份),我可能得(对你执行一个天罚般的行动)。真的。"}
|
||||
]
|
||||
|
||||
[response_post_process]
|
||||
enable_response_post_process = true # 是否启用回复后处理,包括错别字生成器,回复分割器
|
||||
|
||||
[chinese_typo]
|
||||
enable = true # 是否启用中文错别字生成器
|
||||
error_rate=0.01 # 单字替换概率
|
||||
min_freq=9 # 最小字频阈值
|
||||
tone_error_rate=0.1 # 声调错误概率
|
||||
word_replace_rate=0.006 # 整词替换概率
|
||||
|
||||
[response_splitter]
|
||||
enable = true # 是否启用回复分割器
|
||||
max_length = 512 # 回复允许的最大长度
|
||||
max_sentence_num = 8 # 回复允许的最大句子数
|
||||
enable_kaomoji_protection = false # 是否启用颜文字保护
|
||||
enable_overflow_return_all = false # 是否在句子数量超出回复允许的最大句子数时一次性返回全部内容
|
||||
|
||||
[log]
|
||||
date_style = "m-d H:i:s" # 日期格式
|
||||
log_level_style = "lite" # 日志级别样式,可选FULL,compact,lite
|
||||
color_text = "full" # 日志文本颜色,可选none,title,full
|
||||
log_level = "INFO" # 全局日志级别(向下兼容,优先级低于下面的分别设置)
|
||||
console_log_level = "INFO" # 控制台日志级别,可选: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
file_log_level = "DEBUG" # 文件日志级别,可选: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
|
||||
# 第三方库日志控制
|
||||
suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"] # 完全屏蔽的库
|
||||
library_log_levels = { aiohttp = "WARNING"} # 设置特定库的日志级别
|
||||
|
||||
[debug]
|
||||
show_prompt = false # 是否显示prompt
|
||||
show_replyer_prompt = false # 是否显示回复器prompt
|
||||
show_replyer_reasoning = false # 是否显示回复器推理
|
||||
show_jargon_prompt = false # 是否显示jargon相关提示词
|
||||
show_memory_prompt = false # 是否显示记忆检索相关提示词
|
||||
show_planner_prompt = false # 是否显示planner的prompt和原始返回结果
|
||||
show_lpmm_paragraph = false # 是否显示lpmm找到的相关文段日志
|
||||
|
||||
[maim_message]
|
||||
auth_token = [] # 认证令牌,用于旧版API验证,为空则不启用验证
|
||||
|
||||
# 新版API Server配置(额外监听端口)
|
||||
enable_api_server = false # 是否启用额外的新版API Server
|
||||
api_server_host = "0.0.0.0" # 新版API Server主机地址
|
||||
api_server_port = 8090 # 新版API Server端口号
|
||||
api_server_use_wss = false # 新版API Server是否启用WSS
|
||||
api_server_cert_file = "" # 新版API Server SSL证书文件路径
|
||||
api_server_key_file = "" # 新版API Server SSL密钥文件路径
|
||||
api_server_allowed_api_keys = [] # 新版API Server允许的API Key列表,为空则允许所有连接
|
||||
|
||||
[telemetry] #发送统计信息,主要是看全球有多少只麦麦
|
||||
enable = true
|
||||
|
||||
[webui] # WebUI 独立服务器配置
|
||||
# 注意: WebUI 的监听地址(host)和端口(port)已移至 .env 文件中的 WEBUI_HOST 和 WEBUI_PORT
|
||||
enabled = true # 是否启用WebUI
|
||||
mode = "production" # 模式: development(开发) 或 production(生产)
|
||||
|
||||
# 防爬虫配置
|
||||
anti_crawler_mode = "basic" # 防爬虫模式: false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止)
|
||||
allowed_ips = "127.0.0.1" # IP白名单(逗号分隔,支持精确IP、CIDR格式和通配符)
|
||||
# 示例: 127.0.0.1,192.168.1.0/24,172.17.0.0/16
|
||||
trusted_proxies = "" # 信任的代理IP列表(逗号分隔),只有来自这些IP的X-Forwarded-For才被信任
|
||||
# 示例: 127.0.0.1,192.168.1.1,172.17.0.1
|
||||
trust_xff = false # 是否启用X-Forwarded-For代理解析(默认false)
|
||||
# 启用后,仍要求直连IP在trusted_proxies中才会信任XFF头
|
||||
secure_cookie = false # 是否启用安全Cookie(仅通过HTTPS传输,默认false)
|
||||
|
||||
[experimental] #实验性功能
|
||||
# 为指定聊天添加额外的prompt配置
|
||||
# 格式: ["platform:id:type:prompt内容", ...]
|
||||
# 示例:
|
||||
# chat_prompts = [
|
||||
# "qq:114514:group:这是一个摄影群,你精通摄影知识",
|
||||
# "qq:19198:group:这是一个二次元交流群",
|
||||
# "qq:114514:private:这是你与好朋友的私聊"
|
||||
# ]
|
||||
chat_prompts = []
|
||||
|
||||
|
||||
# 此系统暂时移除,无效配置
|
||||
[relationship]
|
||||
enable_relationship = true # 是否启用关系系统
|
||||
|
|
@ -113,8 +113,6 @@ from .core.claude_config import (
|
|||
)
|
||||
from .tool_chain import (
|
||||
ToolChainDefinition,
|
||||
ToolChainStep,
|
||||
ChainExecutionResult,
|
||||
tool_chain_manager,
|
||||
)
|
||||
|
||||
|
|
@ -1651,7 +1649,7 @@ class MCPStatusCommand(BaseCommand):
|
|||
tool_name = f"chain_{name}".replace("-", "_").replace(".", "_")
|
||||
if component_registry.get_component_info(tool_name, ComponentType.TOOL):
|
||||
registered += 1
|
||||
lines = [f"✅ 已重新加载工具链配置"]
|
||||
lines = ["✅ 已重新加载工具链配置"]
|
||||
lines.append(f"📋 配置数: {len(chains)} 个")
|
||||
lines.append(f"🔧 已注册: {registered} 个(可被 LLM 调用)")
|
||||
if chains:
|
||||
|
|
@ -1698,7 +1696,7 @@ class MCPStatusCommand(BaseCommand):
|
|||
output_preview += "..."
|
||||
lines.append(output_preview)
|
||||
else:
|
||||
lines.append(f"❌ 工具链执行失败")
|
||||
lines.append("❌ 工具链执行失败")
|
||||
lines.append(f"错误: {result.error}")
|
||||
if result.step_results:
|
||||
lines.append("")
|
||||
|
|
@ -1777,9 +1775,9 @@ class MCPStatusCommand(BaseCommand):
|
|||
cb = info.get("circuit_breaker", {})
|
||||
cb_state = cb.get("state", "closed")
|
||||
if cb_state == "open":
|
||||
lines.append(f" ⚡ 断路器熔断中")
|
||||
lines.append(" ⚡ 断路器熔断中")
|
||||
elif cb_state == "half_open":
|
||||
lines.append(f" ⚡ 断路器试探中")
|
||||
lines.append(" ⚡ 断路器试探中")
|
||||
if info["consecutive_failures"] > 0:
|
||||
lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']} 次")
|
||||
|
||||
|
|
|
|||
|
|
@ -14,12 +14,11 @@ MCP Workflow 模块 v1.9.0
|
|||
- 与 ReAct 软流程互补,用户可选择合适的执行方式
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from src.common.logger import get_logger
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import random
|
||||
from typing import List, Tuple, Type, Any
|
||||
from typing import List, Tuple, Type, Any, Optional
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
|
|
@ -17,6 +17,9 @@ from src.plugin_system import (
|
|||
emoji_api,
|
||||
)
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("hello_world_plugin")
|
||||
|
||||
|
||||
class CompareNumbersTool(BaseTool):
|
||||
|
|
@ -217,6 +220,39 @@ class RandomEmojis(BaseCommand):
|
|||
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
|
||||
|
||||
|
||||
class TestCommand(BaseCommand):
|
||||
"""响应/test命令"""
|
||||
|
||||
command_name = "test"
|
||||
command_description = "测试命令"
|
||||
command_pattern = r"^/test$"
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], int]:
|
||||
"""执行测试命令"""
|
||||
try:
|
||||
from src.plugin_system.apis import generator_api
|
||||
|
||||
reply_reason = "这是一条测试消息。"
|
||||
logger.info(f"测试命令:{reply_reason}")
|
||||
result_status, data = await generator_api.generate_reply(
|
||||
chat_stream=self.message.chat_stream,
|
||||
reply_reason=reply_reason,
|
||||
enable_chinese_typo=False,
|
||||
extra_info=f"{reply_reason}用于测试bot的功能是否正常。请你按设定的人设表达一句\"测试正常\"",
|
||||
)
|
||||
if result_status:
|
||||
# 发送生成的回复
|
||||
if data and data.reply_set and data.reply_set.reply_data:
|
||||
for reply_seg in data.reply_set.reply_data:
|
||||
send_data = reply_seg.content
|
||||
await self.send_text(send_data, storage_message=True)
|
||||
logger.info(f"已回复: {send_data}")
|
||||
return True, "", 1
|
||||
except Exception as e:
|
||||
logger.error(f"表达器生成失败:{e}")
|
||||
return True, "", 1
|
||||
|
||||
|
||||
# ===== 插件注册 =====
|
||||
|
||||
|
||||
|
|
@ -259,6 +295,7 @@ class HelloWorldPlugin(BasePlugin):
|
|||
(PrintMessage.get_handler_info(), PrintMessage),
|
||||
(ForwardMessages.get_handler_info(), ForwardMessages),
|
||||
(RandomEmojis.get_command_info(), RandomEmojis),
|
||||
(TestCommand.get_command_info(), TestCommand),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,322 @@
|
|||
"""
|
||||
评估结果统计脚本
|
||||
|
||||
功能:
|
||||
1. 扫描temp目录下所有JSON文件
|
||||
2. 分析每个文件的统计信息
|
||||
3. 输出详细的统计报告
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("evaluation_stats_analyzer")
|
||||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
|
||||
|
||||
def parse_datetime(dt_str: str) -> datetime | None:
|
||||
"""解析ISO格式的日期时间字符串"""
|
||||
try:
|
||||
return datetime.fromisoformat(dt_str)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def analyze_single_file(file_path: str) -> Dict:
|
||||
"""
|
||||
分析单个JSON文件的统计信息
|
||||
|
||||
Args:
|
||||
file_path: JSON文件路径
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
file_name = os.path.basename(file_path)
|
||||
stats = {
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"file_size": os.path.getsize(file_path),
|
||||
"error": None,
|
||||
"last_updated": None,
|
||||
"total_count": 0,
|
||||
"actual_count": 0,
|
||||
"suitable_count": 0,
|
||||
"unsuitable_count": 0,
|
||||
"suitable_rate": 0.0,
|
||||
"unique_pairs": 0,
|
||||
"evaluators": Counter(),
|
||||
"evaluation_dates": [],
|
||||
"date_range": None,
|
||||
"has_expression_id": False,
|
||||
"has_reason": False,
|
||||
"reason_count": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 基本信息
|
||||
stats["last_updated"] = data.get("last_updated")
|
||||
stats["total_count"] = data.get("total_count", 0)
|
||||
|
||||
results = data.get("manual_results", [])
|
||||
stats["actual_count"] = len(results)
|
||||
|
||||
if not results:
|
||||
return stats
|
||||
|
||||
# 统计通过/不通过
|
||||
suitable_count = sum(1 for r in results if r.get("suitable") is True)
|
||||
unsuitable_count = sum(1 for r in results if r.get("suitable") is False)
|
||||
stats["suitable_count"] = suitable_count
|
||||
stats["unsuitable_count"] = unsuitable_count
|
||||
stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0
|
||||
|
||||
# 统计唯一的(situation, style)对
|
||||
pairs: Set[Tuple[str, str]] = set()
|
||||
for r in results:
|
||||
if "situation" in r and "style" in r:
|
||||
pairs.add((r["situation"], r["style"]))
|
||||
stats["unique_pairs"] = len(pairs)
|
||||
|
||||
# 统计评估者
|
||||
for r in results:
|
||||
evaluator = r.get("evaluator", "unknown")
|
||||
stats["evaluators"][evaluator] += 1
|
||||
|
||||
# 统计评估时间
|
||||
evaluation_dates = []
|
||||
for r in results:
|
||||
evaluated_at = r.get("evaluated_at")
|
||||
if evaluated_at:
|
||||
dt = parse_datetime(evaluated_at)
|
||||
if dt:
|
||||
evaluation_dates.append(dt)
|
||||
|
||||
stats["evaluation_dates"] = evaluation_dates
|
||||
if evaluation_dates:
|
||||
min_date = min(evaluation_dates)
|
||||
max_date = max(evaluation_dates)
|
||||
stats["date_range"] = {
|
||||
"start": min_date.isoformat(),
|
||||
"end": max_date.isoformat(),
|
||||
"duration_days": (max_date - min_date).days + 1
|
||||
}
|
||||
|
||||
# 检查字段存在性
|
||||
stats["has_expression_id"] = any("expression_id" in r for r in results)
|
||||
stats["has_reason"] = any(r.get("reason") for r in results)
|
||||
stats["reason_count"] = sum(1 for r in results if r.get("reason"))
|
||||
|
||||
except Exception as e:
|
||||
stats["error"] = str(e)
|
||||
logger.error(f"分析文件 {file_name} 时出错: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def print_file_stats(stats: Dict, index: int = None):
|
||||
"""打印单个文件的统计信息"""
|
||||
prefix = f"[{index}] " if index is not None else ""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"{prefix}文件: {stats['file_name']}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
if stats["error"]:
|
||||
print(f"✗ 错误: {stats['error']}")
|
||||
return
|
||||
|
||||
print(f"文件路径: {stats['file_path']}")
|
||||
print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)")
|
||||
|
||||
if stats["last_updated"]:
|
||||
print(f"最后更新: {stats['last_updated']}")
|
||||
|
||||
print("\n【记录统计】")
|
||||
print(f" 文件中的 total_count: {stats['total_count']}")
|
||||
print(f" 实际记录数: {stats['actual_count']}")
|
||||
|
||||
if stats['total_count'] != stats['actual_count']:
|
||||
diff = stats['total_count'] - stats['actual_count']
|
||||
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
|
||||
|
||||
print("\n【评估结果统计】")
|
||||
print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)")
|
||||
print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)")
|
||||
|
||||
print("\n【唯一性统计】")
|
||||
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']} 条")
|
||||
if stats['actual_count'] > 0:
|
||||
duplicate_count = stats['actual_count'] - stats['unique_pairs']
|
||||
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
print("\n【评估者统计】")
|
||||
if stats['evaluators']:
|
||||
for evaluator, count in stats['evaluators'].most_common():
|
||||
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
print("\n【时间统计】")
|
||||
if stats['date_range']:
|
||||
print(f" 最早评估时间: {stats['date_range']['start']}")
|
||||
print(f" 最晚评估时间: {stats['date_range']['end']}")
|
||||
print(f" 评估时间跨度: {stats['date_range']['duration_days']} 天")
|
||||
else:
|
||||
print(" 无时间信息")
|
||||
|
||||
print("\n【字段统计】")
|
||||
print(f" 包含 expression_id: {'是' if stats['has_expression_id'] else '否'}")
|
||||
print(f" 包含 reason: {'是' if stats['has_reason'] else '否'}")
|
||||
if stats['has_reason']:
|
||||
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
|
||||
|
||||
|
||||
def print_summary(all_stats: List[Dict]):
|
||||
"""打印汇总统计信息"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print("汇总统计")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
total_files = len(all_stats)
|
||||
valid_files = [s for s in all_stats if not s.get("error")]
|
||||
error_files = [s for s in all_stats if s.get("error")]
|
||||
|
||||
print("\n【文件统计】")
|
||||
print(f" 总文件数: {total_files}")
|
||||
print(f" 成功解析: {len(valid_files)}")
|
||||
print(f" 解析失败: {len(error_files)}")
|
||||
|
||||
if error_files:
|
||||
print("\n 失败文件列表:")
|
||||
for stats in error_files:
|
||||
print(f" - {stats['file_name']}: {stats['error']}")
|
||||
|
||||
if not valid_files:
|
||||
print("\n没有成功解析的文件")
|
||||
return
|
||||
|
||||
# 汇总记录统计
|
||||
total_records = sum(s['actual_count'] for s in valid_files)
|
||||
total_suitable = sum(s['suitable_count'] for s in valid_files)
|
||||
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
|
||||
total_unique_pairs = set()
|
||||
|
||||
# 收集所有唯一的(situation, style)对
|
||||
for stats in valid_files:
|
||||
try:
|
||||
with open(stats['file_path'], "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
for r in results:
|
||||
if "situation" in r and "style" in r:
|
||||
total_unique_pairs.add((r["situation"], r["style"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("\n【记录汇总】")
|
||||
print(f" 总记录数: {total_records:,} 条")
|
||||
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
|
||||
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
|
||||
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,} 条")
|
||||
|
||||
if total_records > 0:
|
||||
duplicate_count = total_records - len(total_unique_pairs)
|
||||
duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
# 汇总评估者统计
|
||||
all_evaluators = Counter()
|
||||
for stats in valid_files:
|
||||
all_evaluators.update(stats['evaluators'])
|
||||
|
||||
print("\n【评估者汇总】")
|
||||
if all_evaluators:
|
||||
for evaluator, count in all_evaluators.most_common():
|
||||
rate = (count / total_records * 100) if total_records > 0 else 0
|
||||
print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
# 汇总时间范围
|
||||
all_dates = []
|
||||
for stats in valid_files:
|
||||
all_dates.extend(stats['evaluation_dates'])
|
||||
|
||||
if all_dates:
|
||||
min_date = min(all_dates)
|
||||
max_date = max(all_dates)
|
||||
print("\n【时间汇总】")
|
||||
print(f" 最早评估时间: {min_date.isoformat()}")
|
||||
print(f" 最晚评估时间: {max_date.isoformat()}")
|
||||
print(f" 总时间跨度: {(max_date - min_date).days + 1} 天")
|
||||
|
||||
# 文件大小汇总
|
||||
total_size = sum(s['file_size'] for s in valid_files)
|
||||
avg_size = total_size / len(valid_files) if valid_files else 0
|
||||
print("\n【文件大小汇总】")
|
||||
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
|
||||
print(f" 平均大小: {avg_size:,.0f} 字节 ({avg_size / 1024:.2f} KB)")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
logger.info("=" * 80)
|
||||
logger.info("开始分析评估结果统计信息")
|
||||
logger.info("=" * 80)
|
||||
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
print(f"\n✗ 错误:未找到temp目录: {TEMP_DIR}")
|
||||
logger.error(f"未找到temp目录: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
if not json_files:
|
||||
print(f"\n✗ 错误:temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
json_files.sort() # 按文件名排序
|
||||
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件")
|
||||
print("=" * 80)
|
||||
|
||||
# 分析每个文件
|
||||
all_stats = []
|
||||
for i, json_file in enumerate(json_files, 1):
|
||||
stats = analyze_single_file(json_file)
|
||||
all_stats.append(stats)
|
||||
print_file_stats(stats, index=i)
|
||||
|
||||
# 打印汇总统计
|
||||
print_summary(all_stats)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("分析完成")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,507 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import importlib
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import initialize_logging, get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import LLMUsage
|
||||
|
||||
logger = get_logger("compare_finish_search_token")
|
||||
|
||||
|
||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
"""获取从指定时间开始的token使用情况
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
|
||||
Returns:
|
||||
包含token使用统计的字典
|
||||
"""
|
||||
try:
|
||||
start_datetime = datetime.fromtimestamp(start_time)
|
||||
|
||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
||||
records = (
|
||||
LLMUsage.select()
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_datetime)
|
||||
& (
|
||||
(LLMUsage.request_type.like("%memory%"))
|
||||
| (LLMUsage.request_type == "memory.question")
|
||||
| (LLMUsage.request_type == "memory.react")
|
||||
| (LLMUsage.request_type == "memory.react.final")
|
||||
)
|
||||
)
|
||||
.order_by(LLMUsage.timestamp.asc())
|
||||
)
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
request_count = 0
|
||||
model_usage = {} # 按模型统计
|
||||
|
||||
for record in records:
|
||||
total_prompt_tokens += record.prompt_tokens or 0
|
||||
total_completion_tokens += record.completion_tokens or 0
|
||||
total_tokens += record.total_tokens or 0
|
||||
total_cost += record.cost or 0.0
|
||||
request_count += 1
|
||||
|
||||
# 按模型统计
|
||||
model_name = record.model_name or "unknown"
|
||||
if model_name not in model_usage:
|
||||
model_usage[model_name] = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": 0.0,
|
||||
"request_count": 0,
|
||||
}
|
||||
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
|
||||
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
|
||||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
||||
model_usage[model_name]["request_count"] += 1
|
||||
|
||||
return {
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"request_count": request_count,
|
||||
"model_usage": model_usage,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取token使用情况失败: {e}")
|
||||
return {
|
||||
"total_prompt_tokens": 0,
|
||||
"total_completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"request_count": 0,
|
||||
"model_usage": {},
|
||||
}
|
||||
|
||||
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
try:
|
||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
||||
|
||||
module_name = "src.memory_system.memory_retrieval"
|
||||
|
||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
)
|
||||
|
||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith('src.memory_system.') and key != 'src.memory_system':
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
del sys.modules[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
||||
try:
|
||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
try:
|
||||
import src.chat.replyer.group_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
try:
|
||||
import src.chat.replyer.private_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
||||
|
||||
# 现在尝试导入 memory_retrieval
|
||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
||||
memory_retrieval_module = importlib.import_module(module_name)
|
||||
|
||||
return (
|
||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
||||
memory_retrieval_module._react_agent_solve_question,
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _init_tools_without_finish_search():
|
||||
"""初始化工具但不注册 finish_search"""
|
||||
from src.memory_system.retrieval_tools import (
|
||||
register_query_chat_history,
|
||||
register_query_person_info,
|
||||
register_query_words,
|
||||
)
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
from src.config.config import global_config
|
||||
|
||||
# 清空工具注册器
|
||||
tool_registry = get_tool_registry()
|
||||
tool_registry.tools.clear()
|
||||
|
||||
# 注册除 finish_search 外的所有工具
|
||||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
register_query_words()
|
||||
|
||||
# 如果启用 LPMM agent 模式,也注册 LPMM 工具
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
from src.memory_system.retrieval_tools.query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
register_lpmm_knowledge()
|
||||
|
||||
logger.info("已初始化工具(不包含 finish_search)")
|
||||
|
||||
|
||||
def _init_tools_with_finish_search():
|
||||
"""初始化工具并注册 finish_search"""
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
from src.memory_system.retrieval_tools import init_all_tools
|
||||
|
||||
# 清空工具注册器
|
||||
tool_registry = get_tool_registry()
|
||||
tool_registry.tools.clear()
|
||||
|
||||
# 初始化所有工具(包括 finish_search)
|
||||
init_all_tools()
|
||||
logger.info("已初始化工具(包含 finish_search)")
|
||||
|
||||
|
||||
async def get_prompt_tokens_for_tools(
|
||||
question: str,
|
||||
chat_id: str,
|
||||
use_finish_search: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""获取使用不同工具配置时的prompt token消耗
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
use_finish_search: 是否使用 finish_search 工具
|
||||
|
||||
Returns:
|
||||
包含prompt token信息的字典
|
||||
"""
|
||||
# 先初始化 prompt(如果还未初始化)
|
||||
# 注意:init_memory_retrieval_prompt 会调用 init_all_tools,所以我们需要在它之后重新设置工具
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt, _ = _import_memory_retrieval()
|
||||
init_memory_retrieval_prompt()
|
||||
|
||||
# 初始化工具(根据参数决定是否包含 finish_search)
|
||||
# 必须在 init_memory_retrieval_prompt 之后调用,因为它会调用 init_all_tools
|
||||
if use_finish_search:
|
||||
_init_tools_with_finish_search()
|
||||
else:
|
||||
_init_tools_without_finish_search()
|
||||
|
||||
# 获取工具注册器
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
tool_registry = get_tool_registry()
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
|
||||
# 验证工具列表(调试用)
|
||||
tool_names = [tool["name"] for tool in tool_definitions]
|
||||
if use_finish_search:
|
||||
if "finish_search" not in tool_names:
|
||||
logger.warning("期望包含 finish_search 工具,但工具列表中未找到")
|
||||
else:
|
||||
if "finish_search" in tool_names:
|
||||
logger.warning("期望不包含 finish_search 工具,但工具列表中找到了,将移除")
|
||||
# 移除 finish_search 工具
|
||||
tool_registry.tools.pop("finish_search", None)
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
tool_names = [tool["name"] for tool in tool_definitions]
|
||||
|
||||
# 构建第一次调用的prompt(模拟_react_agent_solve_question的第一次调用)
|
||||
from src.config.config import global_config
|
||||
bot_name = global_config.bot.nickname
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
# 构建head_prompt
|
||||
head_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_react_prompt_head",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
question=question,
|
||||
collected_info="",
|
||||
current_iteration=1,
|
||||
remaining_iterations=global_config.memory.max_agent_iterations - 1,
|
||||
max_iterations=global_config.memory.max_agent_iterations,
|
||||
)
|
||||
|
||||
# 构建消息列表(只包含system message,模拟第一次调用)
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
messages = []
|
||||
system_builder = MessageBuilder()
|
||||
system_builder.set_role(RoleType.System)
|
||||
system_builder.add_text_content(head_prompt)
|
||||
messages.append(system_builder.build())
|
||||
|
||||
# 调用LLM API来计算token(只调用一次,不实际执行)
|
||||
from src.llm_models.utils_model import LLMRequest, RequestType
|
||||
from src.config.config import model_config
|
||||
|
||||
# 创建LLM请求对象
|
||||
llm_request = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="memory.react.compare")
|
||||
|
||||
# 构建工具选项
|
||||
tool_built = llm_request._build_tool_options(tool_definitions)
|
||||
|
||||
# 直接调用 _execute_request 以获取完整的响应对象(包含 usage)
|
||||
response, model_info = await llm_request._execute_request(
|
||||
request_type=RequestType.RESPONSE,
|
||||
message_factory=lambda _client, *, _messages=messages: _messages,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
tool_options=tool_built,
|
||||
)
|
||||
|
||||
# 从响应中获取token使用情况
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
if response and hasattr(response, 'usage') and response.usage:
|
||||
prompt_tokens = response.usage.prompt_tokens or 0
|
||||
completion_tokens = response.usage.completion_tokens or 0
|
||||
total_tokens = response.usage.total_tokens or 0
|
||||
|
||||
return {
|
||||
"use_finish_search": use_finish_search,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"tool_count": len(tool_definitions),
|
||||
"tool_names": [tool["name"] for tool in tool_definitions],
|
||||
}
|
||||
|
||||
|
||||
async def compare_prompt_tokens(
|
||||
question: str,
|
||||
chat_id: str = "compare_finish_search",
|
||||
) -> Dict[str, Any]:
|
||||
"""对比使用 finish_search 工具与否的输入 token 差异
|
||||
|
||||
只运行一次,只计算输入 token 的差异,确保除了工具定义外其他内容一致
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
包含对比结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("finish_search 工具 输入 Token 消耗对比测试")
|
||||
print("=" * 80)
|
||||
print(f"\n[测试问题] {question}")
|
||||
print(f"[聊天ID] {chat_id}")
|
||||
print("\n注意: 只对比第一次LLM调用的输入token差异,不运行完整迭代流程")
|
||||
|
||||
# 第一次测试:不使用 finish_search
|
||||
print("\n" + "-" * 80)
|
||||
print("[测试 1/2] 不使用 finish_search 工具")
|
||||
print("-" * 80)
|
||||
result_without = await get_prompt_tokens_for_tools(
|
||||
question=question,
|
||||
chat_id=f"{chat_id}_without",
|
||||
use_finish_search=False,
|
||||
)
|
||||
|
||||
print(f"\n[结果]")
|
||||
print(f" 工具数量: {result_without['tool_count']}")
|
||||
print(f" 工具列表: {', '.join(result_without['tool_names'])}")
|
||||
print(f" 输入 Prompt Tokens: {result_without['prompt_tokens']:,}")
|
||||
|
||||
# 等待一下,确保数据库记录已写入
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 第二次测试:使用 finish_search
|
||||
print("\n" + "-" * 80)
|
||||
print("[测试 2/2] 使用 finish_search 工具")
|
||||
print("-" * 80)
|
||||
result_with = await get_prompt_tokens_for_tools(
|
||||
question=question,
|
||||
chat_id=f"{chat_id}_with",
|
||||
use_finish_search=True,
|
||||
)
|
||||
|
||||
print(f"\n[结果]")
|
||||
print(f" 工具数量: {result_with['tool_count']}")
|
||||
print(f" 工具列表: {', '.join(result_with['tool_names'])}")
|
||||
print(f" 输入 Prompt Tokens: {result_with['prompt_tokens']:,}")
|
||||
|
||||
# 对比结果
|
||||
print("\n" + "=" * 80)
|
||||
print("[对比结果]")
|
||||
print("=" * 80)
|
||||
|
||||
prompt_token_diff = result_with['prompt_tokens'] - result_without['prompt_tokens']
|
||||
prompt_token_diff_percent = (prompt_token_diff / result_without['prompt_tokens'] * 100) if result_without['prompt_tokens'] > 0 else 0
|
||||
|
||||
tool_count_diff = result_with['tool_count'] - result_without['tool_count']
|
||||
|
||||
print(f"\n[输入 Prompt Token 对比]")
|
||||
print(f" 不使用 finish_search: {result_without['prompt_tokens']:,} tokens")
|
||||
print(f" 使用 finish_search: {result_with['prompt_tokens']:,} tokens")
|
||||
print(f" 差异: {prompt_token_diff:+,} tokens ({prompt_token_diff_percent:+.2f}%)")
|
||||
|
||||
print(f"\n[工具数量对比]")
|
||||
print(f" 不使用 finish_search: {result_without['tool_count']} 个工具")
|
||||
print(f" 使用 finish_search: {result_with['tool_count']} 个工具")
|
||||
print(f" 差异: {tool_count_diff:+d} 个工具")
|
||||
|
||||
print(f"\n[工具列表对比]")
|
||||
without_tools = set(result_without['tool_names'])
|
||||
with_tools = set(result_with['tool_names'])
|
||||
only_with = with_tools - without_tools
|
||||
only_without = without_tools - with_tools
|
||||
|
||||
if only_with:
|
||||
print(f" 仅在 '使用 finish_search' 中的工具: {', '.join(only_with)}")
|
||||
if only_without:
|
||||
print(f" 仅在 '不使用 finish_search' 中的工具: {', '.join(only_without)}")
|
||||
if not only_with and not only_without:
|
||||
print(f" 工具列表相同(除了 finish_search)")
|
||||
|
||||
# 显示其他token信息
|
||||
print(f"\n[其他 Token 信息]")
|
||||
print(f" Completion Tokens (不使用 finish_search): {result_without.get('completion_tokens', 0):,}")
|
||||
print(f" Completion Tokens (使用 finish_search): {result_with.get('completion_tokens', 0):,}")
|
||||
print(f" 总 Tokens (不使用 finish_search): {result_without.get('total_tokens', 0):,}")
|
||||
print(f" 总 Tokens (使用 finish_search): {result_with.get('total_tokens', 0):,}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
"without_finish_search": result_without,
|
||||
"with_finish_search": result_with,
|
||||
"comparison": {
|
||||
"prompt_token_diff": prompt_token_diff,
|
||||
"prompt_token_diff_percent": prompt_token_diff_percent,
|
||||
"tool_count_diff": tool_count_diff,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="对比使用 finish_search 工具与否的 token 消耗差异"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
default="compare_finish_search",
|
||||
help="测试用的聊天ID(默认: compare_finish_search)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
help="将结果保存到JSON文件(可选)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
||||
initialize_logging(verbose=False)
|
||||
|
||||
# 交互式输入问题
|
||||
print("\n" + "=" * 80)
|
||||
print("finish_search 工具 Token 消耗对比测试工具")
|
||||
print("=" * 80)
|
||||
question = input("\n请输入要查询的问题: ").strip()
|
||||
if not question:
|
||||
print("错误: 问题不能为空")
|
||||
return
|
||||
|
||||
# 连接数据库
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
print(f"错误: 数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 运行对比测试
|
||||
try:
|
||||
result = asyncio.run(
|
||||
compare_prompt_tokens(
|
||||
question=question,
|
||||
chat_id=args.chat_id,
|
||||
)
|
||||
)
|
||||
|
||||
# 如果指定了输出文件,保存结果
|
||||
if args.output:
|
||||
# 将thinking_steps转换为可序列化的格式
|
||||
output_result = result.copy()
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[结果已保存] {args.output}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[中断] 用户中断测试")
|
||||
except Exception as e:
|
||||
logger.error(f"测试失败: {e}", exc_info=True)
|
||||
print(f"\n[错误] 测试失败: {e}")
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -0,0 +1,556 @@
|
|||
"""
|
||||
表达方式按count分组的LLM评估和统计分析脚本
|
||||
|
||||
功能:
|
||||
1. 随机选择50条表达,至少要有20条count>1的项目,然后进行LLM评估
|
||||
2. 比较不同count之间的LLM评估合格率是否有显著差异
|
||||
- 首先每个count分开比较
|
||||
- 然后比较count为1和count大于1的两种
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from typing import List, Dict, Set, Tuple
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import db
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
logger = get_logger("expression_evaluator_count_analysis_llm")
|
||||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.json")
|
||||
|
||||
|
||||
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
"""
|
||||
加载已有的评估结果
|
||||
|
||||
Returns:
|
||||
(已有结果列表, 已评估的项目(situation, style)元组集合)
|
||||
"""
|
||||
if not os.path.exists(COUNT_ANALYSIS_FILE):
|
||||
return [], set()
|
||||
|
||||
try:
|
||||
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("evaluation_results", [])
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs = {(r["situation"], r["style"]) for r in results if "situation" in r and "style" in r}
|
||||
logger.info(f"已加载 {len(results)} 条已有评估结果")
|
||||
return results, evaluated_pairs
|
||||
except Exception as e:
|
||||
logger.error(f"加载已有评估结果失败: {e}")
|
||||
return [], set()
|
||||
|
||||
|
||||
def save_results(evaluation_results: List[Dict]):
|
||||
"""
|
||||
保存评估结果到文件
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
try:
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(evaluation_results),
|
||||
"evaluation_results": evaluation_results
|
||||
}
|
||||
|
||||
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
|
||||
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
|
||||
except Exception as e:
|
||||
logger.error(f"保存评估结果失败: {e}")
|
||||
print(f"\n✗ 保存评估结果失败: {e}")
|
||||
|
||||
|
||||
def select_expressions_for_evaluation(
|
||||
evaluated_pairs: Set[Tuple[str, str]] = None
|
||||
) -> List[Expression]:
|
||||
"""
|
||||
选择用于评估的表达方式
|
||||
选择所有count>1的项目,然后选择两倍数量的count=1的项目
|
||||
|
||||
Args:
|
||||
evaluated_pairs: 已评估的项目集合,用于避免重复
|
||||
|
||||
Returns:
|
||||
选中的表达方式列表
|
||||
"""
|
||||
if evaluated_pairs is None:
|
||||
evaluated_pairs = set()
|
||||
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
# 过滤出未评估的项目
|
||||
unevaluated = [
|
||||
expr for expr in all_expressions
|
||||
if (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
|
||||
if not unevaluated:
|
||||
logger.warning("所有项目都已评估完成")
|
||||
return []
|
||||
|
||||
# 按count分组
|
||||
count_eq1 = [expr for expr in unevaluated if expr.count == 1]
|
||||
count_gt1 = [expr for expr in unevaluated if expr.count > 1]
|
||||
|
||||
logger.info(f"未评估项目中:count=1的有{len(count_eq1)}条,count>1的有{len(count_gt1)}条")
|
||||
|
||||
# 选择所有count>1的项目
|
||||
selected_count_gt1 = count_gt1.copy()
|
||||
|
||||
# 选择count=1的项目,数量为count>1数量的2倍
|
||||
count_gt1_count = len(selected_count_gt1)
|
||||
count_eq1_needed = count_gt1_count * 2
|
||||
|
||||
if len(count_eq1) < count_eq1_needed:
|
||||
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条")
|
||||
count_eq1_needed = len(count_eq1)
|
||||
|
||||
# 随机选择count=1的项目
|
||||
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
|
||||
|
||||
selected = selected_count_gt1 + selected_count_eq1
|
||||
random.shuffle(selected) # 打乱顺序
|
||||
|
||||
logger.info(f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)")
|
||||
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择表达方式失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
|
||||
使用条件或使用情景:{situation}
|
||||
表达方式或言语风格:{style}
|
||||
|
||||
请从以下方面进行评估:
|
||||
1. 表达方式或言语风格 是否与使用条件或使用情景 匹配
|
||||
2. 允许部分语法错误或口头化或缺省出现
|
||||
3. 表达方式不能太过特指,需要具有泛用性
|
||||
4. 一般不涉及具体的人名或名称
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
||||
|
||||
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
return {
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"count": expression.count,
|
||||
"suitable": suitable,
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm",
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
"""
|
||||
对评估结果进行统计分析
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
if not evaluation_results:
|
||||
print("\n没有评估结果可供分析")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("统计分析结果")
|
||||
print("=" * 60)
|
||||
|
||||
# 按count分组统计
|
||||
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
count_groups[count]["total"] += 1
|
||||
if suitable:
|
||||
count_groups[count]["suitable"] += 1
|
||||
else:
|
||||
count_groups[count]["unsuitable"] += 1
|
||||
|
||||
# 显示每个count的统计
|
||||
print("\n【按count分组统计】")
|
||||
print("-" * 60)
|
||||
for count in sorted(count_groups.keys()):
|
||||
group = count_groups[count]
|
||||
total = group["total"]
|
||||
suitable = group["suitable"]
|
||||
unsuitable = group["unsuitable"]
|
||||
pass_rate = (suitable / total * 100) if total > 0 else 0
|
||||
|
||||
print(f"Count = {count}:")
|
||||
print(f" 总数: {total}")
|
||||
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
# 比较count=1和count>1
|
||||
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
|
||||
if count == 1:
|
||||
count_eq1_group["total"] += 1
|
||||
if suitable:
|
||||
count_eq1_group["suitable"] += 1
|
||||
else:
|
||||
count_eq1_group["unsuitable"] += 1
|
||||
else:
|
||||
count_gt1_group["total"] += 1
|
||||
if suitable:
|
||||
count_gt1_group["suitable"] += 1
|
||||
else:
|
||||
count_gt1_group["unsuitable"] += 1
|
||||
|
||||
print("\n【Count=1 vs Count>1 对比】")
|
||||
print("-" * 60)
|
||||
|
||||
eq1_total = count_eq1_group["total"]
|
||||
eq1_suitable = count_eq1_group["suitable"]
|
||||
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
|
||||
|
||||
gt1_total = count_gt1_group["total"]
|
||||
gt1_suitable = count_gt1_group["suitable"]
|
||||
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
|
||||
|
||||
print("Count = 1:")
|
||||
print(f" 总数: {eq1_total}")
|
||||
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)")
|
||||
print()
|
||||
print("Count > 1:")
|
||||
print(f" 总数: {gt1_total}")
|
||||
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
# 进行卡方检验(简化版,使用2x2列联表)
|
||||
if eq1_total > 0 and gt1_total > 0:
|
||||
print("【统计显著性检验】")
|
||||
print("-" * 60)
|
||||
|
||||
# 构建2x2列联表
|
||||
# 通过 不通过
|
||||
# count=1 a b
|
||||
# count>1 c d
|
||||
a = eq1_suitable
|
||||
b = eq1_total - eq1_suitable
|
||||
c = gt1_suitable
|
||||
d = gt1_total - gt1_suitable
|
||||
|
||||
# 计算卡方统计量(简化版,使用Pearson卡方检验)
|
||||
n = eq1_total + gt1_total
|
||||
if n > 0:
|
||||
# 期望频数
|
||||
e_a = (eq1_total * (a + c)) / n
|
||||
e_b = (eq1_total * (b + d)) / n
|
||||
e_c = (gt1_total * (a + c)) / n
|
||||
e_d = (gt1_total * (b + d)) / n
|
||||
|
||||
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5)
|
||||
min_expected = min(e_a, e_b, e_c, e_d)
|
||||
if min_expected < 5:
|
||||
print("警告:期望频数小于5,卡方检验可能不准确")
|
||||
print("建议使用Fisher精确检验")
|
||||
|
||||
# 计算卡方值
|
||||
chi_square = 0
|
||||
if e_a > 0:
|
||||
chi_square += ((a - e_a) ** 2) / e_a
|
||||
if e_b > 0:
|
||||
chi_square += ((b - e_b) ** 2) / e_b
|
||||
if e_c > 0:
|
||||
chi_square += ((c - e_c) ** 2) / e_c
|
||||
if e_d > 0:
|
||||
chi_square += ((d - e_d) ** 2) / e_d
|
||||
|
||||
# 自由度 = (行数-1) * (列数-1) = 1
|
||||
df = 1
|
||||
|
||||
# 临界值(α=0.05)
|
||||
chi_square_critical_005 = 3.841
|
||||
chi_square_critical_001 = 6.635
|
||||
|
||||
print(f"卡方统计量: {chi_square:.4f}")
|
||||
print(f"自由度: {df}")
|
||||
print(f"临界值 (α=0.05): {chi_square_critical_005}")
|
||||
print(f"临界值 (α=0.01): {chi_square_critical_001}")
|
||||
|
||||
if chi_square >= chi_square_critical_001:
|
||||
print("结论: 在α=0.01水平下,count=1和count>1的合格率存在显著差异(p<0.01)")
|
||||
elif chi_square >= chi_square_critical_005:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率存在显著差异(p<0.05)")
|
||||
else:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率不存在显著差异(p≥0.05)")
|
||||
|
||||
# 计算差异大小
|
||||
diff = abs(eq1_pass_rate - gt1_pass_rate)
|
||||
print(f"\n合格率差异: {diff:.2f}%")
|
||||
if diff > 10:
|
||||
print("差异较大(>10%)")
|
||||
elif diff > 5:
|
||||
print("差异中等(5-10%)")
|
||||
else:
|
||||
print("差异较小(<5%)")
|
||||
else:
|
||||
print("数据不足,无法进行统计检验")
|
||||
else:
|
||||
print("数据不足,无法进行count=1和count>1的对比分析")
|
||||
|
||||
# 保存统计分析结果
|
||||
analysis_result = {
|
||||
"analysis_time": datetime.now().isoformat(),
|
||||
"count_groups": {str(k): v for k, v in count_groups.items()},
|
||||
"count_eq1": count_eq1_group,
|
||||
"count_gt1": count_gt1_group,
|
||||
"total_evaluated": len(evaluation_results)
|
||||
}
|
||||
|
||||
try:
|
||||
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
|
||||
with open(analysis_file, "w", encoding="utf-8") as f:
|
||||
json.dump(analysis_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n✓ 统计分析结果已保存到: {analysis_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存统计分析结果失败: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式按count分组的LLM评估和统计分析")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
logger.info("数据库连接成功")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 加载已有评估结果
|
||||
existing_results, evaluated_pairs = load_existing_results()
|
||||
evaluation_results = existing_results.copy()
|
||||
|
||||
if evaluated_pairs:
|
||||
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
|
||||
print(f"已评估项目数: {len(evaluated_pairs)}")
|
||||
|
||||
# 检查是否需要继续评估(检查是否还有未评估的count>1项目)
|
||||
# 先查询未评估的count>1项目数量
|
||||
try:
|
||||
all_expressions = list(Expression.select())
|
||||
unevaluated_count_gt1 = [
|
||||
expr for expr in all_expressions
|
||||
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
has_unevaluated = len(unevaluated_count_gt1) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"查询未评估项目失败: {e}")
|
||||
has_unevaluated = False
|
||||
|
||||
if has_unevaluated:
|
||||
print("\n" + "=" * 60)
|
||||
print("开始LLM评估")
|
||||
print("=" * 60)
|
||||
print("评估结果会自动保存到文件\n")
|
||||
|
||||
# 创建LLM实例
|
||||
print("创建LLM实例...")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_count_analysis_llm"
|
||||
)
|
||||
print("✓ LLM实例创建成功\n")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
print(f"\n✗ 创建LLM实例失败: {e}")
|
||||
db.close()
|
||||
return
|
||||
|
||||
# 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目)
|
||||
expressions = select_expressions_for_evaluation(
|
||||
evaluated_pairs=evaluated_pairs
|
||||
)
|
||||
|
||||
if not expressions:
|
||||
print("\n没有可评估的项目")
|
||||
else:
|
||||
print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
|
||||
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)} 条")
|
||||
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)} 条\n")
|
||||
|
||||
batch_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
print(f"LLM评估进度: {i}/{len(expressions)}")
|
||||
print(f" Situation: {expression.situation}")
|
||||
print(f" Style: {expression.style}")
|
||||
print(f" Count: {expression.count}")
|
||||
|
||||
llm_result = await llm_evaluate_expression(expression, llm)
|
||||
|
||||
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
|
||||
if llm_result.get('error'):
|
||||
print(f" 错误: {llm_result['error']}")
|
||||
print()
|
||||
|
||||
batch_results.append(llm_result)
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
|
||||
|
||||
# 添加延迟以避免API限流
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 将当前批次结果添加到总结果中
|
||||
evaluation_results.extend(batch_results)
|
||||
|
||||
# 保存结果
|
||||
save_results(evaluation_results)
|
||||
else:
|
||||
print(f"\n所有count>1的项目都已评估完成,已有 {len(evaluation_results)} 条评估结果")
|
||||
|
||||
# 进行统计分析
|
||||
if len(evaluation_results) > 0:
|
||||
perform_statistical_analysis(evaluation_results)
|
||||
else:
|
||||
print("\n没有评估结果可供分析")
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
@ -0,0 +1,523 @@
|
|||
"""
|
||||
表达方式LLM评估脚本
|
||||
|
||||
功能:
|
||||
1. 读取已保存的人工评估结果(作为效标)
|
||||
2. 使用LLM对相同项目进行评估
|
||||
3. 对比人工评估和LLM评估的结果,输出分析报告
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
from typing import List, Dict, Set, Tuple
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expression_evaluator_llm")
|
||||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
|
||||
|
||||
def load_manual_results() -> List[Dict]:
|
||||
"""
|
||||
加载人工评估结果(自动读取temp目录下所有JSON文件并合并)
|
||||
|
||||
Returns:
|
||||
人工评估结果列表(已去重)
|
||||
"""
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
logger.error(f"未找到temp目录: {TEMP_DIR}")
|
||||
print("\n✗ 错误:未找到temp目录")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
if not json_files:
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
print("\n✗ 错误:temp目录下未找到JSON文件")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
logger.info(f"找到 {len(json_files)} 个JSON文件")
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件:")
|
||||
for json_file in json_files:
|
||||
print(f" - {os.path.basename(json_file)}")
|
||||
|
||||
# 读取并合并所有JSON文件
|
||||
all_results = []
|
||||
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
|
||||
|
||||
for json_file in json_files:
|
||||
try:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
|
||||
# 去重:使用(situation, style)作为唯一标识
|
||||
for result in results:
|
||||
if "situation" not in result or "style" not in result:
|
||||
logger.warning(f"跳过无效数据(缺少必要字段): {result}")
|
||||
continue
|
||||
|
||||
pair = (result["situation"], result["style"])
|
||||
if pair not in seen_pairs:
|
||||
seen_pairs.add(pair)
|
||||
all_results.append(result)
|
||||
|
||||
logger.info(f"从 {os.path.basename(json_file)} 加载了 {len(results)} 条结果")
|
||||
except Exception as e:
|
||||
logger.error(f"加载文件 {json_file} 失败: {e}")
|
||||
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
|
||||
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
|
||||
使用条件或使用情景:{situation}
|
||||
表达方式或言语风格:{style}
|
||||
|
||||
请从以下方面进行评估:
|
||||
1. 表达方式或言语风格 是否与使用条件或使用情景 匹配
|
||||
2. 允许部分语法错误或口头化或缺省出现
|
||||
3. 表达方式不能太过特指,需要具有泛用性
|
||||
4. 一般不涉及具体的人名或名称
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
import re
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
||||
|
||||
async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(situation, style, llm)
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
return {
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"suitable": suitable,
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm"
|
||||
}
|
||||
|
||||
|
||||
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
|
||||
"""
|
||||
对比人工评估和LLM评估的结果
|
||||
|
||||
Args:
|
||||
manual_results: 人工评估结果列表
|
||||
llm_results: LLM评估结果列表
|
||||
method_name: 评估方法名称(用于标识)
|
||||
|
||||
Returns:
|
||||
对比分析结果字典
|
||||
"""
|
||||
# 按(situation, style)建立映射
|
||||
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
|
||||
|
||||
total = len(manual_results)
|
||||
matched = 0
|
||||
true_positives = 0
|
||||
true_negatives = 0
|
||||
false_positives = 0
|
||||
false_negatives = 0
|
||||
|
||||
for manual_result in manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
manual_suitable = manual_result["suitable"]
|
||||
llm_suitable = llm_result["suitable"]
|
||||
|
||||
if manual_suitable == llm_suitable:
|
||||
matched += 1
|
||||
|
||||
if manual_suitable and llm_suitable:
|
||||
true_positives += 1
|
||||
elif not manual_suitable and not llm_suitable:
|
||||
true_negatives += 1
|
||||
elif not manual_suitable and llm_suitable:
|
||||
false_positives += 1
|
||||
elif manual_suitable and not llm_suitable:
|
||||
false_negatives += 1
|
||||
|
||||
accuracy = (matched / total * 100) if total > 0 else 0
|
||||
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
|
||||
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
|
||||
# 计算人工效标的不合适率
|
||||
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
|
||||
manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0
|
||||
|
||||
# 计算经过LLM删除后剩余项目中的不合适率
|
||||
# 在所有项目中,移除LLM判定为不合适的项目后,剩下的项目 = TP + FP(LLM判定为合适的项目)
|
||||
# 在这些剩下的项目中,按人工评定的不合适项目 = FP(人工认为不合适,但LLM认为合适)
|
||||
llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数(保留的项目)
|
||||
llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0
|
||||
|
||||
# 两者百分比相减(评估LLM评定修正后的不合适率是否有降低)
|
||||
rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate
|
||||
|
||||
random_baseline = 50.0
|
||||
accuracy_above_random = accuracy - random_baseline
|
||||
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
|
||||
|
||||
return {
|
||||
"method": method_name,
|
||||
"total": total,
|
||||
"matched": matched,
|
||||
"accuracy": accuracy,
|
||||
"accuracy_above_random": accuracy_above_random,
|
||||
"accuracy_improvement_ratio": accuracy_improvement_ratio,
|
||||
"true_positives": true_positives,
|
||||
"true_negatives": true_negatives,
|
||||
"false_positives": false_positives,
|
||||
"false_negatives": false_negatives,
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1_score,
|
||||
"specificity": specificity,
|
||||
"manual_unsuitable_rate": manual_unsuitable_rate,
|
||||
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
|
||||
"rate_difference": rate_difference
|
||||
}
|
||||
|
||||
|
||||
async def main(count: int | None = None):
|
||||
"""
|
||||
主函数
|
||||
|
||||
Args:
|
||||
count: 随机选取的数据条数,如果为None则使用全部数据
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式LLM评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 1. 加载人工评估结果
|
||||
print("\n步骤1: 加载人工评估结果")
|
||||
manual_results = load_manual_results()
|
||||
if not manual_results:
|
||||
return
|
||||
|
||||
print(f"成功加载 {len(manual_results)} 条人工评估结果")
|
||||
|
||||
# 如果指定了数量,随机选择指定数量的数据
|
||||
if count is not None:
|
||||
if count <= 0:
|
||||
print(f"\n✗ 错误:指定的数量必须大于0,当前值: {count}")
|
||||
return
|
||||
if count > len(manual_results):
|
||||
print(f"\n⚠ 警告:指定的数量 ({count}) 大于可用数据量 ({len(manual_results)}),将使用全部数据")
|
||||
else:
|
||||
random.seed() # 使用系统时间作为随机种子
|
||||
manual_results = random.sample(manual_results, count)
|
||||
print(f"随机选取 {len(manual_results)} 条数据进行评估")
|
||||
|
||||
# 验证数据完整性
|
||||
valid_manual_results = []
|
||||
for r in manual_results:
|
||||
if "situation" in r and "style" in r:
|
||||
valid_manual_results.append(r)
|
||||
else:
|
||||
logger.warning(f"跳过无效数据: {r}")
|
||||
|
||||
if len(valid_manual_results) != len(manual_results):
|
||||
print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过")
|
||||
|
||||
print(f"有效数据: {len(valid_manual_results)} 条")
|
||||
|
||||
# 2. 创建LLM实例并评估
|
||||
print("\n步骤2: 创建LLM实例")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_llm"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
print("\n步骤3: 开始LLM评估")
|
||||
llm_results = []
|
||||
for i, manual_result in enumerate(valid_manual_results, 1):
|
||||
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
|
||||
llm_results.append(await evaluate_expression_llm(
|
||||
manual_result["situation"],
|
||||
manual_result["style"],
|
||||
llm
|
||||
))
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 5. 输出FP和FN项目(在评估结果之前)
|
||||
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
|
||||
|
||||
# 5.1 输出FP项目(人工评估不通过但LLM误判为通过)
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估不通过但LLM误判为通过的项目(FP - False Positive)")
|
||||
print("=" * 60)
|
||||
|
||||
fp_items = []
|
||||
for manual_result in valid_manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
# 人工评估不通过,但LLM评估通过(FP情况)
|
||||
if not manual_result["suitable"] and llm_result["suitable"]:
|
||||
fp_items.append({
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
|
||||
if fp_items:
|
||||
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
|
||||
for idx, item in enumerate(fp_items, 1):
|
||||
print(f"--- [{idx}] ---")
|
||||
print(f"Situation: {item['situation']}")
|
||||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 不通过 ❌")
|
||||
print("LLM评估: 通过 ✅ (误判)")
|
||||
if item.get('llm_error'):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误判项目(所有人工评估不通过的项目都被LLM正确识别为不通过)")
|
||||
|
||||
# 5.2 输出FN项目(人工评估通过但LLM误判为不通过)
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估通过但LLM误判为不通过的项目(FN - False Negative)")
|
||||
print("=" * 60)
|
||||
|
||||
fn_items = []
|
||||
for manual_result in valid_manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
# 人工评估通过,但LLM评估不通过(FN情况)
|
||||
if manual_result["suitable"] and not llm_result["suitable"]:
|
||||
fn_items.append({
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
|
||||
if fn_items:
|
||||
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
|
||||
for idx, item in enumerate(fn_items, 1):
|
||||
print(f"--- [{idx}] ---")
|
||||
print(f"Situation: {item['situation']}")
|
||||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 通过 ✅")
|
||||
print("LLM评估: 不通过 ❌ (误删)")
|
||||
if item.get('llm_error'):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误删项目(所有人工评估通过的项目都被LLM正确识别为通过)")
|
||||
|
||||
# 6. 对比分析并输出结果
|
||||
comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估结果(以人工评估为标准)")
|
||||
print("=" * 60)
|
||||
|
||||
# 详细评估结果(核心指标优先)
|
||||
print(f"\n--- {comparison['method']} ---")
|
||||
print(f" 总数: {comparison['total']} 条")
|
||||
print()
|
||||
# print(" 【核心能力指标】")
|
||||
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
|
||||
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个")
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
|
||||
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
|
||||
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个")
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(" 【其他指标】")
|
||||
print(f" 准确率: {comparison['accuracy']:.2f}% (整体判断正确率)")
|
||||
print(f" 精确率: {comparison['precision']:.2f}% (判断为合适的项目中,实际合适的比例)")
|
||||
print(f" F1分数: {comparison['f1_score']:.2f} (精确率和召回率的调和平均)")
|
||||
print(f" 匹配数: {comparison['matched']}/{comparison['total']}")
|
||||
print()
|
||||
print(" 【不合适率分析】")
|
||||
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
|
||||
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}")
|
||||
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
|
||||
print()
|
||||
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
print()
|
||||
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
|
||||
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
# print(f" - 含义: {'LLM删除后剩余项目中的不合适率降低了' if comparison['rate_difference'] > 0 else 'LLM删除后剩余项目中的不合适率反而升高了' if comparison['rate_difference'] < 0 else '两者相等'} ({'✓ LLM删除有效' if comparison['rate_difference'] > 0 else '✗ LLM删除效果不佳' if comparison['rate_difference'] < 0 else '效果相同'})")
|
||||
# print()
|
||||
print(" 【分类统计】")
|
||||
print(f" TP (正确识别为合适): {comparison['true_positives']}")
|
||||
print(f" TN (正确识别为不合适): {comparison['true_negatives']} ⭐")
|
||||
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
|
||||
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
|
||||
|
||||
# 7. 保存结果到JSON文件
|
||||
output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"manual_results": valid_manual_results,
|
||||
"llm_results": llm_results,
|
||||
"comparison": comparison
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"\n评估结果已保存到: {output_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存结果到文件失败: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="表达方式LLM评估脚本",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
python evaluate_expressions_llm_v6.py # 使用全部数据
|
||||
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
|
||||
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--count",
|
||||
type=int,
|
||||
default=None,
|
||||
help="随机选取的数据条数(默认:使用全部数据)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(count=args.count))
|
||||
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
"""
|
||||
表达方式人工评估脚本
|
||||
|
||||
功能:
|
||||
1. 不停随机抽取项目(不重复)进行人工评估
|
||||
2. 将结果保存到 temp 文件夹下的 JSON 文件,作为效标(标准答案)
|
||||
3. 支持继续评估(从已有文件中读取已评估的项目,避免重复)
|
||||
"""
|
||||
|
||||
import random
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from typing import List, Dict, Set, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import db
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expression_evaluator_manual")
|
||||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
|
||||
|
||||
|
||||
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
"""
|
||||
加载已有的评估结果
|
||||
|
||||
Returns:
|
||||
(已有结果列表, 已评估的项目(situation, style)元组集合)
|
||||
"""
|
||||
if not os.path.exists(MANUAL_EVAL_FILE):
|
||||
return [], set()
|
||||
|
||||
try:
|
||||
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs = {(r["situation"], r["style"]) for r in results if "situation" in r and "style" in r}
|
||||
logger.info(f"已加载 {len(results)} 条已有评估结果")
|
||||
return results, evaluated_pairs
|
||||
except Exception as e:
|
||||
logger.error(f"加载已有评估结果失败: {e}")
|
||||
return [], set()
|
||||
|
||||
|
||||
def save_results(manual_results: List[Dict]):
|
||||
"""
|
||||
保存评估结果到文件
|
||||
|
||||
Args:
|
||||
manual_results: 评估结果列表
|
||||
"""
|
||||
try:
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(manual_results),
|
||||
"manual_results": manual_results
|
||||
}
|
||||
|
||||
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}")
|
||||
print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)")
|
||||
except Exception as e:
|
||||
logger.error(f"保存评估结果失败: {e}")
|
||||
print(f"\n✗ 保存评估结果失败: {e}")
|
||||
|
||||
|
||||
def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]:
|
||||
"""
|
||||
获取未评估的表达方式
|
||||
|
||||
Args:
|
||||
evaluated_pairs: 已评估的项目(situation, style)元组集合
|
||||
batch_size: 每次获取的数量
|
||||
|
||||
Returns:
|
||||
未评估的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
# 过滤出未评估的项目:匹配 situation 和 style 均一致
|
||||
unevaluated = [
|
||||
expr for expr in all_expressions
|
||||
if (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
|
||||
if not unevaluated:
|
||||
logger.info("所有项目都已评估完成")
|
||||
return []
|
||||
|
||||
# 如果未评估数量少于请求数量,返回所有
|
||||
if len(unevaluated) <= batch_size:
|
||||
logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回")
|
||||
return unevaluated
|
||||
|
||||
# 随机选择指定数量
|
||||
selected = random.sample(unevaluated, batch_size)
|
||||
logger.info(f"从 {len(unevaluated)} 条未评估项目中随机选择了 {len(selected)} 条")
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取未评估表达方式失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
||||
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
|
||||
"""
|
||||
人工评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
index: 当前索引(从1开始)
|
||||
total: 总数
|
||||
|
||||
Returns:
|
||||
评估结果字典,如果用户退出则返回 None
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print(f"人工评估 [{index}/{total}]")
|
||||
print("=" * 60)
|
||||
print(f"Situation: {expression.situation}")
|
||||
print(f"Style: {expression.style}")
|
||||
print("\n请评估该表达方式是否合适:")
|
||||
print(" 输入 'y' 或 'yes' 或 '1' 表示合适(通过)")
|
||||
print(" 输入 'n' 或 'no' 或 '0' 表示不合适(不通过)")
|
||||
print(" 输入 'q' 或 'quit' 退出评估")
|
||||
print(" 输入 's' 或 'skip' 跳过当前项目")
|
||||
|
||||
while True:
|
||||
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
|
||||
|
||||
if user_input in ['q', 'quit']:
|
||||
print("退出评估")
|
||||
return None
|
||||
|
||||
if user_input in ['s', 'skip']:
|
||||
print("跳过当前项目")
|
||||
return "skip"
|
||||
|
||||
if user_input in ['y', 'yes', '1', '是', '通过']:
|
||||
suitable = True
|
||||
break
|
||||
elif user_input in ['n', 'no', '0', '否', '不通过']:
|
||||
suitable = False
|
||||
break
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n/q/s)")
|
||||
|
||||
result = {
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"suitable": suitable,
|
||||
"reason": None,
|
||||
"evaluator": "manual",
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式人工评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
logger.info("数据库连接成功")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 加载已有评估结果
|
||||
existing_results, evaluated_pairs = load_existing_results()
|
||||
manual_results = existing_results.copy()
|
||||
|
||||
if evaluated_pairs:
|
||||
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
|
||||
print(f"已评估项目数: {len(evaluated_pairs)}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("开始人工评估")
|
||||
print("=" * 60)
|
||||
print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目")
|
||||
print("评估结果会自动保存到文件\n")
|
||||
|
||||
batch_size = 10
|
||||
batch_count = 0
|
||||
|
||||
while True:
|
||||
# 获取未评估的项目
|
||||
expressions = get_unevaluated_expressions(evaluated_pairs, batch_size)
|
||||
|
||||
if not expressions:
|
||||
print("\n" + "=" * 60)
|
||||
print("所有项目都已评估完成!")
|
||||
print("=" * 60)
|
||||
break
|
||||
|
||||
batch_count += 1
|
||||
print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---")
|
||||
|
||||
batch_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
manual_result = manual_evaluate_expression(expression, i, len(expressions))
|
||||
|
||||
if manual_result is None:
|
||||
# 用户退出
|
||||
print("\n评估已中断")
|
||||
if batch_results:
|
||||
# 保存当前批次的结果
|
||||
manual_results.extend(batch_results)
|
||||
save_results(manual_results)
|
||||
return
|
||||
|
||||
if manual_result == "skip":
|
||||
# 跳过当前项目
|
||||
continue
|
||||
|
||||
batch_results.append(manual_result)
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs.add((manual_result["situation"], manual_result["style"]))
|
||||
|
||||
# 将当前批次结果添加到总结果中
|
||||
manual_results.extend(batch_results)
|
||||
|
||||
# 保存结果
|
||||
save_results(manual_results)
|
||||
|
||||
print(f"\n当前批次完成,已评估总数: {len(manual_results)} 条")
|
||||
|
||||
# 询问是否继续
|
||||
while True:
|
||||
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
|
||||
if continue_input in ['y', 'yes', '1', '是', '继续']:
|
||||
break
|
||||
elif continue_input in ['n', 'no', '0', '否', '退出']:
|
||||
print("\n评估结束")
|
||||
return
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n)")
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -0,0 +1,476 @@
|
|||
"""
|
||||
表达方式评估脚本
|
||||
|
||||
功能:
|
||||
1. 随机读取指定数量的表达方式,获取其situation和style
|
||||
2. 先进行人工评估(逐条手动评估)
|
||||
3. 然后使用LLM进行评估
|
||||
4. 对比人工评估和LLM评估的正确率、精确率、召回率、F1分数等指标(以人工评估为标准)
|
||||
5. 不真正修改数据库,只是做评估
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import db
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expression_evaluator_comparison")
|
||||
|
||||
|
||||
def get_random_expressions(count: int = 10) -> List[Expression]:
|
||||
"""
|
||||
随机读取指定数量的表达方式
|
||||
|
||||
Args:
|
||||
count: 要读取的数量,默认10条
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
# 如果总数少于请求数量,返回所有
|
||||
if len(all_expressions) <= count:
|
||||
logger.info(f"数据库中共有 {len(all_expressions)} 条表达方式,全部返回")
|
||||
return all_expressions
|
||||
|
||||
# 随机选择指定数量
|
||||
selected = random.sample(all_expressions, count)
|
||||
logger.info(f"从 {len(all_expressions)} 条表达方式中随机选择了 {len(selected)} 条")
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机读取表达方式失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
||||
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
|
||||
"""
|
||||
人工评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
index: 当前索引(从1开始)
|
||||
total: 总数
|
||||
|
||||
Returns:
|
||||
评估结果字典,包含:
|
||||
- expression_id: 表达方式ID
|
||||
- situation: 情境
|
||||
- style: 风格
|
||||
- suitable: 是否合适(人工评估)
|
||||
- reason: 评估理由(始终为None)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print(f"人工评估 [{index}/{total}]")
|
||||
print("=" * 60)
|
||||
print(f"Situation: {expression.situation}")
|
||||
print(f"Style: {expression.style}")
|
||||
print("\n请评估该表达方式是否合适:")
|
||||
print(" 输入 'y' 或 'yes' 或 '1' 表示合适(通过)")
|
||||
print(" 输入 'n' 或 'no' 或 '0' 表示不合适(不通过)")
|
||||
print(" 输入 'q' 或 'quit' 退出评估")
|
||||
|
||||
while True:
|
||||
user_input = input("\n您的评估 (y/n/q): ").strip().lower()
|
||||
|
||||
if user_input in ['q', 'quit']:
|
||||
print("退出评估")
|
||||
return None
|
||||
|
||||
if user_input in ['y', 'yes', '1', '是', '通过']:
|
||||
suitable = True
|
||||
break
|
||||
elif user_input in ['n', 'no', '0', '否', '不通过']:
|
||||
suitable = False
|
||||
break
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n/q)")
|
||||
|
||||
result = {
|
||||
"expression_id": expression.id,
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"suitable": suitable,
|
||||
"reason": None,
|
||||
"evaluator": "manual"
|
||||
}
|
||||
|
||||
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
prompt = f"""请评估以下表达方式是否合适:
|
||||
|
||||
情境(situation):{situation}
|
||||
风格(style):{style}
|
||||
|
||||
请从以下方面进行评估:
|
||||
1. 情境描述是否清晰、准确
|
||||
2. 风格表达是否合理、自然
|
||||
3. 情境和风格是否匹配
|
||||
4. 允许部分语法错误出现
|
||||
5. 允许口头化或缺省表达
|
||||
6. 允许部分上下文缺失
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
}}
|
||||
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(expression: Expression, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(expression.situation, expression.style)
|
||||
logger.debug(f"正在评估表达方式 ID: {expression.id}")
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
import re
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果")
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 ID: {expression.id} 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
||||
|
||||
async def evaluate_expression_llm(expression: Expression, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式 ID: {expression.id}")
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(expression, llm)
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
return {
|
||||
"expression_id": expression.id,
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"suitable": suitable,
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm"
|
||||
}
|
||||
|
||||
|
||||
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
|
||||
"""
|
||||
对比人工评估和LLM评估的结果
|
||||
|
||||
Args:
|
||||
manual_results: 人工评估结果列表
|
||||
llm_results: LLM评估结果列表
|
||||
method_name: 评估方法名称(用于标识)
|
||||
|
||||
Returns:
|
||||
对比分析结果字典
|
||||
"""
|
||||
# 按expression_id建立映射
|
||||
llm_dict = {r["expression_id"]: r for r in llm_results}
|
||||
|
||||
total = len(manual_results)
|
||||
matched = 0
|
||||
true_positives = 0
|
||||
true_negatives = 0
|
||||
false_positives = 0
|
||||
false_negatives = 0
|
||||
|
||||
for manual_result in manual_results:
|
||||
llm_result = llm_dict.get(manual_result["expression_id"])
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
manual_suitable = manual_result["suitable"]
|
||||
llm_suitable = llm_result["suitable"]
|
||||
|
||||
if manual_suitable == llm_suitable:
|
||||
matched += 1
|
||||
|
||||
if manual_suitable and llm_suitable:
|
||||
true_positives += 1
|
||||
elif not manual_suitable and not llm_suitable:
|
||||
true_negatives += 1
|
||||
elif not manual_suitable and llm_suitable:
|
||||
false_positives += 1
|
||||
elif manual_suitable and not llm_suitable:
|
||||
false_negatives += 1
|
||||
|
||||
accuracy = (matched / total * 100) if total > 0 else 0
|
||||
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
|
||||
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
|
||||
random_baseline = 50.0
|
||||
accuracy_above_random = accuracy - random_baseline
|
||||
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
|
||||
|
||||
return {
|
||||
"method": method_name,
|
||||
"total": total,
|
||||
"matched": matched,
|
||||
"accuracy": accuracy,
|
||||
"accuracy_above_random": accuracy_above_random,
|
||||
"accuracy_improvement_ratio": accuracy_improvement_ratio,
|
||||
"true_positives": true_positives,
|
||||
"true_negatives": true_negatives,
|
||||
"false_positives": false_positives,
|
||||
"false_negatives": false_negatives,
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1_score,
|
||||
"specificity": specificity
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
logger.info("数据库连接成功")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 1. 随机读取表达方式
|
||||
logger.info("\n步骤1: 随机读取表达方式")
|
||||
expressions = get_random_expressions(10)
|
||||
if not expressions:
|
||||
logger.error("没有可用的表达方式,退出")
|
||||
return
|
||||
logger.info(f"成功读取 {len(expressions)} 条表达方式")
|
||||
|
||||
# 2. 人工评估
|
||||
print("\n" + "=" * 60)
|
||||
print("开始人工评估")
|
||||
print("=" * 60)
|
||||
print(f"共需要评估 {len(expressions)} 条表达方式")
|
||||
print("请逐条进行评估...\n")
|
||||
|
||||
manual_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
manual_result = manual_evaluate_expression(expression, i, len(expressions))
|
||||
if manual_result is None:
|
||||
print("\n评估已中断")
|
||||
return
|
||||
manual_results.append(manual_result)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估完成")
|
||||
print("=" * 60)
|
||||
|
||||
# 3. 创建LLM实例并评估
|
||||
logger.info("\n步骤3: 创建LLM实例")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_comparison"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
logger.info("\n步骤4: 开始LLM评估")
|
||||
llm_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
logger.info(f"LLM评估进度: {i}/{len(expressions)}")
|
||||
llm_results.append(await evaluate_expression_llm(expression, llm))
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 4. 对比分析并输出结果
|
||||
comparison = compare_evaluations(manual_results, llm_results, "LLM评估")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估结果(以人工评估为标准)")
|
||||
print("=" * 60)
|
||||
print("\n评估目标:")
|
||||
print(" 1. 核心能力:将不合适的项目正确提取出来(特定负类召回率)")
|
||||
print(" 2. 次要能力:尽可能少的误删合适的项目(召回率)")
|
||||
|
||||
# 详细评估结果(核心指标优先)
|
||||
print("\n【详细对比】")
|
||||
print(f"\n--- {comparison['method']} ---")
|
||||
print(f" 总数: {comparison['total']} 条")
|
||||
print()
|
||||
print(" 【核心能力指标】")
|
||||
print(f" ⭐ 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
|
||||
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个")
|
||||
print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(f" ⭐ 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
|
||||
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
|
||||
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个")
|
||||
print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(" 【其他指标】")
|
||||
print(f" 准确率: {comparison['accuracy']:.2f}% (整体判断正确率)")
|
||||
print(f" 精确率: {comparison['precision']:.2f}% (判断为合适的项目中,实际合适的比例)")
|
||||
print(f" F1分数: {comparison['f1_score']:.2f} (精确率和召回率的调和平均)")
|
||||
print(f" 匹配数: {comparison['matched']}/{comparison['total']}")
|
||||
print()
|
||||
print(" 【分类统计】")
|
||||
print(f" TP (正确识别为合适): {comparison['true_positives']}")
|
||||
print(f" TN (正确识别为不合适): {comparison['true_negatives']} ⭐")
|
||||
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
|
||||
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
|
||||
|
||||
# 5. 输出人工评估不通过但LLM误判为通过的详细信息
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估不通过但LLM误判为通过的项目(FP - False Positive)")
|
||||
print("=" * 60)
|
||||
|
||||
# 按expression_id建立映射
|
||||
llm_dict = {r["expression_id"]: r for r in llm_results}
|
||||
|
||||
fp_items = []
|
||||
for manual_result in manual_results:
|
||||
llm_result = llm_dict.get(manual_result["expression_id"])
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
# 人工评估不通过,但LLM评估通过(FP情况)
|
||||
if not manual_result["suitable"] and llm_result["suitable"]:
|
||||
fp_items.append({
|
||||
"expression_id": manual_result["expression_id"],
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
|
||||
if fp_items:
|
||||
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
|
||||
for idx, item in enumerate(fp_items, 1):
|
||||
print(f"--- [{idx}] 项目 ID: {item['expression_id']} ---")
|
||||
print(f"Situation: {item['situation']}")
|
||||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 不通过 ❌")
|
||||
print("LLM评估: 通过 ✅ (误判)")
|
||||
if item.get('llm_error'):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误判项目(所有人工评估不通过的项目都被LLM正确识别为不通过)")
|
||||
|
||||
# 6. 保存结果到JSON文件
|
||||
output_file = os.path.join(project_root, "data", "expression_evaluation_comparison.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"manual_results": manual_results,
|
||||
"llm_results": llm_results,
|
||||
"comparison": comparison
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"\n评估结果已保存到: {output_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存结果到文件失败: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估完成")
|
||||
print("=" * 60)
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
@ -1,567 +0,0 @@
|
|||
"""
|
||||
模拟 Expression 合并过程
|
||||
|
||||
用法:
|
||||
python scripts/expression_merge_simulation.py
|
||||
或指定 chat_id:
|
||||
python scripts/expression_merge_simulation.py --chat-id <chat_id>
|
||||
或指定相似度阈值:
|
||||
python scripts/expression_merge_simulation.py --similarity-threshold 0.8
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Import after setting up path (required for project imports)
|
||||
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
|
||||
from src.bw_learner.learner_utils import calculate_style_similarity # noqa: E402
|
||||
from src.llm_models.utils_model import LLMRequest # noqa: E402
|
||||
from src.config.config import model_config # noqa: E402
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name}"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id[:8]}...)"
|
||||
|
||||
|
||||
def parse_content_list(stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 content_list JSON 字符串为列表"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
|
||||
def parse_style_list(stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 style_list JSON 字符串为列表"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
|
||||
def find_exact_style_match(
|
||||
expressions: List[Expression],
|
||||
target_style: str,
|
||||
chat_id: str,
|
||||
exclude_ids: set
|
||||
) -> Optional[Expression]:
|
||||
"""
|
||||
查找具有完全匹配 style 的 Expression 记录
|
||||
检查 style 字段和 style_list 中的每一项
|
||||
"""
|
||||
for expr in expressions:
|
||||
if expr.chat_id != chat_id or expr.id in exclude_ids:
|
||||
continue
|
||||
|
||||
# 检查 style 字段
|
||||
if expr.style == target_style:
|
||||
return expr
|
||||
|
||||
# 检查 style_list 中的每一项
|
||||
style_list = parse_style_list(expr.style_list)
|
||||
if target_style in style_list:
|
||||
return expr
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_similar_style_expression(
|
||||
expressions: List[Expression],
|
||||
target_style: str,
|
||||
chat_id: str,
|
||||
similarity_threshold: float,
|
||||
exclude_ids: set
|
||||
) -> Optional[Tuple[Expression, float]]:
|
||||
"""
|
||||
查找具有相似 style 的 Expression 记录
|
||||
检查 style 字段和 style_list 中的每一项
|
||||
|
||||
Returns:
|
||||
(Expression, similarity) 或 None
|
||||
"""
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for expr in expressions:
|
||||
if expr.chat_id != chat_id or expr.id in exclude_ids:
|
||||
continue
|
||||
|
||||
# 检查 style 字段
|
||||
similarity = calculate_style_similarity(target_style, expr.style)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
# 检查 style_list 中的每一项
|
||||
style_list = parse_style_list(expr.style_list)
|
||||
for existing_style in style_list:
|
||||
similarity = calculate_style_similarity(target_style, existing_style)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
if best_match:
|
||||
return (best_match, best_similarity)
|
||||
return None
|
||||
|
||||
|
||||
async def compose_situation_text(content_list: List[str], summary_model: LLMRequest) -> str:
|
||||
"""组合 situation 文本,尝试使用 LLM 总结"""
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
if not sanitized:
|
||||
return ""
|
||||
|
||||
if len(sanitized) == 1:
|
||||
return sanitized[0]
|
||||
|
||||
# 尝试使用 LLM 总结
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
print(f" ⚠️ LLM 总结 situation 失败: {e}")
|
||||
|
||||
# 如果总结失败,返回用 "/" 连接的字符串
|
||||
return "/".join(sanitized)
|
||||
|
||||
|
||||
async def compose_style_text(style_list: List[str], summary_model: LLMRequest) -> str:
|
||||
"""组合 style 文本,尝试使用 LLM 总结"""
|
||||
sanitized = [s.strip() for s in style_list if s.strip()]
|
||||
if not sanitized:
|
||||
return ""
|
||||
|
||||
if len(sanitized) == 1:
|
||||
return sanitized[0]
|
||||
|
||||
# 尝试使用 LLM 总结
|
||||
prompt = (
|
||||
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
|
||||
print(f"Prompt:{prompt} Summary:{summary}")
|
||||
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
print(f" ⚠️ LLM 总结 style 失败: {e}")
|
||||
|
||||
# 如果总结失败,返回第一个
|
||||
return sanitized[0]
|
||||
|
||||
|
||||
async def simulate_merge(
|
||||
expressions: List[Expression],
|
||||
similarity_threshold: float = 0.75,
|
||||
use_llm: bool = False,
|
||||
max_samples: int = 10,
|
||||
) -> Dict:
|
||||
"""
|
||||
模拟合并过程
|
||||
|
||||
Args:
|
||||
expressions: Expression 列表(从数据库读出的原始记录)
|
||||
similarity_threshold: style 相似度阈值
|
||||
use_llm: 是否使用 LLM 进行实际总结
|
||||
max_samples: 最多随机抽取的 Expression 数量(为 0 或 None 时表示不限制)
|
||||
|
||||
Returns:
|
||||
包含合并统计信息的字典
|
||||
"""
|
||||
# 如果样本太多,随机抽取一部分进行模拟,避免运行时间过长
|
||||
if max_samples and len(expressions) > max_samples:
|
||||
expressions = random.sample(expressions, max_samples)
|
||||
|
||||
# 按 chat_id 分组
|
||||
expressions_by_chat = defaultdict(list)
|
||||
for expr in expressions:
|
||||
expressions_by_chat[expr.chat_id].append(expr)
|
||||
|
||||
# 初始化 LLM 模型(如果需要)
|
||||
summary_model = None
|
||||
if use_llm:
|
||||
try:
|
||||
summary_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="expression.summary"
|
||||
)
|
||||
print("✅ LLM 模型已初始化,将进行实际总结")
|
||||
except Exception as e:
|
||||
print(f"⚠️ LLM 模型初始化失败: {e},将跳过 LLM 总结")
|
||||
use_llm = False
|
||||
|
||||
merge_stats = {
|
||||
"total_expressions": len(expressions),
|
||||
"total_chats": len(expressions_by_chat),
|
||||
"exact_matches": 0,
|
||||
"similar_matches": 0,
|
||||
"new_records": 0,
|
||||
"merge_details": [],
|
||||
"chat_stats": {},
|
||||
"use_llm": use_llm
|
||||
}
|
||||
|
||||
# 为每个 chat_id 模拟合并
|
||||
for chat_id, chat_expressions in expressions_by_chat.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
chat_stat = {
|
||||
"chat_id": chat_id,
|
||||
"chat_name": chat_name,
|
||||
"total": len(chat_expressions),
|
||||
"exact_matches": 0,
|
||||
"similar_matches": 0,
|
||||
"new_records": 0,
|
||||
"merges": []
|
||||
}
|
||||
|
||||
processed_ids = set()
|
||||
|
||||
for expr in chat_expressions:
|
||||
if expr.id in processed_ids:
|
||||
continue
|
||||
|
||||
target_style = expr.style
|
||||
target_situation = expr.situation
|
||||
|
||||
# 第一层:检查完全匹配
|
||||
exact_match = find_exact_style_match(
|
||||
chat_expressions,
|
||||
target_style,
|
||||
chat_id,
|
||||
{expr.id}
|
||||
)
|
||||
|
||||
if exact_match:
|
||||
# 完全匹配(不使用 LLM 总结)
|
||||
# 模拟合并后的 content_list 和 style_list
|
||||
target_content_list = parse_content_list(exact_match.content_list)
|
||||
target_content_list.append(target_situation)
|
||||
|
||||
target_style_list = parse_style_list(exact_match.style_list)
|
||||
if exact_match.style and exact_match.style not in target_style_list:
|
||||
target_style_list.append(exact_match.style)
|
||||
if target_style not in target_style_list:
|
||||
target_style_list.append(target_style)
|
||||
|
||||
merge_info = {
|
||||
"type": "exact",
|
||||
"source_id": expr.id,
|
||||
"target_id": exact_match.id,
|
||||
"source_style": target_style,
|
||||
"target_style": exact_match.style,
|
||||
"source_situation": target_situation,
|
||||
"target_situation": exact_match.situation,
|
||||
"similarity": 1.0,
|
||||
"merged_content_list": target_content_list,
|
||||
"merged_style_list": target_style_list,
|
||||
"merged_situation": exact_match.situation, # 完全匹配时保持原 situation
|
||||
"merged_style": exact_match.style # 完全匹配时保持原 style
|
||||
}
|
||||
chat_stat["exact_matches"] += 1
|
||||
chat_stat["merges"].append(merge_info)
|
||||
merge_stats["exact_matches"] += 1
|
||||
processed_ids.add(expr.id)
|
||||
continue
|
||||
|
||||
# 第二层:检查相似匹配
|
||||
similar_match = find_similar_style_expression(
|
||||
chat_expressions,
|
||||
target_style,
|
||||
chat_id,
|
||||
similarity_threshold,
|
||||
{expr.id}
|
||||
)
|
||||
|
||||
if similar_match:
|
||||
match_expr, similarity = similar_match
|
||||
# 相似匹配(使用 LLM 总结)
|
||||
# 模拟合并后的 content_list 和 style_list
|
||||
target_content_list = parse_content_list(match_expr.content_list)
|
||||
target_content_list.append(target_situation)
|
||||
|
||||
target_style_list = parse_style_list(match_expr.style_list)
|
||||
if match_expr.style and match_expr.style not in target_style_list:
|
||||
target_style_list.append(match_expr.style)
|
||||
if target_style not in target_style_list:
|
||||
target_style_list.append(target_style)
|
||||
|
||||
# 使用 LLM 总结(如果启用)
|
||||
merged_situation = match_expr.situation
|
||||
merged_style = match_expr.style or target_style
|
||||
|
||||
if use_llm and summary_model:
|
||||
try:
|
||||
merged_situation = await compose_situation_text(target_content_list, summary_model)
|
||||
merged_style = await compose_style_text(target_style_list, summary_model)
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 处理记录 {expr.id} 时 LLM 总结失败: {e}")
|
||||
# 如果总结失败,使用 fallback
|
||||
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
|
||||
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
|
||||
else:
|
||||
# 不使用 LLM 时,使用简单拼接
|
||||
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
|
||||
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
|
||||
|
||||
merge_info = {
|
||||
"type": "similar",
|
||||
"source_id": expr.id,
|
||||
"target_id": match_expr.id,
|
||||
"source_style": target_style,
|
||||
"target_style": match_expr.style,
|
||||
"source_situation": target_situation,
|
||||
"target_situation": match_expr.situation,
|
||||
"similarity": similarity,
|
||||
"merged_content_list": target_content_list,
|
||||
"merged_style_list": target_style_list,
|
||||
"merged_situation": merged_situation,
|
||||
"merged_style": merged_style,
|
||||
"llm_used": use_llm and summary_model is not None
|
||||
}
|
||||
chat_stat["similar_matches"] += 1
|
||||
chat_stat["merges"].append(merge_info)
|
||||
merge_stats["similar_matches"] += 1
|
||||
processed_ids.add(expr.id)
|
||||
continue
|
||||
|
||||
# 没有匹配,作为新记录
|
||||
chat_stat["new_records"] += 1
|
||||
merge_stats["new_records"] += 1
|
||||
processed_ids.add(expr.id)
|
||||
|
||||
merge_stats["chat_stats"][chat_id] = chat_stat
|
||||
merge_stats["merge_details"].extend(chat_stat["merges"])
|
||||
|
||||
return merge_stats
|
||||
|
||||
|
||||
def print_merge_results(stats: Dict, show_details: bool = True, max_details: int = 50):
|
||||
"""打印合并结果"""
|
||||
print("\n" + "=" * 80)
|
||||
print("Expression 合并模拟结果")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n📊 总体统计:")
|
||||
print(f" 总 Expression 数: {stats['total_expressions']}")
|
||||
print(f" 总聊天数: {stats['total_chats']}")
|
||||
print(f" 完全匹配合并: {stats['exact_matches']}")
|
||||
print(f" 相似匹配合并: {stats['similar_matches']}")
|
||||
print(f" 新记录(无匹配): {stats['new_records']}")
|
||||
if stats.get('use_llm'):
|
||||
print(" LLM 总结: 已启用")
|
||||
else:
|
||||
print(" LLM 总结: 未启用(仅模拟)")
|
||||
|
||||
total_merges = stats['exact_matches'] + stats['similar_matches']
|
||||
if stats['total_expressions'] > 0:
|
||||
merge_ratio = (total_merges / stats['total_expressions']) * 100
|
||||
print(f" 合并比例: {merge_ratio:.1f}%")
|
||||
|
||||
# 按聊天分组显示
|
||||
print("\n📋 按聊天分组统计:")
|
||||
for chat_id, chat_stat in stats['chat_stats'].items():
|
||||
print(f"\n {chat_stat['chat_name']} ({chat_id[:8]}...):")
|
||||
print(f" 总数: {chat_stat['total']}")
|
||||
print(f" 完全匹配: {chat_stat['exact_matches']}")
|
||||
print(f" 相似匹配: {chat_stat['similar_matches']}")
|
||||
print(f" 新记录: {chat_stat['new_records']}")
|
||||
|
||||
# 显示合并详情
|
||||
if show_details and stats['merge_details']:
|
||||
print(f"\n📝 合并详情 (显示前 {min(max_details, len(stats['merge_details']))} 条):")
|
||||
print()
|
||||
|
||||
for idx, merge in enumerate(stats['merge_details'][:max_details], 1):
|
||||
merge_type = "完全匹配" if merge['type'] == 'exact' else f"相似匹配 (相似度: {merge['similarity']:.3f})"
|
||||
print(f" {idx}. {merge_type}")
|
||||
print(f" 源记录 ID: {merge['source_id']}")
|
||||
print(f" 目标记录 ID: {merge['target_id']}")
|
||||
print(f" 源 Style: {merge['source_style'][:50]}")
|
||||
print(f" 目标 Style: {merge['target_style'][:50]}")
|
||||
print(f" 源 Situation: {merge['source_situation'][:50]}")
|
||||
print(f" 目标 Situation: {merge['target_situation'][:50]}")
|
||||
|
||||
# 显示合并后的结果
|
||||
if 'merged_situation' in merge:
|
||||
print(f" → 合并后 Situation: {merge['merged_situation'][:50]}")
|
||||
if 'merged_style' in merge:
|
||||
print(f" → 合并后 Style: {merge['merged_style'][:50]}")
|
||||
if merge.get('llm_used'):
|
||||
print(" → LLM 总结: 已使用")
|
||||
elif merge['type'] == 'similar':
|
||||
print(" → LLM 总结: 未使用(模拟模式)")
|
||||
|
||||
# 显示合并后的列表
|
||||
if 'merged_content_list' in merge and len(merge['merged_content_list']) > 1:
|
||||
print(f" → Content List ({len(merge['merged_content_list'])} 项): {', '.join(merge['merged_content_list'][:3])}")
|
||||
if len(merge['merged_content_list']) > 3:
|
||||
print(f" ... 还有 {len(merge['merged_content_list']) - 3} 项")
|
||||
if 'merged_style_list' in merge and len(merge['merged_style_list']) > 1:
|
||||
print(f" → Style List ({len(merge['merged_style_list'])} 项): {', '.join(merge['merged_style_list'][:3])}")
|
||||
if len(merge['merged_style_list']) > 3:
|
||||
print(f" ... 还有 {len(merge['merged_style_list']) - 3} 项")
|
||||
print()
|
||||
|
||||
if len(stats['merge_details']) > max_details:
|
||||
print(f" ... 还有 {len(stats['merge_details']) - max_details} 条合并记录未显示")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description="模拟 Expression 合并过程")
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="指定要分析的 chat_id(不指定则分析所有)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--similarity-threshold",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="相似度阈值 (0-1, 默认: 0.75)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-details",
|
||||
action="store_true",
|
||||
help="不显示详细信息,只显示统计"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-details",
|
||||
type=int,
|
||||
default=50,
|
||||
help="最多显示的合并详情数 (默认: 50)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-llm",
|
||||
action="store_true",
|
||||
help="启用 LLM 进行实际总结(默认: 仅模拟,不调用 LLM)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-samples",
|
||||
type=int,
|
||||
default=10,
|
||||
help="最多随机抽取的 Expression 数量 (默认: 10,设置为 0 表示不限制)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证阈值
|
||||
if not 0 <= args.similarity_threshold <= 1:
|
||||
print("错误: similarity-threshold 必须在 0-1 之间")
|
||||
return
|
||||
|
||||
# 确定输出文件路径
|
||||
if args.output:
|
||||
output_file = args.output
|
||||
else:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, f"expression_merge_simulation_{timestamp}.txt")
|
||||
|
||||
# 查询 Expression 记录
|
||||
print("正在从数据库加载Expression数据...")
|
||||
try:
|
||||
if args.chat_id:
|
||||
expressions = list(Expression.select().where(Expression.chat_id == args.chat_id))
|
||||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录 (chat_id: {args.chat_id})")
|
||||
else:
|
||||
expressions = list(Expression.select())
|
||||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
|
||||
except Exception as e:
|
||||
print(f"❌ 加载数据失败: {e}")
|
||||
return
|
||||
|
||||
if not expressions:
|
||||
print("❌ 数据库中没有找到Expression记录")
|
||||
return
|
||||
|
||||
# 执行合并模拟
|
||||
print(f"\n正在模拟合并过程(相似度阈值: {args.similarity_threshold},最大样本数: {args.max_samples})...")
|
||||
if args.use_llm:
|
||||
print("⚠️ 已启用 LLM 总结,将进行实际的 API 调用")
|
||||
else:
|
||||
print("ℹ️ 未启用 LLM 总结,仅进行模拟(使用 --use-llm 启用实际 LLM 调用)")
|
||||
|
||||
stats = asyncio.run(
|
||||
simulate_merge(
|
||||
expressions,
|
||||
similarity_threshold=args.similarity_threshold,
|
||||
use_llm=args.use_llm,
|
||||
max_samples=args.max_samples,
|
||||
)
|
||||
)
|
||||
|
||||
# 输出结果
|
||||
original_stdout = sys.stdout
|
||||
try:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
sys.stdout = f
|
||||
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
|
||||
sys.stdout = original_stdout
|
||||
|
||||
# 同时在控制台输出
|
||||
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
|
||||
|
||||
except Exception as e:
|
||||
sys.stdout = original_stdout
|
||||
print(f"❌ 写入文件失败: {e}")
|
||||
return
|
||||
|
||||
print(f"\n✅ 模拟结果已保存到: {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -1,342 +0,0 @@
|
|||
import sys
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def get_expression_data() -> List[Tuple[float, float, str, str]]:
|
||||
"""获取Expression表中的数据,返回(create_date, count, chat_id, expression_type)的列表"""
|
||||
expressions = Expression.select()
|
||||
data = []
|
||||
|
||||
for expr in expressions:
|
||||
# 如果create_date为空,跳过该记录
|
||||
if expr.create_date is None:
|
||||
continue
|
||||
|
||||
data.append((expr.create_date, expr.count, expr.chat_id, expr.type))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def create_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 分离数据
|
||||
create_dates = [item[0] for item in data]
|
||||
counts = [item[1] for item in data]
|
||||
_chat_ids = [item[2] for item in data]
|
||||
_expression_types = [item[3] for item in data]
|
||||
|
||||
# 转换时间戳为datetime对象
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
time_span = max(dates) - min(dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = "%Y-%m-%d %H:%M"
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# 创建散点图
|
||||
scatter = ax.scatter(dates, counts, alpha=0.6, s=30, c=range(len(dates)), cmap="viridis")
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||
ax.set_title("表达式使用次数随时间分布散点图", fontsize=14, fontweight="bold")
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 添加颜色条
|
||||
cbar = plt.colorbar(scatter)
|
||||
cbar.set_label("数据点顺序", fontsize=10)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print("\n=== 数据统计 ===")
|
||||
print(f"总数据点数量: {len(data)}")
|
||||
print(f"时间范围: {min(dates).strftime('%Y-%m-%d %H:%M:%S')} 到 {max(dates).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"使用次数范围: {min(counts):.1f} 到 {max(counts):.1f}")
|
||||
print(f"平均使用次数: {np.mean(counts):.2f}")
|
||||
print(f"中位数使用次数: {np.median(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def create_grouped_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建按聊天分组的散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 按chat_id分组
|
||||
chat_groups = {}
|
||||
for item in data:
|
||||
chat_id = item[2]
|
||||
if chat_id not in chat_groups:
|
||||
chat_groups[chat_id] = []
|
||||
chat_groups[chat_id].append(item)
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = "%Y-%m-%d %H:%M"
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
|
||||
# 为每个聊天分配不同颜色
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(chat_groups)))
|
||||
|
||||
for i, (chat_id, chat_data) in enumerate(chat_groups.items()):
|
||||
create_dates = [item[0] for item in chat_data]
|
||||
counts = [item[1] for item in chat_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
chat_name = get_chat_name(chat_id)
|
||||
# 截断过长的聊天名称
|
||||
display_name = chat_name[:20] + "..." if len(chat_name) > 20 else chat_name
|
||||
|
||||
ax.scatter(
|
||||
dates,
|
||||
counts,
|
||||
alpha=0.7,
|
||||
s=40,
|
||||
c=[colors[i]],
|
||||
label=f"{display_name} ({len(chat_data)}个)",
|
||||
edgecolors="black",
|
||||
linewidth=0.5,
|
||||
)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||
ax.set_title("按聊天分组的表达式使用次数散点图", fontsize=14, fontweight="bold")
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print("\n=== 分组统计 ===")
|
||||
print(f"总聊天数量: {len(chat_groups)}")
|
||||
for chat_id, chat_data in chat_groups.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
counts = [item[1] for item in chat_data]
|
||||
print(f"{chat_name}: {len(chat_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n分组散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def create_type_scatter_plot(data: List[Tuple[float, float, str, str]], save_path: str = None):
|
||||
"""创建按表达式类型分组的散点图"""
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据")
|
||||
return
|
||||
|
||||
# 按type分组
|
||||
type_groups = {}
|
||||
for item in data:
|
||||
expr_type = item[3]
|
||||
if expr_type not in type_groups:
|
||||
type_groups[expr_type] = []
|
||||
type_groups[expr_type].append(item)
|
||||
|
||||
# 计算时间跨度,自动调整显示格式
|
||||
all_dates = [datetime.fromtimestamp(item[0]) for item in data]
|
||||
time_span = max(all_dates) - min(all_dates)
|
||||
if time_span.days > 30: # 超过30天,按月显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.MonthLocator()
|
||||
minor_locator = mdates.DayLocator(interval=7)
|
||||
elif time_span.days > 7: # 超过7天,按天显示
|
||||
date_format = "%Y-%m-%d"
|
||||
major_locator = mdates.DayLocator(interval=1)
|
||||
minor_locator = mdates.HourLocator(interval=12)
|
||||
else: # 7天内,按小时显示
|
||||
date_format = "%Y-%m-%d %H:%M"
|
||||
major_locator = mdates.HourLocator(interval=6)
|
||||
minor_locator = mdates.HourLocator(interval=1)
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# 为每个类型分配不同颜色
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, len(type_groups)))
|
||||
|
||||
for i, (expr_type, type_data) in enumerate(type_groups.items()):
|
||||
create_dates = [item[0] for item in type_data]
|
||||
counts = [item[1] for item in type_data]
|
||||
dates = [datetime.fromtimestamp(ts) for ts in create_dates]
|
||||
|
||||
ax.scatter(
|
||||
dates,
|
||||
counts,
|
||||
alpha=0.7,
|
||||
s=40,
|
||||
c=[colors[i]],
|
||||
label=f"{expr_type} ({len(type_data)}个)",
|
||||
edgecolors="black",
|
||||
linewidth=0.5,
|
||||
)
|
||||
|
||||
# 设置标签和标题
|
||||
ax.set_xlabel("创建日期 (Create Date)", fontsize=12)
|
||||
ax.set_ylabel("使用次数 (Count)", fontsize=12)
|
||||
ax.set_title("按表达式类型分组的散点图", fontsize=14, fontweight="bold")
|
||||
|
||||
# 设置x轴日期格式 - 根据时间跨度自动调整
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter(date_format))
|
||||
ax.xaxis.set_major_locator(major_locator)
|
||||
ax.xaxis.set_minor_locator(minor_locator)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加图例
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示统计信息
|
||||
print("\n=== 类型统计 ===")
|
||||
for expr_type, type_data in type_groups.items():
|
||||
counts = [item[1] for item in type_data]
|
||||
print(f"{expr_type}: {len(type_data)}个表达式, 平均使用次数: {np.mean(counts):.2f}")
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"\n类型散点图已保存到: {save_path}")
|
||||
|
||||
# 显示图片
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("开始分析表达式数据...")
|
||||
|
||||
# 获取数据
|
||||
data = get_expression_data()
|
||||
|
||||
if not data:
|
||||
print("没有找到有效的表达式数据(create_date不为空的数据)")
|
||||
return
|
||||
|
||||
print(f"找到 {len(data)} 条有效数据")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 生成时间戳用于文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 1. 创建基础散点图
|
||||
print("\n1. 创建基础散点图...")
|
||||
create_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_{timestamp}.png"))
|
||||
|
||||
# 2. 创建按聊天分组的散点图
|
||||
print("\n2. 创建按聊天分组的散点图...")
|
||||
create_grouped_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_chat_{timestamp}.png"))
|
||||
|
||||
# 3. 创建按类型分组的散点图
|
||||
print("\n3. 创建按类型分组的散点图...")
|
||||
create_type_scatter_plot(data, os.path.join(output_dir, f"expression_scatter_by_type_{timestamp}.png"))
|
||||
|
||||
print("\n分析完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,559 +0,0 @@
|
|||
"""
|
||||
分析expression库中situation和style的相似度
|
||||
|
||||
用法:
|
||||
python scripts/expression_similarity_analysis.py
|
||||
或指定阈值:
|
||||
python scripts/expression_similarity_analysis.py --situation-threshold 0.8 --style-threshold 0.7
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from typing import List, Tuple
|
||||
from collections import defaultdict
|
||||
from difflib import SequenceMatcher
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Import after setting up path (required for project imports)
|
||||
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
|
||||
from src.config.config import global_config # noqa: E402
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager # noqa: E402
|
||||
|
||||
|
||||
class TeeOutput:
|
||||
"""同时输出到控制台和文件的类"""
|
||||
def __init__(self, file_path: str):
|
||||
self.file = open(file_path, "w", encoding="utf-8")
|
||||
self.console = sys.stdout
|
||||
|
||||
def write(self, text: str):
|
||||
"""写入文本到控制台和文件"""
|
||||
self.console.write(text)
|
||||
self.file.write(text)
|
||||
self.file.flush() # 立即刷新到文件
|
||||
|
||||
def flush(self):
|
||||
"""刷新输出"""
|
||||
self.console.flush()
|
||||
self.file.flush()
|
||||
|
||||
def close(self):
|
||||
"""关闭文件"""
|
||||
if self.file:
|
||||
self.file.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||
"""
|
||||
解析'platform:id:type'为chat_id,直接复用 ChatManager 的逻辑
|
||||
"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def build_chat_id_groups() -> dict[str, set[str]]:
|
||||
"""
|
||||
根据expression_groups配置,构建chat_id到相关chat_id集合的映射
|
||||
|
||||
Returns:
|
||||
dict: {chat_id: set of related chat_ids (including itself)}
|
||||
"""
|
||||
groups = global_config.expression.expression_groups
|
||||
chat_id_groups: dict[str, set[str]] = {}
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,收集所有配置中的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if stream_config_str == "*":
|
||||
continue
|
||||
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
|
||||
# 所有chat_id都互相相关
|
||||
for chat_id in all_chat_ids:
|
||||
chat_id_groups[chat_id] = all_chat_ids.copy()
|
||||
else:
|
||||
# 处理普通组
|
||||
for group in groups:
|
||||
group_chat_ids = set()
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.add(chat_id_candidate)
|
||||
|
||||
# 组内的所有chat_id都互相相关
|
||||
for chat_id in group_chat_ids:
|
||||
if chat_id not in chat_id_groups:
|
||||
chat_id_groups[chat_id] = set()
|
||||
chat_id_groups[chat_id].update(group_chat_ids)
|
||||
|
||||
# 确保每个chat_id至少包含自身
|
||||
for chat_id in chat_id_groups:
|
||||
chat_id_groups[chat_id].add(chat_id)
|
||||
|
||||
return chat_id_groups
|
||||
|
||||
|
||||
def are_chat_ids_related(chat_id1: str, chat_id2: str, chat_id_groups: dict[str, set[str]]) -> bool:
|
||||
"""
|
||||
判断两个chat_id是否相关(相同或同组)
|
||||
|
||||
Args:
|
||||
chat_id1: 第一个chat_id
|
||||
chat_id2: 第二个chat_id
|
||||
chat_id_groups: chat_id到相关chat_id集合的映射
|
||||
|
||||
Returns:
|
||||
bool: 如果两个chat_id相同或同组,返回True
|
||||
"""
|
||||
if chat_id1 == chat_id2:
|
||||
return True
|
||||
|
||||
# 如果chat_id1在映射中,检查chat_id2是否在其相关集合中
|
||||
if chat_id1 in chat_id_groups:
|
||||
return chat_id2 in chat_id_groups[chat_id1]
|
||||
|
||||
# 如果chat_id1不在映射中,说明它不在任何组中,只与自己相关
|
||||
return False
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name}"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id[:8]}...)"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id[:8]}...)"
|
||||
|
||||
|
||||
def text_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
使用SequenceMatcher计算相似度,返回0-1之间的值
|
||||
在计算前会移除"使用"和"句式"这两个词
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# 移除"使用"和"句式"这两个词
|
||||
def remove_ignored_words(text: str) -> str:
|
||||
"""移除需要忽略的词"""
|
||||
text = text.replace("使用", "")
|
||||
text = text.replace("句式", "")
|
||||
return text.strip()
|
||||
|
||||
cleaned_text1 = remove_ignored_words(text1)
|
||||
cleaned_text2 = remove_ignored_words(text2)
|
||||
|
||||
# 如果清理后文本为空,返回0
|
||||
if not cleaned_text1 or not cleaned_text2:
|
||||
return 0.0
|
||||
|
||||
return SequenceMatcher(None, cleaned_text1, cleaned_text2).ratio()
|
||||
|
||||
|
||||
def find_similar_pairs(
|
||||
expressions: List[Expression],
|
||||
field_name: str,
|
||||
threshold: float,
|
||||
max_pairs: int = None
|
||||
) -> List[Tuple[int, int, float, str, str]]:
|
||||
"""
|
||||
找出相似的expression对
|
||||
|
||||
Args:
|
||||
expressions: Expression对象列表
|
||||
field_name: 要比较的字段名 ('situation' 或 'style')
|
||||
threshold: 相似度阈值 (0-1)
|
||||
max_pairs: 最多返回的对数,None表示返回所有
|
||||
|
||||
Returns:
|
||||
List of (index1, index2, similarity, text1, text2) tuples
|
||||
"""
|
||||
similar_pairs = []
|
||||
n = len(expressions)
|
||||
|
||||
print(f"正在分析 {field_name} 字段的相似度...")
|
||||
print(f"总共需要比较 {n * (n - 1) // 2} 对...")
|
||||
|
||||
for i in range(n):
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f" 已处理 {i + 1}/{n} 个项目...")
|
||||
|
||||
expr1 = expressions[i]
|
||||
text1 = getattr(expr1, field_name, "")
|
||||
|
||||
for j in range(i + 1, n):
|
||||
expr2 = expressions[j]
|
||||
text2 = getattr(expr2, field_name, "")
|
||||
|
||||
similarity = text_similarity(text1, text2)
|
||||
|
||||
if similarity >= threshold:
|
||||
similar_pairs.append((i, j, similarity, text1, text2))
|
||||
|
||||
# 按相似度降序排序
|
||||
similar_pairs.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
if max_pairs:
|
||||
similar_pairs = similar_pairs[:max_pairs]
|
||||
|
||||
return similar_pairs
|
||||
|
||||
|
||||
def group_similar_items(
|
||||
expressions: List[Expression],
|
||||
field_name: str,
|
||||
threshold: float,
|
||||
chat_id_groups: dict[str, set[str]]
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
将相似的expression分组(仅比较相同chat_id或同组的项目)
|
||||
|
||||
Args:
|
||||
expressions: Expression对象列表
|
||||
field_name: 要比较的字段名 ('situation' 或 'style')
|
||||
threshold: 相似度阈值 (0-1)
|
||||
chat_id_groups: chat_id到相关chat_id集合的映射
|
||||
|
||||
Returns:
|
||||
List of groups, each group is a list of indices
|
||||
"""
|
||||
n = len(expressions)
|
||||
# 使用并查集的思想来分组
|
||||
parent = list(range(n))
|
||||
|
||||
def find(x):
|
||||
if parent[x] != x:
|
||||
parent[x] = find(parent[x])
|
||||
return parent[x]
|
||||
|
||||
def union(x, y):
|
||||
px, py = find(x), find(y)
|
||||
if px != py:
|
||||
parent[px] = py
|
||||
|
||||
print(f"正在对 {field_name} 字段进行分组(仅比较相同chat_id或同组的项目)...")
|
||||
|
||||
# 统计需要比较的对数
|
||||
total_pairs = 0
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
if are_chat_ids_related(expressions[i].chat_id, expressions[j].chat_id, chat_id_groups):
|
||||
total_pairs += 1
|
||||
|
||||
print(f"总共需要比较 {total_pairs} 对(已过滤不同chat_id且不同组的项目)...")
|
||||
|
||||
compared_pairs = 0
|
||||
for i in range(n):
|
||||
if (i + 1) % 100 == 0:
|
||||
print(f" 已处理 {i + 1}/{n} 个项目...")
|
||||
|
||||
expr1 = expressions[i]
|
||||
text1 = getattr(expr1, field_name, "")
|
||||
|
||||
for j in range(i + 1, n):
|
||||
expr2 = expressions[j]
|
||||
|
||||
# 只比较相同chat_id或同组的项目
|
||||
if not are_chat_ids_related(expr1.chat_id, expr2.chat_id, chat_id_groups):
|
||||
continue
|
||||
|
||||
compared_pairs += 1
|
||||
text2 = getattr(expr2, field_name, "")
|
||||
|
||||
similarity = text_similarity(text1, text2)
|
||||
|
||||
if similarity >= threshold:
|
||||
union(i, j)
|
||||
|
||||
# 收集分组
|
||||
groups = defaultdict(list)
|
||||
for i in range(n):
|
||||
root = find(i)
|
||||
groups[root].append(i)
|
||||
|
||||
# 只返回包含多个项目的组
|
||||
result = [group for group in groups.values() if len(group) > 1]
|
||||
result.sort(key=len, reverse=True)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_similarity_analysis(
|
||||
expressions: List[Expression],
|
||||
field_name: str,
|
||||
threshold: float,
|
||||
chat_id_groups: dict[str, set[str]],
|
||||
show_details: bool = True,
|
||||
max_groups: int = 20
|
||||
):
|
||||
"""打印相似度分析结果"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"{field_name.upper()} 相似度分析 (阈值: {threshold})")
|
||||
print("=" * 80)
|
||||
|
||||
# 分组分析
|
||||
groups = group_similar_items(expressions, field_name, threshold, chat_id_groups)
|
||||
|
||||
total_items = len(expressions)
|
||||
similar_items_count = sum(len(group) for group in groups)
|
||||
unique_groups = len(groups)
|
||||
|
||||
print("\n📊 统计信息:")
|
||||
print(f" 总项目数: {total_items}")
|
||||
print(f" 相似项目数: {similar_items_count} ({similar_items_count / total_items * 100:.1f}%)")
|
||||
print(f" 相似组数: {unique_groups}")
|
||||
print(f" 平均每组项目数: {similar_items_count / unique_groups:.1f}" if unique_groups > 0 else " 平均每组项目数: 0")
|
||||
|
||||
if not groups:
|
||||
print(f"\n未找到相似度 >= {threshold} 的项目组")
|
||||
return
|
||||
|
||||
print(f"\n📋 相似组详情 (显示前 {min(max_groups, len(groups))} 组):")
|
||||
print()
|
||||
|
||||
for group_idx, group in enumerate(groups[:max_groups], 1):
|
||||
print(f"组 {group_idx} (共 {len(group)} 个项目):")
|
||||
|
||||
if show_details:
|
||||
# 显示组内所有项目的详细信息
|
||||
for idx in group:
|
||||
expr = expressions[idx]
|
||||
text = getattr(expr, field_name, "")
|
||||
chat_name = get_chat_name(expr.chat_id)
|
||||
|
||||
# 截断过长的文本
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
|
||||
print(f" [{expr.id}] {display_text}")
|
||||
print(f" 聊天: {chat_name}, Count: {expr.count}")
|
||||
|
||||
# 计算组内平均相似度
|
||||
if len(group) > 1:
|
||||
similarities = []
|
||||
above_threshold_pairs = [] # 存储满足阈值的相似对
|
||||
above_threshold_count = 0
|
||||
for i in range(len(group)):
|
||||
for j in range(i + 1, len(group)):
|
||||
text1 = getattr(expressions[group[i]], field_name, "")
|
||||
text2 = getattr(expressions[group[j]], field_name, "")
|
||||
sim = text_similarity(text1, text2)
|
||||
similarities.append(sim)
|
||||
if sim >= threshold:
|
||||
above_threshold_count += 1
|
||||
# 存储满足阈值的对的信息
|
||||
expr1 = expressions[group[i]]
|
||||
expr2 = expressions[group[j]]
|
||||
display_text1 = text1[:40] + "..." if len(text1) > 40 else text1
|
||||
display_text2 = text2[:40] + "..." if len(text2) > 40 else text2
|
||||
above_threshold_pairs.append((
|
||||
expr1.id, display_text1,
|
||||
expr2.id, display_text2,
|
||||
sim
|
||||
))
|
||||
|
||||
if similarities:
|
||||
avg_sim = sum(similarities) / len(similarities)
|
||||
min_sim = min(similarities)
|
||||
max_sim = max(similarities)
|
||||
above_threshold_ratio = above_threshold_count / len(similarities) * 100
|
||||
print(f" 平均相似度: {avg_sim:.3f} (范围: {min_sim:.3f} - {max_sim:.3f})")
|
||||
print(f" 满足阈值({threshold})的比例: {above_threshold_ratio:.1f}% ({above_threshold_count}/{len(similarities)})")
|
||||
|
||||
# 显示满足阈值的相似对(这些是直接连接,导致它们被分到一组)
|
||||
if above_threshold_pairs:
|
||||
print(" ⚠️ 直接相似的对 (这些对导致它们被分到一组):")
|
||||
# 按相似度降序排序
|
||||
above_threshold_pairs.sort(key=lambda x: x[4], reverse=True)
|
||||
for idx1, text1, idx2, text2, sim in above_threshold_pairs[:10]: # 最多显示10对
|
||||
print(f" [{idx1}] ↔ [{idx2}]: {sim:.3f}")
|
||||
print(f" \"{text1}\" ↔ \"{text2}\"")
|
||||
if len(above_threshold_pairs) > 10:
|
||||
print(f" ... 还有 {len(above_threshold_pairs) - 10} 对满足阈值")
|
||||
else:
|
||||
print(f" ⚠️ 警告: 组内没有任何对满足阈值({threshold:.2f}),可能是通过传递性连接")
|
||||
else:
|
||||
# 只显示组内第一个项目作为示例
|
||||
expr = expressions[group[0]]
|
||||
text = getattr(expr, field_name, "")
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
print(f" 示例: {display_text}")
|
||||
print(f" ... 还有 {len(group) - 1} 个相似项目")
|
||||
|
||||
print()
|
||||
|
||||
if len(groups) > max_groups:
|
||||
print(f"... 还有 {len(groups) - max_groups} 组未显示")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description="分析expression库中situation和style的相似度")
|
||||
parser.add_argument(
|
||||
"--situation-threshold",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="situation相似度阈值 (0-1, 默认: 0.7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--style-threshold",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="style相似度阈值 (0-1, 默认: 0.7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-details",
|
||||
action="store_true",
|
||||
help="不显示详细信息,只显示统计"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-groups",
|
||||
type=int,
|
||||
default=20,
|
||||
help="最多显示的组数 (默认: 20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证阈值
|
||||
if not 0 <= args.situation_threshold <= 1:
|
||||
print("错误: situation-threshold 必须在 0-1 之间")
|
||||
return
|
||||
if not 0 <= args.style_threshold <= 1:
|
||||
print("错误: style-threshold 必须在 0-1 之间")
|
||||
return
|
||||
|
||||
# 确定输出文件路径
|
||||
if args.output:
|
||||
output_file = args.output
|
||||
else:
|
||||
# 自动生成带时间戳的输出文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = os.path.join(project_root, "data", "temp")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, f"expression_similarity_analysis_{timestamp}.txt")
|
||||
|
||||
# 使用TeeOutput同时输出到控制台和文件
|
||||
with TeeOutput(output_file) as tee:
|
||||
# 临时替换sys.stdout
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = tee
|
||||
|
||||
try:
|
||||
print("=" * 80)
|
||||
print("Expression 相似度分析工具")
|
||||
print("=" * 80)
|
||||
print(f"输出文件: {output_file}")
|
||||
print()
|
||||
|
||||
_run_analysis(args)
|
||||
|
||||
finally:
|
||||
# 恢复原始stdout
|
||||
sys.stdout = original_stdout
|
||||
|
||||
print(f"\n✅ 分析结果已保存到: {output_file}")
|
||||
|
||||
|
||||
def _run_analysis(args):
|
||||
"""执行分析的主逻辑"""
|
||||
|
||||
# 查询所有Expression记录
|
||||
print("正在从数据库加载Expression数据...")
|
||||
try:
|
||||
expressions = list(Expression.select())
|
||||
except Exception as e:
|
||||
print(f"❌ 加载数据失败: {e}")
|
||||
return
|
||||
|
||||
if not expressions:
|
||||
print("❌ 数据库中没有找到Expression记录")
|
||||
return
|
||||
|
||||
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
|
||||
print()
|
||||
|
||||
# 构建chat_id分组映射
|
||||
print("正在构建chat_id分组映射(根据expression_groups配置)...")
|
||||
try:
|
||||
chat_id_groups = build_chat_id_groups()
|
||||
print(f"✅ 成功构建 {len(chat_id_groups)} 个chat_id的分组映射")
|
||||
if chat_id_groups:
|
||||
# 统计分组信息
|
||||
total_related = sum(len(related) for related in chat_id_groups.values())
|
||||
avg_related = total_related / len(chat_id_groups)
|
||||
print(f" 平均每个chat_id与 {avg_related:.1f} 个chat_id相关(包括自身)")
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f"⚠️ 构建chat_id分组映射失败: {e}")
|
||||
print(" 将使用默认行为:只比较相同chat_id的项目")
|
||||
chat_id_groups = {}
|
||||
|
||||
# 分析situation相似度
|
||||
print_similarity_analysis(
|
||||
expressions,
|
||||
"situation",
|
||||
args.situation_threshold,
|
||||
chat_id_groups,
|
||||
show_details=not args.no_details,
|
||||
max_groups=args.max_groups
|
||||
)
|
||||
|
||||
# 分析style相似度
|
||||
print_similarity_analysis(
|
||||
expressions,
|
||||
"style",
|
||||
args.style_threshold,
|
||||
chat_id_groups,
|
||||
show_details=not args.no_details,
|
||||
max_groups=args.max_groups
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("分析完成!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -1,196 +0,0 @@
|
|||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
# Add project root to Python path
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
# 直接从数据库查询ChatStreams表
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
# 如果有群组信息,显示群组名称
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
# 如果是私聊,显示用户昵称
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def calculate_time_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of last active time in days"""
|
||||
now = time.time()
|
||||
distribution = {
|
||||
"0-1天": 0,
|
||||
"1-3天": 0,
|
||||
"3-7天": 0,
|
||||
"7-14天": 0,
|
||||
"14-30天": 0,
|
||||
"30-60天": 0,
|
||||
"60-90天": 0,
|
||||
"90+天": 0,
|
||||
}
|
||||
for expr in expressions:
|
||||
diff_days = (now - expr.last_active_time) / (24 * 3600)
|
||||
if diff_days < 1:
|
||||
distribution["0-1天"] += 1
|
||||
elif diff_days < 3:
|
||||
distribution["1-3天"] += 1
|
||||
elif diff_days < 7:
|
||||
distribution["3-7天"] += 1
|
||||
elif diff_days < 14:
|
||||
distribution["7-14天"] += 1
|
||||
elif diff_days < 30:
|
||||
distribution["14-30天"] += 1
|
||||
elif diff_days < 60:
|
||||
distribution["30-60天"] += 1
|
||||
elif diff_days < 90:
|
||||
distribution["60-90天"] += 1
|
||||
else:
|
||||
distribution["90+天"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_count_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of count values"""
|
||||
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
|
||||
for expr in expressions:
|
||||
cnt = expr.count
|
||||
if cnt < 1:
|
||||
distribution["0-1"] += 1
|
||||
elif cnt < 2:
|
||||
distribution["1-2"] += 1
|
||||
elif cnt < 3:
|
||||
distribution["2-3"] += 1
|
||||
elif cnt < 4:
|
||||
distribution["3-4"] += 1
|
||||
elif cnt < 5:
|
||||
distribution["4-5"] += 1
|
||||
elif cnt < 10:
|
||||
distribution["5-10"] += 1
|
||||
else:
|
||||
distribution["10+"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
|
||||
"""Get top N most used expressions for a specific chat_id"""
|
||||
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
|
||||
|
||||
|
||||
def show_overall_statistics(expressions, total: int) -> None:
|
||||
"""Show overall statistics"""
|
||||
time_dist = calculate_time_distribution(expressions)
|
||||
count_dist = calculate_count_distribution(expressions)
|
||||
|
||||
print("\n=== 总体统计 ===")
|
||||
print(f"总表达式数量: {total}")
|
||||
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
print(f"{period}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
|
||||
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
|
||||
"""Show statistics for a specific chat"""
|
||||
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
|
||||
chat_total = len(chat_exprs)
|
||||
|
||||
print(f"\n=== {chat_name} ===")
|
||||
print(f"表达式数量: {chat_total}")
|
||||
|
||||
if chat_total == 0:
|
||||
print("该聊天没有表达式数据")
|
||||
return
|
||||
|
||||
# Time distribution for this chat
|
||||
time_dist = calculate_time_distribution(chat_exprs)
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
if count > 0:
|
||||
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Count distribution for this chat
|
||||
count_dist = calculate_count_distribution(chat_exprs)
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
if count > 0:
|
||||
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Top expressions
|
||||
print("\nTop 10使用最多的表达式:")
|
||||
top_exprs = get_top_expressions_by_chat(chat_id, 10)
|
||||
for i, expr in enumerate(top_exprs, 1):
|
||||
print(f"{i}. [{expr.type}] Count: {expr.count}")
|
||||
print(f" Situation: {expr.situation}")
|
||||
print(f" Style: {expr.style}")
|
||||
print()
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for expression statistics"""
|
||||
# Get all expressions
|
||||
expressions = list(Expression.select())
|
||||
if not expressions:
|
||||
print("数据库中没有找到表达式")
|
||||
return
|
||||
|
||||
total = len(expressions)
|
||||
|
||||
# Get unique chat_ids and their names
|
||||
chat_ids = list(set(expr.chat_id for expr in expressions))
|
||||
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
|
||||
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
|
||||
|
||||
while True:
|
||||
print("\n" + "=" * 50)
|
||||
print("表达式统计分析")
|
||||
print("=" * 50)
|
||||
print("0. 显示总体统计")
|
||||
|
||||
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
|
||||
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
|
||||
print(f"{i}. {chat_name} ({chat_count}个表达式)")
|
||||
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
try:
|
||||
choice_num = int(choice)
|
||||
if choice_num == 0:
|
||||
show_overall_statistics(expressions, total)
|
||||
elif 1 <= choice_num <= len(chat_info):
|
||||
chat_id, chat_name = chat_info[choice_num - 1]
|
||||
show_chat_statistics(chat_id, chat_name)
|
||||
else:
|
||||
print("无效的选择,请重新输入")
|
||||
except ValueError:
|
||||
print("请输入有效的数字")
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
|
|
@ -470,7 +470,7 @@ def _run_embedding_helper() -> None:
|
|||
test_path.rename(archive_path)
|
||||
except Exception as exc: # pragma: no cover - 防御性兜底
|
||||
logger.error("归档 embedding_model_test.json 失败: %s", exc)
|
||||
print(f"[ERROR] 归档 embedding_model_test.json 失败,请检查文件权限与路径。错误详情已写入日志。")
|
||||
print("[ERROR] 归档 embedding_model_test.json 失败,请检查文件权限与路径。错误详情已写入日志。")
|
||||
return
|
||||
|
||||
print(
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,447 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import importlib
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import initialize_logging, get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import LLMUsage
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
logger = get_logger("test_memory_retrieval")
|
||||
|
||||
# 使用 importlib 动态导入,避免循环导入问题
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
try:
|
||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
||||
|
||||
module_name = "src.memory_system.memory_retrieval"
|
||||
|
||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
existing_module._process_single_question,
|
||||
)
|
||||
|
||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith('src.memory_system.') and key != 'src.memory_system':
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
del sys.modules[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
||||
try:
|
||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
try:
|
||||
import src.chat.replyer.group_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
try:
|
||||
import src.chat.replyer.private_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
||||
|
||||
# 现在尝试导入 memory_retrieval
|
||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
||||
memory_retrieval_module = importlib.import_module(module_name)
|
||||
|
||||
return (
|
||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
||||
memory_retrieval_module._react_agent_solve_question,
|
||||
memory_retrieval_module._process_single_question,
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream:
|
||||
"""创建一个测试用的 ChatStream 对象"""
|
||||
user_info = UserInfo(
|
||||
platform="test",
|
||||
user_id="test_user",
|
||||
user_nickname="测试用户",
|
||||
)
|
||||
group_info = GroupInfo(
|
||||
platform="test",
|
||||
group_id="test_group",
|
||||
group_name="测试群组",
|
||||
)
|
||||
return ChatStream(
|
||||
stream_id=chat_id,
|
||||
platform="test",
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
|
||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
"""获取从指定时间开始的token使用情况
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
|
||||
Returns:
|
||||
包含token使用统计的字典
|
||||
"""
|
||||
try:
|
||||
start_datetime = datetime.fromtimestamp(start_time)
|
||||
|
||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
||||
records = (
|
||||
LLMUsage.select()
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_datetime)
|
||||
& (
|
||||
(LLMUsage.request_type.like("%memory%"))
|
||||
| (LLMUsage.request_type == "memory.question")
|
||||
| (LLMUsage.request_type == "memory.react")
|
||||
| (LLMUsage.request_type == "memory.react.final")
|
||||
)
|
||||
)
|
||||
.order_by(LLMUsage.timestamp.asc())
|
||||
)
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
request_count = 0
|
||||
model_usage = {} # 按模型统计
|
||||
|
||||
for record in records:
|
||||
total_prompt_tokens += record.prompt_tokens or 0
|
||||
total_completion_tokens += record.completion_tokens or 0
|
||||
total_tokens += record.total_tokens or 0
|
||||
total_cost += record.cost or 0.0
|
||||
request_count += 1
|
||||
|
||||
# 按模型统计
|
||||
model_name = record.model_name or "unknown"
|
||||
if model_name not in model_usage:
|
||||
model_usage[model_name] = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": 0.0,
|
||||
"request_count": 0,
|
||||
}
|
||||
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
|
||||
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
|
||||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
||||
model_usage[model_name]["request_count"] += 1
|
||||
|
||||
return {
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"request_count": request_count,
|
||||
"model_usage": model_usage,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取token使用情况失败: {e}")
|
||||
return {
|
||||
"total_prompt_tokens": 0,
|
||||
"total_completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"request_count": 0,
|
||||
"model_usage": {},
|
||||
}
|
||||
|
||||
|
||||
def format_thinking_steps(thinking_steps: list) -> str:
|
||||
"""格式化思考步骤为可读字符串"""
|
||||
if not thinking_steps:
|
||||
return "无思考步骤"
|
||||
|
||||
lines = []
|
||||
for step in thinking_steps:
|
||||
iteration = step.get("iteration", "?")
|
||||
thought = step.get("thought", "")
|
||||
actions = step.get("actions", [])
|
||||
observations = step.get("observations", [])
|
||||
|
||||
lines.append(f"\n--- 迭代 {iteration} ---")
|
||||
if thought:
|
||||
lines.append(f"思考: {thought[:200]}...")
|
||||
|
||||
if actions:
|
||||
lines.append("行动:")
|
||||
for action in actions:
|
||||
action_type = action.get("action_type", "unknown")
|
||||
action_params = action.get("action_params", {})
|
||||
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
|
||||
|
||||
if observations:
|
||||
lines.append("观察:")
|
||||
for obs in observations:
|
||||
obs_str = str(obs)[:200]
|
||||
if len(str(obs)) > 200:
|
||||
obs_str += "..."
|
||||
lines.append(f" - {obs_str}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def test_memory_retrieval(
|
||||
question: str,
|
||||
chat_id: str = "test_memory_retrieval",
|
||||
context: str = "",
|
||||
max_iterations: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""测试记忆检索功能
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
max_iterations: 最大迭代次数
|
||||
|
||||
Returns:
|
||||
包含测试结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"[测试] 记忆检索测试")
|
||||
print(f"[问题] {question}")
|
||||
print("=" * 80)
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 延迟导入并初始化记忆检索prompt(这会自动加载 global_config)
|
||||
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
|
||||
try:
|
||||
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
|
||||
|
||||
# 检查 prompt 是否已经初始化,避免重复初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt()
|
||||
else:
|
||||
logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# 获取 global_config(此时应该已经加载)
|
||||
from src.config.config import global_config
|
||||
|
||||
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
||||
timeout = global_config.memory.agent_timeout_seconds
|
||||
|
||||
print(f"\n[配置]")
|
||||
print(f" 最大迭代次数: {max_iterations}")
|
||||
print(f" 超时时间: {timeout}秒")
|
||||
print(f" 聊天ID: {chat_id}")
|
||||
|
||||
# 执行检索
|
||||
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
max_iterations=max_iterations,
|
||||
timeout=timeout,
|
||||
initial_info="",
|
||||
)
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
# 获取token使用情况
|
||||
token_usage = get_token_usage_since(start_time)
|
||||
|
||||
# 构建结果
|
||||
result = {
|
||||
"question": question,
|
||||
"found_answer": found_answer,
|
||||
"answer": answer,
|
||||
"is_timeout": is_timeout,
|
||||
"elapsed_time": elapsed_time,
|
||||
"thinking_steps": thinking_steps,
|
||||
"iteration_count": len(thinking_steps),
|
||||
"token_usage": token_usage,
|
||||
}
|
||||
|
||||
# 输出结果
|
||||
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
print(f"\n[结果]")
|
||||
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
||||
if found_answer and answer:
|
||||
print(f" 答案: {answer}")
|
||||
else:
|
||||
print(f" 答案: (未找到答案)")
|
||||
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
||||
print(f" 迭代次数: {len(thinking_steps)}")
|
||||
print(f" 总耗时: {elapsed_time:.2f}秒")
|
||||
|
||||
print(f"\n[Token使用情况]")
|
||||
print(f" 总请求数: {token_usage['request_count']}")
|
||||
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
||||
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
||||
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
||||
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
||||
|
||||
if token_usage['model_usage']:
|
||||
print(f"\n[按模型统计]")
|
||||
for model_name, usage in token_usage['model_usage'].items():
|
||||
print(f" {model_name}:")
|
||||
print(f" 请求数: {usage['request_count']}")
|
||||
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
||||
print(f" Completion Tokens: {usage['completion_tokens']:,}")
|
||||
print(f" 总Tokens: {usage['total_tokens']:,}")
|
||||
print(f" 成本: ${usage['cost']:.6f}")
|
||||
|
||||
print(f"\n[迭代详情]")
|
||||
print(format_thinking_steps(thinking_steps))
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="测试记忆检索功能。可以输入一个问题,脚本会使用记忆检索的逻辑进行检索,并记录迭代信息、时间和token总消耗。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
default="test_memory_retrieval",
|
||||
help="测试用的聊天ID(默认: test_memory_retrieval)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context",
|
||||
default="",
|
||||
help="上下文信息(可选)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
help="将结果保存到JSON文件(可选)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
||||
initialize_logging(verbose=False)
|
||||
|
||||
# 交互式输入问题
|
||||
print("\n" + "=" * 80)
|
||||
print("记忆检索测试工具")
|
||||
print("=" * 80)
|
||||
question = input("\n请输入要查询的问题: ").strip()
|
||||
if not question:
|
||||
print("错误: 问题不能为空")
|
||||
return
|
||||
|
||||
# 交互式输入最大迭代次数
|
||||
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
|
||||
max_iterations = None
|
||||
if max_iterations_input:
|
||||
try:
|
||||
max_iterations = int(max_iterations_input)
|
||||
if max_iterations <= 0:
|
||||
print("警告: 迭代次数必须大于0,将使用配置默认值")
|
||||
max_iterations = None
|
||||
except ValueError:
|
||||
print("警告: 无效的迭代次数,将使用配置默认值")
|
||||
max_iterations = None
|
||||
|
||||
# 连接数据库
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
print(f"错误: 数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 运行测试
|
||||
try:
|
||||
result = asyncio.run(
|
||||
test_memory_retrieval(
|
||||
question=question,
|
||||
chat_id=args.chat_id,
|
||||
context=args.context,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
)
|
||||
|
||||
# 如果指定了输出文件,保存结果
|
||||
if args.output:
|
||||
# 将thinking_steps转换为可序列化的格式
|
||||
output_result = result.copy()
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[结果已保存] {args.output}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[中断] 用户中断测试")
|
||||
except Exception as e:
|
||||
logger.error(f"测试失败: {e}", exc_info=True)
|
||||
print(f"\n[错误] 测试失败: {e}")
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
"""
|
||||
表达方式自动检查定时任务
|
||||
|
||||
功能:
|
||||
1. 定期随机选取指定数量的表达方式
|
||||
2. 使用LLM进行评估
|
||||
3. 通过评估的:rejected=0, checked=1
|
||||
4. 未通过评估的:rejected=1, checked=1
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
logger = get_logger("expression_auto_check_task")
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
# 基础评估标准
|
||||
base_criteria = [
|
||||
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
|
||||
"允许部分语法错误或口头化或缺省出现",
|
||||
"表达方式不能太过特指,需要具有泛用性",
|
||||
"一般不涉及具体的人名或名称"
|
||||
]
|
||||
|
||||
# 从配置中获取额外的自定义标准
|
||||
custom_criteria = global_config.expression.expression_auto_check_custom_criteria
|
||||
|
||||
# 合并所有评估标准
|
||||
all_criteria = base_criteria.copy()
|
||||
if custom_criteria:
|
||||
all_criteria.extend(custom_criteria)
|
||||
|
||||
# 构建评估标准列表字符串
|
||||
criteria_list = "\n".join([f"{i+1}. {criterion}" for i, criterion in enumerate(all_criteria)])
|
||||
|
||||
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
|
||||
使用条件或使用情景:{situation}
|
||||
表达方式或言语风格:{style}
|
||||
|
||||
请从以下方面进行评估:
|
||||
{criteria_list}
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
judge_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_check"
|
||||
)
|
||||
|
||||
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, (reasoning, model_name, _) = await judge_llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
import re
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
||||
|
||||
class ExpressionAutoCheckTask(AsyncTask):
|
||||
"""表达方式自动检查定时任务"""
|
||||
|
||||
def __init__(self):
|
||||
# 从配置中获取检查间隔和一次检查数量
|
||||
check_interval = global_config.expression.expression_auto_check_interval
|
||||
super().__init__(
|
||||
task_name="Expression Auto Check Task",
|
||||
wait_before_start=60, # 启动后等待60秒再开始第一次检查
|
||||
run_interval=check_interval
|
||||
)
|
||||
|
||||
async def _select_expressions(self, count: int) -> List[Expression]:
|
||||
"""
|
||||
随机选择指定数量的未检查表达方式
|
||||
|
||||
Args:
|
||||
count: 需要选择的数量
|
||||
|
||||
Returns:
|
||||
选中的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有未检查的表达方式(checked=False)
|
||||
unevaluated_expressions = list(
|
||||
Expression.select().where(~Expression.checked)
|
||||
)
|
||||
|
||||
if not unevaluated_expressions:
|
||||
logger.info("没有未检查的表达方式")
|
||||
return []
|
||||
|
||||
# 随机选择指定数量
|
||||
selected_count = min(count, len(unevaluated_expressions))
|
||||
selected = random.sample(unevaluated_expressions, selected_count)
|
||||
|
||||
logger.info(f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条")
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择表达方式时出错: {e}")
|
||||
return []
|
||||
|
||||
async def _evaluate_expression(self, expression: Expression) -> bool:
|
||||
"""
|
||||
评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 要评估的表达方式
|
||||
|
||||
Returns:
|
||||
True表示通过,False表示不通过
|
||||
"""
|
||||
|
||||
suitable, reason, error = await single_expression_check(
|
||||
expression.situation,
|
||||
expression.style,
|
||||
)
|
||||
|
||||
# 更新数据库
|
||||
try:
|
||||
expression.checked = True
|
||||
expression.rejected = not suitable # 通过则rejected=0,不通过则rejected=1
|
||||
expression.modified_by = 'ai' # 标记为AI检查
|
||||
expression.save()
|
||||
|
||||
status = "通过" if suitable else "不通过"
|
||||
logger.info(
|
||||
f"表达方式评估完成 [ID: {expression.id}] - {status} | "
|
||||
f"Situation: {expression.situation}... | "
|
||||
f"Style: {expression.style}... | "
|
||||
f"Reason: {reason[:50]}..."
|
||||
)
|
||||
|
||||
if error:
|
||||
logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}")
|
||||
|
||||
return suitable
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新表达方式状态失败 [ID: {expression.id}]: {e}")
|
||||
return False
|
||||
|
||||
async def run(self):
|
||||
"""执行检查任务"""
|
||||
try:
|
||||
# 检查是否启用自动检查
|
||||
if not global_config.expression.expression_self_reflect:
|
||||
logger.debug("表达方式自动检查未启用,跳过本次执行")
|
||||
return
|
||||
|
||||
check_count = global_config.expression.expression_auto_check_count
|
||||
if check_count <= 0:
|
||||
logger.warning(f"检查数量配置无效: {check_count},跳过本次执行")
|
||||
return
|
||||
|
||||
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count} 条")
|
||||
|
||||
|
||||
# 选择要检查的表达方式
|
||||
expressions = await self._select_expressions(check_count)
|
||||
|
||||
if not expressions:
|
||||
logger.info("没有需要检查的表达方式")
|
||||
return
|
||||
|
||||
# 逐个评估
|
||||
passed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
|
||||
|
||||
if await self._evaluate_expression(expression):
|
||||
passed_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
# 避免请求过快
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
logger.info(
|
||||
f"表达方式自动检查完成: 总计 {len(expressions)} 条,"
|
||||
f"通过 {passed_count} 条,不通过 {failed_count} 条"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行表达方式自动检查任务时出错: {e}", exc_info=True)
|
||||
|
||||
|
|
@ -18,10 +18,13 @@ from src.bw_learner.learner_utils import (
|
|||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
calculate_style_similarity,
|
||||
calculate_similarity,
|
||||
parse_expression_response,
|
||||
)
|
||||
from src.bw_learner.jargon_miner import miner_manager
|
||||
from json_repair import repair_json
|
||||
from src.bw_learner.expression_auto_check_task import (
|
||||
single_expression_check,
|
||||
)
|
||||
|
||||
|
||||
# MAX_EXPRESSION_COUNT = 300
|
||||
|
|
@ -89,8 +92,9 @@ class ExpressionLearner:
|
|||
model_set=model_config.model_task_config.utils, request_type="expression.learner"
|
||||
)
|
||||
self.summary_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
|
||||
model_set=model_config.model_task_config.tool_use, request_type="expression.summary"
|
||||
)
|
||||
self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
|
@ -136,11 +140,21 @@ class ExpressionLearner:
|
|||
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
|
||||
expressions: List[Tuple[str, str, str]]
|
||||
jargon_entries: List[Tuple[str, str]] # (content, source_id)
|
||||
expressions, jargon_entries = self.parse_expression_response(response)
|
||||
expressions = self._filter_self_reference_styles(expressions)
|
||||
expressions, jargon_entries = parse_expression_response(response)
|
||||
|
||||
# 从缓存中检查 jargon 是否出现在 messages 中
|
||||
cached_jargon_entries = self._check_cached_jargons_in_messages(random_msg)
|
||||
if cached_jargon_entries:
|
||||
# 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过)
|
||||
existing_contents = {content for content, _ in jargon_entries}
|
||||
for content, source_id in cached_jargon_entries:
|
||||
if content not in existing_contents:
|
||||
jargon_entries.append((content, source_id))
|
||||
existing_contents.add(content)
|
||||
logger.info(f"从缓存中检查到黑话: {content}")
|
||||
|
||||
# 检查表达方式数量,如果超过10个则放弃本次表达学习
|
||||
if len(expressions) > 10:
|
||||
if len(expressions) > 20:
|
||||
logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习")
|
||||
expressions = []
|
||||
|
||||
|
|
@ -155,7 +169,7 @@ class ExpressionLearner:
|
|||
|
||||
# 如果没有表达方式,直接返回
|
||||
if not expressions:
|
||||
logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)")
|
||||
logger.info("解析后没有可用的表达方式")
|
||||
return []
|
||||
|
||||
logger.info(f"学习的prompt: {prompt}")
|
||||
|
|
@ -163,9 +177,60 @@ class ExpressionLearner:
|
|||
logger.info(f"学习的jargon_entries: {jargon_entries}")
|
||||
logger.info(f"学习的response: {response}")
|
||||
|
||||
# 直接根据 source_id 在 random_msg 中溯源,获取 context
|
||||
# 过滤表达方式,根据 source_id 溯源并应用各种过滤规则
|
||||
learnt_expressions = self._filter_expressions(expressions, random_msg)
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (situation,style) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for (situation,style) in learnt_expressions:
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
def _filter_expressions(
|
||||
self,
|
||||
expressions: List[Tuple[str, str, str]],
|
||||
messages: List[Any],
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
过滤表达方式,移除不符合条件的条目
|
||||
|
||||
Args:
|
||||
expressions: 表达方式列表,每个元素是 (situation, style, source_id)
|
||||
messages: 原始消息列表,用于溯源和验证
|
||||
|
||||
Returns:
|
||||
过滤后的表达方式列表,每个元素是 (situation, style, context)
|
||||
"""
|
||||
filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
|
||||
|
||||
# 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达)
|
||||
banned_names = set()
|
||||
bot_nickname = (global_config.bot.nickname or "").strip()
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
alias = alias.strip()
|
||||
if alias:
|
||||
banned_names.add(alias)
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
for situation, style, source_id in expressions:
|
||||
source_id_str = (source_id or "").strip()
|
||||
if not source_id_str.isdigit():
|
||||
|
|
@ -173,12 +238,12 @@ class ExpressionLearner:
|
|||
continue
|
||||
|
||||
line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
|
||||
if line_index < 0 or line_index >= len(random_msg):
|
||||
if line_index < 0 or line_index >= len(messages):
|
||||
# 超出范围,跳过
|
||||
continue
|
||||
|
||||
# 当前行的原始内容
|
||||
current_msg = random_msg[line_index]
|
||||
current_msg = messages[line_index]
|
||||
|
||||
# 过滤掉从bot自己发言中提取到的表达方式
|
||||
if is_bot_message(current_msg):
|
||||
|
|
@ -195,251 +260,53 @@ class ExpressionLearner:
|
|||
)
|
||||
continue
|
||||
|
||||
filtered_expressions.append((situation, style, context))
|
||||
|
||||
learnt_expressions = filtered_expressions
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
) in learnt_expressions:
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
def parse_expression_response(self, response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
# 尝试提取 ```json 代码块
|
||||
json_block_pattern = r"```json\s*(.*?)\s*```"
|
||||
match = re.search(json_block_pattern, raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1).strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = None
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
|
||||
try:
|
||||
# 优先尝试直接解析
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
parsed = json.loads(raw)
|
||||
else:
|
||||
repaired = repair_json(raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as parse_error:
|
||||
# 如果解析失败,尝试修复中文引号问题
|
||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
try:
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
if char == '"': # 中文左引号 U+201C
|
||||
result.append('\\"')
|
||||
elif char == '"': # 中文右引号 U+201D
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
# 不在字符串内,直接添加
|
||||
result.append(char)
|
||||
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
|
||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
# 再次尝试解析
|
||||
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
parsed = json.loads(fixed_raw)
|
||||
else:
|
||||
repaired = repair_json(fixed_raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as fix_error:
|
||||
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
|
||||
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
|
||||
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
return [], []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
# 过滤掉 style 与机器人名称/昵称重复的表达
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() in banned_casefold:
|
||||
logger.debug(
|
||||
f"跳过 style 与机器人名称重复的表达方式: situation={situation}, style={style}, source_id={source_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
# 过滤掉包含 "表情:" 或 "表情:" 的内容
|
||||
if "表情:" in (situation or "") or "表情:" in (situation or "") or \
|
||||
"表情:" in (style or "") or "表情:" in (style or "") or \
|
||||
"表情:" in context or "表情:" in context:
|
||||
logger.info(
|
||||
f"跳过包含表情标记的表达方式: situation={situation}, style={style}, source_id={source_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
elif item.get("content"):
|
||||
# 黑话条目(有 content 字段)
|
||||
content = str(item.get("content", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
# 过滤掉包含 "[图片" 的内容
|
||||
if "[图片" in (situation or "") or "[图片" in (style or "") or "[图片" in context:
|
||||
logger.info(
|
||||
f"跳过包含图片标记的表达方式: situation={situation}, style={style}, source_id={source_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
return expressions, jargon_entries
|
||||
filtered_expressions.append((situation, style))
|
||||
|
||||
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
过滤掉style与机器人名称/昵称重复的表达
|
||||
"""
|
||||
banned_names = set()
|
||||
bot_nickname = (global_config.bot.nickname or "").strip()
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
alias = alias.strip()
|
||||
if alias:
|
||||
banned_names.add(alias)
|
||||
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
filtered: List[Tuple[str, str, str]] = []
|
||||
removed_count = 0
|
||||
for situation, style, source_id in expressions:
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() not in banned_casefold:
|
||||
filtered.append((situation, style, source_id))
|
||||
else:
|
||||
removed_count += 1
|
||||
|
||||
if removed_count:
|
||||
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
|
||||
|
||||
return filtered
|
||||
return filtered_expressions
|
||||
|
||||
async def _upsert_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
# 第一层:检查是否有完全一致的 style(检查 style 字段和 style_list)
|
||||
expr_obj = await self._find_exact_style_match(style)
|
||||
# 检查是否有相似的 situation(相似度 >= 0.75,检查 content_list)
|
||||
# 完全匹配(相似度 == 1.0)和相似匹配(相似度 >= 0.75)统一处理
|
||||
expr_obj, similarity = await self._find_similar_situation_expression(situation, similarity_threshold=0.75)
|
||||
|
||||
if expr_obj:
|
||||
# 找到完全匹配的 style,合并到现有记录(不使用 LLM 总结)
|
||||
# 根据相似度决定是否使用 LLM 总结
|
||||
# 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结
|
||||
use_llm_summary = similarity < 1.0
|
||||
await self._update_existing_expression(
|
||||
expr_obj=expr_obj,
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
use_llm_summary=False,
|
||||
)
|
||||
return
|
||||
|
||||
# 第二层:检查是否有相似的 style(相似度 >= 0.75,检查 style 字段和 style_list)
|
||||
similar_expr_obj = await self._find_similar_style_expression(style, similarity_threshold=0.75)
|
||||
|
||||
if similar_expr_obj:
|
||||
# 找到相似的 style,合并到现有记录(使用 LLM 总结)
|
||||
await self._update_existing_expression(
|
||||
expr_obj=similar_expr_obj,
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
use_llm_summary=True,
|
||||
use_llm_summary=use_llm_summary,
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -447,7 +314,6 @@ class ExpressionLearner:
|
|||
await self._create_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
|
|
@ -455,7 +321,6 @@ class ExpressionLearner:
|
|||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = [situation]
|
||||
|
|
@ -466,26 +331,22 @@ class ExpressionLearner:
|
|||
situation=formatted_situation,
|
||||
style=style,
|
||||
content_list=json.dumps(content_list, ensure_ascii=False),
|
||||
style_list=None, # 新记录初始时 style_list 为空
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def _update_existing_expression(
|
||||
self,
|
||||
expr_obj: Expression,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
use_llm_summary: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
更新现有 Expression 记录(style 完全匹配或相似的情况)
|
||||
将新的 situation 添加到 content_list,将新的 style 添加到 style_list(如果不同)
|
||||
更新现有 Expression 记录(situation 完全匹配或相似的情况)
|
||||
将新的 situation 添加到 content_list,不合并 style
|
||||
|
||||
Args:
|
||||
use_llm_summary: 是否使用 LLM 进行总结,完全匹配时为 False,相似匹配时为 True
|
||||
|
|
@ -495,43 +356,24 @@ class ExpressionLearner:
|
|||
content_list.append(situation)
|
||||
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
|
||||
|
||||
# 更新 style_list(如果 style 不同,添加到 style_list)
|
||||
style_list = self._parse_style_list(expr_obj.style_list)
|
||||
# 将原有的 style 也加入 style_list(如果还没有的话)
|
||||
if expr_obj.style and expr_obj.style not in style_list:
|
||||
style_list.append(expr_obj.style)
|
||||
# 如果新的 style 不在 style_list 中,添加它
|
||||
if style not in style_list:
|
||||
style_list.append(style)
|
||||
expr_obj.style_list = json.dumps(style_list, ensure_ascii=False)
|
||||
|
||||
# 更新其他字段
|
||||
expr_obj.count = (expr_obj.count or 0) + 1
|
||||
expr_obj.checked = False # count 增加时重置 checked 为 False
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.context = context
|
||||
|
||||
if use_llm_summary:
|
||||
# 相似匹配时,使用 LLM 重新组合 situation 和 style
|
||||
# 相似匹配时,使用 LLM 重新组合 situation
|
||||
new_situation = await self._compose_situation_text(
|
||||
content_list=content_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.situation,
|
||||
)
|
||||
expr_obj.situation = new_situation
|
||||
|
||||
new_style = await self._compose_style_text(
|
||||
style_list=style_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.style or style,
|
||||
)
|
||||
expr_obj.style = new_style
|
||||
else:
|
||||
# 完全匹配时,不进行 LLM 总结,保持原有的 situation 和 style 不变
|
||||
# 只更新 content_list 和 style_list
|
||||
pass
|
||||
|
||||
expr_obj.save()
|
||||
|
||||
# count 增加后,立即进行一次检查
|
||||
await self._check_expression_immediately(expr_obj)
|
||||
|
||||
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
if not stored_list:
|
||||
return []
|
||||
|
|
@ -541,49 +383,19 @@ class ExpressionLearner:
|
|||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
def _parse_style_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 style_list JSON 字符串为列表,逻辑与 _parse_content_list 相同"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
async def _find_exact_style_match(self, style: str) -> Optional[Expression]:
|
||||
async def _find_similar_situation_expression(self, situation: str, similarity_threshold: float = 0.75) -> Tuple[Optional[Expression], float]:
|
||||
"""
|
||||
查找具有完全匹配 style 的 Expression 记录
|
||||
只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
|
||||
查找具有相似 situation 的 Expression 记录
|
||||
检查 content_list 中的每一项
|
||||
|
||||
Args:
|
||||
style: 要查找的 style
|
||||
|
||||
Returns:
|
||||
找到的 Expression 对象,如果没有找到则返回 None
|
||||
"""
|
||||
# 查询同一 chat_id 的所有记录
|
||||
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
|
||||
|
||||
for expr in all_expressions:
|
||||
# 只检查 style_list 中的每一项
|
||||
style_list = self._parse_style_list(expr.style_list)
|
||||
if style in style_list:
|
||||
return expr
|
||||
|
||||
return None
|
||||
|
||||
async def _find_similar_style_expression(self, style: str, similarity_threshold: float = 0.75) -> Optional[Expression]:
|
||||
"""
|
||||
查找具有相似 style 的 Expression 记录
|
||||
只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
|
||||
|
||||
Args:
|
||||
style: 要查找的 style
|
||||
situation: 要查找的 situation
|
||||
similarity_threshold: 相似度阈值,默认 0.75
|
||||
|
||||
Returns:
|
||||
找到的最相似的 Expression 对象,如果没有找到则返回 None
|
||||
Tuple[Optional[Expression], float]:
|
||||
- 找到的最相似的 Expression 对象,如果没有找到则返回 None
|
||||
- 相似度值(如果找到匹配,范围在 similarity_threshold 到 1.0 之间)
|
||||
"""
|
||||
# 查询同一 chat_id 的所有记录
|
||||
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
|
||||
|
|
@ -592,96 +404,28 @@ class ExpressionLearner:
|
|||
best_similarity = 0.0
|
||||
|
||||
for expr in all_expressions:
|
||||
# 只检查 style_list 中的每一项
|
||||
style_list = self._parse_style_list(expr.style_list)
|
||||
for existing_style in style_list:
|
||||
similarity = calculate_style_similarity(style, existing_style)
|
||||
# 检查 content_list 中的每一项
|
||||
content_list = self._parse_content_list(expr.content_list)
|
||||
for existing_situation in content_list:
|
||||
similarity = calculate_similarity(situation, existing_situation)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
if best_match:
|
||||
logger.debug(f"找到相似的 style: 相似度={best_similarity:.3f}, 现有='{best_match.style}', 新='{style}'")
|
||||
logger.debug(f"找到相似的 situation: 相似度={best_similarity:.3f}, 现有='{best_match.situation}', 新='{situation}'")
|
||||
|
||||
return best_match
|
||||
return best_match, best_similarity
|
||||
|
||||
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
|
||||
async def _compose_situation_text(self, content_list: List[str], fallback: str = "") -> str:
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
summary = await self._summarize_situations(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
return "/".join(sanitized) if sanitized else fallback
|
||||
|
||||
async def _compose_style_text(self, style_list: List[str], count: int, fallback: str = "") -> str:
|
||||
"""
|
||||
组合 style 文本,如果 style_list 有多个元素则尝试总结
|
||||
"""
|
||||
sanitized = [s.strip() for s in style_list if s.strip()]
|
||||
if len(sanitized) > 1:
|
||||
# 只有当有多个 style 时才尝试总结
|
||||
summary = await self._summarize_styles(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
# 如果只有一个或总结失败,返回第一个或 fallback
|
||||
return sanitized[0] if sanitized else fallback
|
||||
|
||||
async def _summarize_styles(self, styles: List[str]) -> Optional[str]:
|
||||
"""总结多个 style,生成一个概括性的 style 描述"""
|
||||
if not styles or len(styles) <= 1:
|
||||
return None
|
||||
|
||||
# 计算输入列表中最长项目的长度
|
||||
max_input_length = max(len(s) for s in styles) if styles else 0
|
||||
max_summary_length = max_input_length * 2
|
||||
|
||||
# 最多重试3次
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < max_retries:
|
||||
# 如果是重试,在 prompt 中强调要更简洁
|
||||
length_hint = f"长度不超过{max_summary_length}个字符," if retry_count > 0 else "长度不超过20个字,"
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个语言风格/表达方式,对其进行总结。"
|
||||
"不要对其进行语义概括,而是尽可能找出其中不变的部分或共同表达,尽量使用原文"
|
||||
f"{length_hint}保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in styles[-10:])}\n只输出概括内容。不要输出其他内容"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
# 检查总结长度是否超过限制
|
||||
if len(summary) <= max_summary_length:
|
||||
return summary
|
||||
else:
|
||||
retry_count += 1
|
||||
logger.debug(
|
||||
f"总结长度 {len(summary)} 超过限制 {max_summary_length} "
|
||||
f"(输入最长项长度: {max_input_length}),重试第 {retry_count} 次"
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达风格失败: {e}")
|
||||
return None
|
||||
|
||||
# 如果重试多次后仍然超过长度,返回 None(不进行总结)
|
||||
logger.warning(
|
||||
f"总结多次后仍超过长度限制,放弃总结。"
|
||||
f"输入最长项长度: {max_input_length}, 最大允许长度: {max_summary_length}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
|
||||
if not situations:
|
||||
return None
|
||||
if not sanitized:
|
||||
return fallback
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
|
||||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -691,7 +435,126 @@ class ExpressionLearner:
|
|||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达情境失败: {e}")
|
||||
return None
|
||||
return "/".join(sanitized) if sanitized else fallback
|
||||
|
||||
async def _init_check_model(self) -> None:
|
||||
"""初始化检查用的 LLM 实例"""
|
||||
if self.check_model is None:
|
||||
try:
|
||||
self.check_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression.check"
|
||||
)
|
||||
logger.debug("检查用 LLM 实例初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"创建检查用 LLM 实例失败: {e}")
|
||||
|
||||
async def _check_expression_immediately(self, expr_obj: Expression) -> None:
|
||||
"""
|
||||
立即检查表达方式(在 count 增加后调用)
|
||||
|
||||
Args:
|
||||
expr_obj: 要检查的表达方式对象
|
||||
"""
|
||||
try:
|
||||
# 检查是否启用自动检查
|
||||
if not global_config.expression.expression_self_reflect:
|
||||
logger.debug("表达方式自动检查未启用,跳过立即检查")
|
||||
return
|
||||
|
||||
# 初始化检查用的 LLM
|
||||
await self._init_check_model()
|
||||
if self.check_model is None:
|
||||
logger.warning("检查用 LLM 实例初始化失败,跳过立即检查")
|
||||
return
|
||||
|
||||
# 执行 LLM 评估
|
||||
suitable, reason, error = await single_expression_check(
|
||||
expr_obj.situation,
|
||||
expr_obj.style
|
||||
)
|
||||
|
||||
# 更新数据库
|
||||
expr_obj.checked = True
|
||||
expr_obj.rejected = not suitable # 通过则 rejected=False,不通过则 rejected=True
|
||||
expr_obj.save()
|
||||
|
||||
status = "通过" if suitable else "不通过"
|
||||
logger.info(
|
||||
f"表达方式立即检查完成 [ID: {expr_obj.id}] - {status} | "
|
||||
f"Situation: {expr_obj.situation[:30]}... | "
|
||||
f"Style: {expr_obj.style[:30]}... | "
|
||||
f"Reason: {reason[:50] if reason else '无'}..."
|
||||
)
|
||||
|
||||
if error:
|
||||
logger.warning(f"表达方式立即检查时出现错误 [ID: {expr_obj.id}]: {error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"立即检查表达方式失败 [ID: {expr_obj.id}]: {e}", exc_info=True)
|
||||
# 检查失败时,保持 checked=False,等待后续自动检查任务处理
|
||||
|
||||
def _check_cached_jargons_in_messages(self, messages: List[Any]) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
检查缓存中的 jargon 是否出现在 messages 中
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 匹配到的黑话条目列表,每个元素是 (content, source_id)
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# 获取 jargon_miner 实例
|
||||
jargon_miner = miner_manager.get_miner(self.chat_id)
|
||||
|
||||
# 获取缓存中的所有 jargon
|
||||
cached_jargons = jargon_miner.get_cached_jargons()
|
||||
if not cached_jargons:
|
||||
return []
|
||||
|
||||
matched_entries: List[Tuple[str, str]] = []
|
||||
|
||||
# 遍历 messages,检查缓存中的 jargon 是否出现
|
||||
for i, msg in enumerate(messages):
|
||||
# 跳过机器人自己的消息
|
||||
if is_bot_message(msg):
|
||||
continue
|
||||
|
||||
# 获取消息文本
|
||||
msg_text = (
|
||||
getattr(msg, "processed_plain_text", None) or
|
||||
""
|
||||
).strip()
|
||||
|
||||
if not msg_text:
|
||||
continue
|
||||
|
||||
# 检查每个缓存中的 jargon 是否出现在消息文本中
|
||||
for jargon in cached_jargons:
|
||||
if not jargon or not jargon.strip():
|
||||
continue
|
||||
|
||||
jargon_content = jargon.strip()
|
||||
|
||||
# 使用正则匹配,考虑单词边界(类似 jargon_explainer 中的逻辑)
|
||||
pattern = re.escape(jargon_content)
|
||||
# 对于中文,使用更宽松的匹配;对于英文/数字,使用单词边界
|
||||
if re.search(r"[\u4e00-\u9fff]", jargon_content):
|
||||
# 包含中文,使用更宽松的匹配
|
||||
search_pattern = pattern
|
||||
else:
|
||||
# 纯英文/数字,使用单词边界
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, msg_text, re.IGNORECASE):
|
||||
# 找到匹配,构建条目(source_id 从 1 开始,因为 build_anonymous_messages 的编号从 1 开始)
|
||||
source_id = str(i + 1)
|
||||
matched_entries.append((jargon_content, source_id))
|
||||
|
||||
return matched_entries
|
||||
|
||||
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -28,11 +28,11 @@ class ExpressionReflector:
|
|||
try:
|
||||
logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
|
||||
|
||||
if not global_config.expression.reflect:
|
||||
if not global_config.expression.expression_self_reflect:
|
||||
logger.debug("[Expression Reflection] 表达反思功能未启用,跳过")
|
||||
return False
|
||||
|
||||
operator_config = global_config.expression.reflect_operator_id
|
||||
operator_config = global_config.expression.manual_reflect_operator_id
|
||||
if not operator_config:
|
||||
logger.debug("[Expression Reflection] Operator ID 未配置,跳过")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ def init_prompt():
|
|||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
model_set=model_config.model_task_config.tool_use, request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
|
|
@ -123,9 +123,11 @@ class ExpressionSelector:
|
|||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
|
||||
)
|
||||
# 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的
|
||||
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
|
||||
if global_config.expression.expression_checked_only:
|
||||
base_conditions = base_conditions & (Expression.checked)
|
||||
style_query = Expression.select().where(base_conditions)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
|
|
@ -202,7 +204,11 @@ class ExpressionSelector:
|
|||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||
# 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的
|
||||
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
if global_config.expression.expression_checked_only:
|
||||
base_conditions = base_conditions & (Expression.checked)
|
||||
style_query = Expression.select().where(base_conditions)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
|
|
@ -295,7 +301,11 @@ class ExpressionSelector:
|
|||
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
||||
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||
# 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的
|
||||
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
if global_config.expression.expression_checked_only:
|
||||
base_conditions = base_conditions & (Expression.checked)
|
||||
style_query = Expression.select().where(base_conditions)
|
||||
|
||||
all_style_exprs = [
|
||||
{
|
||||
|
|
@ -407,8 +417,8 @@ class ExpressionSelector:
|
|||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
print(prompt)
|
||||
print(content)
|
||||
# print(prompt)
|
||||
# print(content)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class JargonExplainer:
|
|||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="jargon.explain",
|
||||
)
|
||||
|
||||
|
|
@ -341,7 +341,7 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
|
|||
meaning = result.get("meaning", "").strip()
|
||||
if found_content and meaning:
|
||||
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||
results.append(",".join(output_parts))
|
||||
results.append("\n".join(output_parts)) # 换行分隔每个jargon解释
|
||||
logger.info(f"在jargon库中找到匹配(模糊搜索): {concept},找到{len(jargon_results)}条结果")
|
||||
else:
|
||||
# 精确匹配
|
||||
|
|
@ -350,7 +350,8 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
|
|||
meaning = result.get("meaning", "").strip()
|
||||
if meaning:
|
||||
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
|
||||
results.append(";".join(output_parts) if len(output_parts) > 1 else output_parts[0])
|
||||
# 换行分隔每个jargon解释
|
||||
results.append("\n".join(output_parts) if len(output_parts) > 1 else output_parts[0])
|
||||
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
|
||||
else:
|
||||
# 未找到,不返回占位信息,只记录日志
|
||||
|
|
@ -361,5 +362,5 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
|
|||
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
|
||||
|
||||
if results:
|
||||
return "【概念检索结果】\n" + "\n".join(results) + "\n"
|
||||
return "你了解以下词语可能的含义:\n" + "\n".join(results) + "\n"
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import asyncio
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import List, Dict, Optional, Any, Callable
|
||||
from typing import List, Dict, Optional, Callable
|
||||
from json_repair import repair_json
|
||||
from peewee import fn
|
||||
|
||||
|
|
@ -11,14 +11,8 @@ from src.common.database.database_model import Jargon
|
|||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.bw_learner.learner_utils import (
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
parse_chat_id_list,
|
||||
chat_id_list_contains,
|
||||
update_chat_id_list,
|
||||
|
|
@ -51,32 +45,32 @@ def _is_single_char_jargon(content: str) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _init_prompt() -> None:
|
||||
prompt_str = """
|
||||
**聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
||||
{chat_str}
|
||||
# def _init_prompt() -> None:
|
||||
# prompt_str = """
|
||||
# **聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
||||
# {chat_str}
|
||||
|
||||
请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
||||
- 必须为对话中真实出现过的短词或短语
|
||||
- 必须是你无法理解含义的词语,没有明确含义的词语,请不要选择有明确含义,或者含义清晰的词语
|
||||
- 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
|
||||
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
||||
# 请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
||||
# - 必须为对话中真实出现过的短词或短语
|
||||
# - 必须是你无法理解含义的词语,没有明确含义的词语,请不要选择有明确含义,或者含义清晰的词语
|
||||
# - 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
|
||||
# - 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
||||
|
||||
黑话必须为以下几种类型:
|
||||
- 由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
||||
- 英文词语的缩写,用英文字母概括一个词汇或含义,例如:CPU、GPU、API
|
||||
- 中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
||||
# 黑话必须为以下几种类型:
|
||||
# - 由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
||||
# - 英文词语的缩写,用英文字母概括一个词汇或含义,例如:CPU、GPU、API
|
||||
# - 中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
||||
|
||||
以 JSON 数组输出,元素为对象(严格按以下结构):
|
||||
请你提取出可能的黑话,最多30个黑话,请尽量提取所有
|
||||
[
|
||||
{{"content": "词条", "msg_id": "m12"}}, // msg_id 必须与上方聊天中展示的ID完全一致
|
||||
{{"content": "词条2", "msg_id": "m15"}}
|
||||
]
|
||||
# 以 JSON 数组输出,元素为对象(严格按以下结构):
|
||||
# 请你提取出可能的黑话,最多30个黑话,请尽量提取所有
|
||||
# [
|
||||
# {{"content": "词条", "msg_id": "m12"}}, // msg_id 必须与上方聊天中展示的ID完全一致
|
||||
# {{"content": "词条2", "msg_id": "m15"}}
|
||||
# ]
|
||||
|
||||
现在请输出:
|
||||
"""
|
||||
Prompt(prompt_str, "extract_jargon_prompt")
|
||||
# 现在请输出:
|
||||
# """
|
||||
# Prompt(prompt_str, "extract_jargon_prompt")
|
||||
|
||||
|
||||
def _init_inference_prompts() -> None:
|
||||
|
|
@ -142,7 +136,6 @@ def _init_inference_prompts() -> None:
|
|||
Prompt(prompt3_str, "jargon_compare_inference_prompt")
|
||||
|
||||
|
||||
_init_prompt()
|
||||
_init_inference_prompts()
|
||||
|
||||
|
||||
|
|
@ -229,34 +222,9 @@ class JargonMiner:
|
|||
if len(self.cache) > self.cache_limit:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
def _collect_cached_entries(self, messages: List[Any]) -> List[Dict[str, List[str]]]:
|
||||
"""检查缓存中的黑话是否出现在当前消息窗口,生成对应上下文"""
|
||||
if not self.cache or not messages:
|
||||
return []
|
||||
|
||||
cached_entries: List[Dict[str, List[str]]] = []
|
||||
processed_pairs = set()
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
msg_text = (
|
||||
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||
).strip()
|
||||
if not msg_text or is_bot_message(msg):
|
||||
continue
|
||||
|
||||
for content in self.cache.keys():
|
||||
if not content:
|
||||
continue
|
||||
if (content, idx) in processed_pairs:
|
||||
continue
|
||||
if content in msg_text:
|
||||
paragraph = build_context_paragraph(messages, idx)
|
||||
if not paragraph:
|
||||
continue
|
||||
cached_entries.append({"content": content, "raw_content": [paragraph]})
|
||||
processed_pairs.add((content, idx))
|
||||
|
||||
return cached_entries
|
||||
def get_cached_jargons(self) -> List[str]:
|
||||
"""获取缓存中的所有黑话列表"""
|
||||
return list(self.cache.keys())
|
||||
|
||||
async def _infer_meaning_by_id(self, jargon_id: int) -> None:
|
||||
"""通过ID加载对象并推断"""
|
||||
|
|
@ -480,263 +448,6 @@ class JargonMiner:
|
|||
|
||||
traceback.print_exc()
|
||||
|
||||
async def run_once(
|
||||
self,
|
||||
messages: List[Any],
|
||||
person_name_filter: Optional[Callable[[str], bool]] = None
|
||||
) -> None:
|
||||
"""
|
||||
运行一次黑话提取
|
||||
|
||||
Args:
|
||||
messages: 外部传入的消息列表(必需)
|
||||
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
|
||||
"""
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._extraction_lock:
|
||||
try:
|
||||
if not messages:
|
||||
return
|
||||
|
||||
# 按时间排序,确保编号与上下文一致
|
||||
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
||||
|
||||
chat_str, message_id_list = build_readable_messages_with_id(
|
||||
messages=messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
pic_single=True,
|
||||
)
|
||||
if not chat_str.strip():
|
||||
return
|
||||
|
||||
msg_id_to_index: Dict[str, int] = {}
|
||||
for idx, (msg_id, _msg) in enumerate(message_id_list or []):
|
||||
if not msg_id:
|
||||
continue
|
||||
msg_id_to_index[msg_id] = idx
|
||||
if not msg_id_to_index:
|
||||
logger.warning("未能生成消息ID映射,跳过本次提取")
|
||||
return
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"extract_jargon_prompt",
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_str=chat_str,
|
||||
)
|
||||
|
||||
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
|
||||
if not response:
|
||||
return
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon提取提示词: {prompt}")
|
||||
logger.info(f"jargon提取结果: {response}")
|
||||
|
||||
# 解析为JSON
|
||||
entries: List[dict] = []
|
||||
try:
|
||||
resp = response.strip()
|
||||
parsed = None
|
||||
if resp.startswith("[") and resp.endswith("]"):
|
||||
parsed = json.loads(resp)
|
||||
else:
|
||||
repaired = repair_json(resp)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed = [parsed]
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
return
|
||||
|
||||
for item in parsed:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
content = str(item.get("content", "")).strip()
|
||||
msg_id_value = item.get("msg_id")
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if contains_bot_self_name(content):
|
||||
logger.info(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||
continue
|
||||
|
||||
# 检查是否包含人物名称
|
||||
if person_name_filter and person_name_filter(content):
|
||||
logger.info(f"解析阶段跳过包含人物名称的词条: {content}")
|
||||
continue
|
||||
|
||||
msg_id_str = str(msg_id_value or "").strip()
|
||||
if not msg_id_str:
|
||||
logger.warning(f"解析jargon失败:msg_id缺失,content={content}")
|
||||
continue
|
||||
|
||||
msg_index = msg_id_to_index.get(msg_id_str)
|
||||
if msg_index is None:
|
||||
logger.warning(f"解析jargon失败:msg_id未找到,content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
target_msg = messages[msg_index]
|
||||
if is_bot_message(target_msg):
|
||||
logger.info(f"解析阶段跳过引用机器人自身消息的词条: content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
context_paragraph = build_context_paragraph(messages, msg_index)
|
||||
if not context_paragraph:
|
||||
logger.warning(f"解析jargon失败:上下文为空,content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
entries.append({"content": content, "raw_content": [context_paragraph]})
|
||||
cached_entries = self._collect_cached_entries(messages)
|
||||
if cached_entries:
|
||||
entries.extend(cached_entries)
|
||||
except Exception as e:
|
||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||
return
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
# 去重并合并raw_content(按 content 聚合)
|
||||
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
|
||||
for entry in entries:
|
||||
content_key = entry["content"]
|
||||
raw_list = entry.get("raw_content", []) or []
|
||||
if content_key in merged_entries:
|
||||
merged_entries[content_key]["raw_content"].extend(raw_list)
|
||||
else:
|
||||
merged_entries[content_key] = {
|
||||
"content": content_key,
|
||||
"raw_content": list(raw_list),
|
||||
}
|
||||
|
||||
uniq_entries = []
|
||||
for merged_entry in merged_entries.values():
|
||||
raw_content_list = merged_entry["raw_content"]
|
||||
if raw_content_list:
|
||||
merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list))
|
||||
uniq_entries.append(merged_entry)
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_list = entry["raw_content"] # 已经是列表
|
||||
|
||||
try:
|
||||
# 查询所有content匹配的记录
|
||||
query = Jargon.select().where(Jargon.content == content)
|
||||
|
||||
# 查找匹配的记录
|
||||
matched_obj = None
|
||||
for obj in query:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有content匹配的记录都可以
|
||||
matched_obj = obj
|
||||
break
|
||||
else:
|
||||
# 关闭all_global:需要检查chat_id列表是否包含目标chat_id
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
if chat_id_list_contains(chat_id_list, self.chat_id):
|
||||
matched_obj = obj
|
||||
break
|
||||
|
||||
if matched_obj:
|
||||
obj = matched_obj
|
||||
try:
|
||||
obj.count = (obj.count or 0) + 1
|
||||
except Exception:
|
||||
obj.count = 1
|
||||
|
||||
# 合并raw_content列表:读取现有列表,追加新值,去重
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content)
|
||||
if isinstance(obj.raw_content, str)
|
||||
else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
# 更新chat_id列表:增加当前chat_id的计数
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1)
|
||||
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.expression.all_global_jargon:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
obj.save()
|
||||
|
||||
# 检查是否需要推断(达到阈值且超过上次判定值)
|
||||
if _should_infer_meaning(obj):
|
||||
# 异步触发推断,不阻塞主流程
|
||||
# 重新加载对象以确保数据最新
|
||||
jargon_id = obj.id
|
||||
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
|
||||
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:新记录默认为is_global=True
|
||||
is_global_new = True
|
||||
else:
|
||||
# 关闭all_global:新记录is_global=False
|
||||
is_global_new = False
|
||||
|
||||
# 使用新格式创建chat_id列表:[[chat_id, count]]
|
||||
chat_id_list = [[self.chat_id, 1]]
|
||||
chat_id_json = json.dumps(chat_id_list, ensure_ascii=False)
|
||||
|
||||
Jargon.create(
|
||||
content=content,
|
||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=chat_id_json,
|
||||
is_global=is_global_new,
|
||||
count=1,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
|
||||
continue
|
||||
finally:
|
||||
self._add_to_cache(content)
|
||||
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
|
||||
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
if saved or updated:
|
||||
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"JargonMiner 运行失败: {e}")
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def process_extracted_entries(
|
||||
self,
|
||||
entries: List[Dict[str, List[str]]],
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ import re
|
|||
import difflib
|
||||
import random
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
|
@ -11,6 +10,7 @@ from src.chat.utils.chat_message_builder import (
|
|||
build_readable_messages,
|
||||
)
|
||||
from src.chat.utils.utils import parse_platform_accounts
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
logger = get_logger("learner_utils")
|
||||
|
|
@ -88,33 +88,15 @@ def calculate_style_similarity(style1: str, style2: str) -> float:
|
|||
return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio()
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
|
||||
Returns:
|
||||
str: 格式化后的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def _compute_weights(population: List[Dict]) -> List[float]:
|
||||
"""
|
||||
根据表达的count计算权重,范围限定在1~5之间。
|
||||
count越高,权重越高,但最多为基础权重的5倍。
|
||||
如果表达已checked,权重会再乘以3倍。
|
||||
"""
|
||||
if not population:
|
||||
return []
|
||||
|
||||
counts = []
|
||||
checked_flags = []
|
||||
for item in population:
|
||||
count = item.get("count", 1)
|
||||
try:
|
||||
|
|
@ -122,29 +104,19 @@ def _compute_weights(population: List[Dict]) -> List[float]:
|
|||
except (TypeError, ValueError):
|
||||
count_value = 1.0
|
||||
counts.append(max(count_value, 0.0))
|
||||
# 获取checked状态
|
||||
checked = item.get("checked", False)
|
||||
checked_flags.append(bool(checked))
|
||||
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
|
||||
if max_count == min_count:
|
||||
base_weights = [1.0 for _ in counts]
|
||||
weights = [1.0 for _ in counts]
|
||||
else:
|
||||
base_weights = []
|
||||
weights = []
|
||||
for count_value in counts:
|
||||
# 线性映射到[1,5]区间
|
||||
normalized = (count_value - min_count) / (max_count - min_count)
|
||||
base_weights.append(1.0 + normalized * 4.0) # 1~5
|
||||
weights.append(1.0 + normalized * 4.0) # 1~5
|
||||
|
||||
# 如果checked,权重乘以3
|
||||
weights = []
|
||||
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
|
||||
if checked:
|
||||
weights.append(base_weight * 3.0)
|
||||
else:
|
||||
weights.append(base_weight)
|
||||
return weights
|
||||
|
||||
|
||||
|
|
@ -378,3 +350,149 @@ def is_bot_message(msg: Any) -> bool:
|
|||
|
||||
bot_account = bot_accounts.get(platform)
|
||||
return bool(bot_account and user_id == bot_account)
|
||||
|
||||
|
||||
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
# 尝试提取 ```json 代码块
|
||||
json_block_pattern = r"```json\s*(.*?)\s*```"
|
||||
match = re.search(json_block_pattern, raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1).strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = None
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
|
||||
try:
|
||||
# 优先尝试直接解析
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
parsed = json.loads(raw)
|
||||
else:
|
||||
repaired = repair_json(raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as parse_error:
|
||||
# 如果解析失败,尝试修复中文引号问题
|
||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
try:
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
if char == '"': # 中文左引号 U+201C
|
||||
result.append('\\"')
|
||||
elif char == '"': # 中文右引号 U+201D
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
# 不在字符串内,直接添加
|
||||
result.append(char)
|
||||
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
|
||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
# 再次尝试解析
|
||||
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
parsed = json.loads(fixed_raw)
|
||||
else:
|
||||
repaired = repair_json(fixed_raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as fix_error:
|
||||
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
|
||||
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
|
||||
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
return [], []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
elif item.get("content"):
|
||||
# 黑话条目(有 content 字段)
|
||||
content = str(item.get("content", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
|
||||
return expressions, jargon_entries
|
||||
|
|
@ -116,20 +116,12 @@ class MessageRecorder:
|
|||
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
|
||||
)
|
||||
|
||||
# 分别触发 expression_learner 和 jargon_miner 的处理
|
||||
# 传递提取的消息,避免它们重复获取
|
||||
# 触发 expression 学习(如果启用)
|
||||
# 触发 expression_learner 和 jargon_miner 的处理
|
||||
if self.enable_expression_learning:
|
||||
asyncio.create_task(
|
||||
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
|
||||
self._trigger_expression_learning(messages)
|
||||
)
|
||||
|
||||
# 触发 jargon 提取(如果启用),传递消息
|
||||
# if self.enable_jargon_learning:
|
||||
# asyncio.create_task(
|
||||
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
||||
# )
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||
import traceback
|
||||
|
|
@ -138,7 +130,7 @@ class MessageRecorder:
|
|||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def _trigger_expression_learning(
|
||||
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||
self, messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 expression 学习,使用指定的消息列表
|
||||
|
|
@ -162,27 +154,6 @@ class MessageRecorder:
|
|||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _trigger_jargon_extraction(
|
||||
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 jargon 提取,使用指定的消息列表
|
||||
|
||||
Args:
|
||||
timestamp_start: 开始时间戳
|
||||
timestamp_end: 结束时间戳
|
||||
messages: 消息列表
|
||||
"""
|
||||
try:
|
||||
# 传递消息给 JargonMiner,避免它重复获取
|
||||
await self.jargon_miner.run_once(messages=messages)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class MessageRecorderManager:
|
||||
"""MessageRecorder 管理器"""
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class ReflectTracker:
|
|||
self.max_duration = 15 * 60 # 15 minutes
|
||||
|
||||
# LLM for judging response
|
||||
self.judge_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reflect.tracker")
|
||||
self.judge_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="reflect.tracker")
|
||||
|
||||
self._init_prompts()
|
||||
|
||||
|
|
@ -134,12 +134,14 @@ class ReflectTracker:
|
|||
if judgment == "Approve":
|
||||
self.expression.checked = True
|
||||
self.expression.rejected = False
|
||||
self.expression.modified_by = 'ai' # 通过LLM判断也标记为ai
|
||||
self.expression.save()
|
||||
logger.info(f"Expression {self.expression.id} approved by operator.")
|
||||
return True
|
||||
|
||||
elif judgment == "Reject":
|
||||
self.expression.checked = True
|
||||
self.expression.modified_by = 'ai' # 通过LLM判断也标记为ai
|
||||
corrected_situation = json_obj.get("corrected_situation")
|
||||
corrected_style = json_obj.get("corrected_style")
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from json_repair import repair_json
|
|||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
|
|
@ -261,6 +262,7 @@ class BrainPlanner:
|
|||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作(ReAct模式)。
|
||||
"""
|
||||
plan_start = time.perf_counter()
|
||||
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
|
|
@ -298,6 +300,7 @@ class BrainPlanner:
|
|||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
prompt_build_start = time.perf_counter()
|
||||
# 构建包含所有动作的提示词:使用统一的 ReAct Prompt
|
||||
prompt_key = "brain_planner_prompt_react"
|
||||
# 这里不记录日志,避免重复打印,由调用方按需控制 log_prompt
|
||||
|
|
@ -308,9 +311,10 @@ class BrainPlanner:
|
|||
message_id_list=message_id_list,
|
||||
prompt_key=prompt_key,
|
||||
)
|
||||
prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000
|
||||
|
||||
# 调用LLM获取决策
|
||||
reasoning, actions = await self._execute_main_planner(
|
||||
reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
|
|
@ -324,6 +328,25 @@ class BrainPlanner:
|
|||
)
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
try:
|
||||
PlanReplyLogger.log_plan(
|
||||
chat_id=self.chat_id,
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
raw_output=llm_raw_output,
|
||||
raw_reasoning=llm_reasoning,
|
||||
actions=actions,
|
||||
timing={
|
||||
"prompt_build_ms": round(prompt_build_ms, 2),
|
||||
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
|
||||
"total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2),
|
||||
"loop_start_time": loop_start_time,
|
||||
},
|
||||
extra=None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix}记录plan日志失败")
|
||||
|
||||
return actions
|
||||
|
||||
async def build_planner_prompt(
|
||||
|
|
@ -421,7 +444,7 @@ class BrainPlanner:
|
|||
if action_info.activation_type == ActionActivationType.NEVER:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
|
||||
continue
|
||||
elif action_info.activation_type in [ActionActivationType.LLM_JUDGE, ActionActivationType.ALWAYS]:
|
||||
elif action_info.activation_type == ActionActivationType.ALWAYS:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.RANDOM:
|
||||
if random.random() < action_info.random_activation_probability:
|
||||
|
|
@ -479,15 +502,20 @@ class BrainPlanner:
|
|||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> Tuple[str, List[ActionPlannerInfo]]:
|
||||
) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
extracted_reasoning = ""
|
||||
llm_reasoning = None
|
||||
llm_duration_ms = None
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
|
|
@ -514,7 +542,7 @@ class BrainPlanner:
|
|||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
]
|
||||
], llm_content, llm_reasoning, llm_duration_ms
|
||||
|
||||
# 解析LLM响应
|
||||
if llm_content:
|
||||
|
|
@ -553,7 +581,7 @@ class BrainPlanner:
|
|||
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
|
||||
)
|
||||
|
||||
return extracted_reasoning, actions
|
||||
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms
|
||||
|
||||
def _create_complete_talk(
|
||||
self, reasoning: str, available_actions: Dict[str, ActionInfo]
|
||||
|
|
|
|||
|
|
@ -382,7 +382,7 @@ class EmojiManager:
|
|||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji.see")
|
||||
self.llm_emotion_judge = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="emoji"
|
||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||
)
|
||||
|
||||
self.emoji_num = 0
|
||||
self.emoji_num_max = global_config.emoji.max_reg_num
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from src.chat.utils.chat_message_builder import (
|
|||
get_raw_msg_before_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import record_replyer_action_temp
|
||||
from src.hippo_memorizer.chat_history_summarizer import ChatHistorySummarizer
|
||||
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
|
@ -244,12 +244,14 @@ class HeartFChatting:
|
|||
thinking_id,
|
||||
actions,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
|
|
@ -526,15 +528,26 @@ class HeartFChatting:
|
|||
reply_set: "ReplySetModel",
|
||||
message_data: "DatabaseMessages",
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
quote_message: Optional[bool] = None,
|
||||
) -> str:
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
|
||||
need_reply = new_message_count >= random.randint(2, 3)
|
||||
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
|
||||
# 根据 llm_quote 配置决定是否使用 quote_message 参数
|
||||
if global_config.chat.llm_quote:
|
||||
# 如果配置为 true,使用 llm_quote 参数决定是否引用回复
|
||||
if quote_message is None:
|
||||
logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用")
|
||||
need_reply = False
|
||||
else:
|
||||
need_reply = quote_message
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
|
||||
else:
|
||||
# 如果配置为 false,使用原来的模式
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
need_reply = new_message_count >= random.randint(2, 3)
|
||||
if need_reply:
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
|
|
@ -640,6 +653,7 @@ class HeartFChatting:
|
|||
|
||||
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
|
||||
unknown_words = None
|
||||
quote_message = None
|
||||
if isinstance(action_planner_info.action_data, dict):
|
||||
uw = action_planner_info.action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
|
|
@ -652,6 +666,19 @@ class HeartFChatting:
|
|||
if cleaned_uw:
|
||||
unknown_words = cleaned_uw
|
||||
|
||||
# 从 Planner 的 action_data 中提取 quote_message 参数
|
||||
qm = action_planner_info.action_data.get("quote")
|
||||
if qm is not None:
|
||||
# 支持多种格式:true/false, "true"/"false", 1/0
|
||||
if isinstance(qm, bool):
|
||||
quote_message = qm
|
||||
elif isinstance(qm, str):
|
||||
quote_message = qm.lower() in ("true", "1", "yes")
|
||||
elif isinstance(qm, (int, float)):
|
||||
quote_message = bool(qm)
|
||||
|
||||
logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
|
||||
|
||||
success, llm_response = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message=action_planner_info.action_message,
|
||||
|
|
@ -682,12 +709,13 @@ class HeartFChatting:
|
|||
thinking_id=thinking_id,
|
||||
actions=chosen_action_plan_infos,
|
||||
selected_expressions=selected_expressions,
|
||||
quote_message=quote_message,
|
||||
)
|
||||
self.last_active_time = time.time()
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"result": f"你回复内容{reply_text}",
|
||||
"result": f"你使用reply动作,对' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'",
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,9 +7,8 @@ from .kg_manager import KGManager
|
|||
|
||||
# from .lpmmconfig import global_config
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import global_config
|
||||
|
||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||
|
||||
|
|
@ -22,7 +21,6 @@ class QAManager:
|
|||
):
|
||||
self.embed_manager = embed_manager
|
||||
self.kg_manager = kg_manager
|
||||
self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
|
||||
|
||||
async def process_query(
|
||||
self, question: str
|
||||
|
|
|
|||
|
|
@ -0,0 +1,139 @@
|
|||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
class PlanReplyLogger:
|
||||
"""独立的Plan/Reply日志记录器,负责落盘和容量控制。"""
|
||||
|
||||
_BASE_DIR = Path("logs")
|
||||
_PLAN_DIR = _BASE_DIR / "plan"
|
||||
_REPLY_DIR = _BASE_DIR / "reply"
|
||||
_TRIM_COUNT = 100
|
||||
|
||||
@classmethod
|
||||
def _get_max_per_chat(cls) -> int:
|
||||
"""从配置中获取每个聊天流最大保存的日志数量"""
|
||||
return getattr(global_config.chat, "plan_reply_log_max_per_chat", 1000)
|
||||
|
||||
@classmethod
|
||||
def log_plan(
|
||||
cls,
|
||||
chat_id: str,
|
||||
prompt: str,
|
||||
reasoning: str,
|
||||
raw_output: Optional[str],
|
||||
raw_reasoning: Optional[str],
|
||||
actions: List[Any],
|
||||
timing: Optional[Dict[str, Any]] = None,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
payload = {
|
||||
"type": "plan",
|
||||
"chat_id": chat_id,
|
||||
"timestamp": time.time(),
|
||||
"prompt": prompt,
|
||||
"reasoning": reasoning,
|
||||
"raw_output": raw_output,
|
||||
"raw_reasoning": raw_reasoning,
|
||||
"actions": [cls._serialize_action(action) for action in actions],
|
||||
"timing": timing or {},
|
||||
"extra": cls._safe_data(extra),
|
||||
}
|
||||
cls._write_json(cls._PLAN_DIR, chat_id, payload)
|
||||
|
||||
@classmethod
|
||||
def log_reply(
|
||||
cls,
|
||||
chat_id: str,
|
||||
prompt: str,
|
||||
output: Optional[str],
|
||||
processed_output: Optional[List[Any]],
|
||||
model: Optional[str],
|
||||
timing: Optional[Dict[str, Any]] = None,
|
||||
reasoning: Optional[str] = None,
|
||||
think_level: Optional[int] = None,
|
||||
error: Optional[str] = None,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
payload = {
|
||||
"type": "reply",
|
||||
"chat_id": chat_id,
|
||||
"timestamp": time.time(),
|
||||
"prompt": prompt,
|
||||
"output": output,
|
||||
"processed_output": cls._safe_data(processed_output),
|
||||
"model": model,
|
||||
"reasoning": reasoning,
|
||||
"think_level": think_level,
|
||||
"timing": timing or {},
|
||||
"error": error if not success else None,
|
||||
"success": success,
|
||||
}
|
||||
cls._write_json(cls._REPLY_DIR, chat_id, payload)
|
||||
|
||||
@classmethod
|
||||
def _write_json(cls, base_dir: Path, chat_id: str, payload: Dict[str, Any]) -> None:
|
||||
chat_dir = base_dir / chat_id
|
||||
chat_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = chat_dir / f"{int(time.time() * 1000)}_{uuid4().hex[:8]}.json"
|
||||
try:
|
||||
with file_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(cls._safe_data(payload), f, ensure_ascii=False, indent=2)
|
||||
finally:
|
||||
cls._trim_overflow(chat_dir)
|
||||
|
||||
@classmethod
|
||||
def _trim_overflow(cls, chat_dir: Path) -> None:
|
||||
"""超过阈值时删除最老的若干文件,避免目录无限增长。"""
|
||||
files = sorted(chat_dir.glob("*.json"), key=lambda p: p.stat().st_mtime)
|
||||
max_per_chat = cls._get_max_per_chat()
|
||||
if len(files) <= max_per_chat:
|
||||
return
|
||||
# 删除最老的 TRIM_COUNT 条
|
||||
for old_file in files[: cls._TRIM_COUNT]:
|
||||
try:
|
||||
old_file.unlink()
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
@classmethod
|
||||
def _serialize_action(cls, action: Any) -> Dict[str, Any]:
|
||||
# ActionPlannerInfo 结构的轻量序列化,避免引用复杂对象
|
||||
message_info = None
|
||||
action_message = getattr(action, "action_message", None)
|
||||
if action_message:
|
||||
user_info = getattr(action_message, "user_info", None)
|
||||
message_info = {
|
||||
"message_id": getattr(action_message, "message_id", None),
|
||||
"user_id": getattr(user_info, "user_id", None) if user_info else None,
|
||||
"platform": getattr(user_info, "platform", None) if user_info else None,
|
||||
"text": getattr(action_message, "processed_plain_text", None),
|
||||
}
|
||||
|
||||
return {
|
||||
"action_type": getattr(action, "action_type", None),
|
||||
"reasoning": getattr(action, "reasoning", None),
|
||||
"action_data": cls._safe_data(getattr(action, "action_data", None)),
|
||||
"action_message": message_info,
|
||||
"available_actions": cls._safe_data(getattr(action, "available_actions", None)),
|
||||
"action_reasoning": getattr(action, "action_reasoning", None),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _safe_data(cls, value: Any) -> Any:
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return {str(k): cls._safe_data(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [cls._safe_data(v) for v in value]
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
# Fallback to string for other complex types
|
||||
return str(value)
|
||||
|
||||
|
|
@ -1,12 +1,9 @@
|
|||
import random
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from typing import List, Dict, TYPE_CHECKING, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
|
|
@ -35,14 +32,6 @@ class ActionModifier:
|
|||
|
||||
self.action_manager = action_manager
|
||||
|
||||
# 用于LLM判定的小模型
|
||||
self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge")
|
||||
|
||||
# 缓存相关属性
|
||||
self._llm_judge_cache = {} # 缓存LLM判定结果
|
||||
self._cache_expiry_time = 30 # 缓存过期时间(秒)
|
||||
self._last_context_hash = None # 上次上下文的哈希值
|
||||
|
||||
async def modify_actions(
|
||||
self,
|
||||
message_content: str = "",
|
||||
|
|
@ -159,9 +148,6 @@ class ActionModifier:
|
|||
"""
|
||||
deactivated_actions = []
|
||||
|
||||
# 分类处理不同激活类型的actions
|
||||
llm_judge_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
actions_to_check = list(actions_with_info.items())
|
||||
random.shuffle(actions_to_check)
|
||||
|
||||
|
|
@ -185,9 +171,6 @@ class ActionModifier:
|
|||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||
|
||||
elif activation_type == ActionActivationType.LLM_JUDGE:
|
||||
llm_judge_actions[action_name] = action_info
|
||||
|
||||
elif activation_type == ActionActivationType.NEVER:
|
||||
reason = "激活类型为never"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
|
|
@ -196,194 +179,8 @@ class ActionModifier:
|
|||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
# 并行处理LLM_JUDGE类型
|
||||
if llm_judge_actions:
|
||||
llm_results = await self._process_llm_judge_actions_parallel(
|
||||
llm_judge_actions,
|
||||
chat_content,
|
||||
)
|
||||
for action_name, should_activate in llm_results.items():
|
||||
if not should_activate:
|
||||
reason = "LLM判定未激活"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||
|
||||
return deactivated_actions
|
||||
|
||||
def _generate_context_hash(self, chat_content: str) -> str:
|
||||
"""生成上下文的哈希值用于缓存"""
|
||||
context_content = f"{chat_content}"
|
||||
return hashlib.md5(context_content.encode("utf-8")).hexdigest()
|
||||
|
||||
async def _process_llm_judge_actions_parallel(
|
||||
self,
|
||||
llm_judge_actions: Dict[str, ActionInfo],
|
||||
chat_content: str = "",
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
并行处理LLM判定actions,支持智能缓存
|
||||
|
||||
Args:
|
||||
llm_judge_actions: 需要LLM判定的actions
|
||||
chat_content: 聊天内容
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: action名称到激活结果的映射
|
||||
"""
|
||||
|
||||
# 生成当前上下文的哈希值
|
||||
current_context_hash = self._generate_context_hash(chat_content)
|
||||
current_time = time.time()
|
||||
|
||||
results = {}
|
||||
tasks_to_run: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 检查缓存
|
||||
for action_name, action_info in llm_judge_actions.items():
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
|
||||
# 检查是否有有效的缓存
|
||||
if (
|
||||
cache_key in self._llm_judge_cache
|
||||
and current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time
|
||||
):
|
||||
results[action_name] = self._llm_judge_cache[cache_key]["result"]
|
||||
logger.debug(
|
||||
f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}"
|
||||
)
|
||||
else:
|
||||
# 需要进行LLM判定
|
||||
tasks_to_run[action_name] = action_info
|
||||
|
||||
# 如果有需要运行的任务,并行执行
|
||||
if tasks_to_run:
|
||||
logger.debug(f"{self.log_prefix}并行执行LLM判定,任务数: {len(tasks_to_run)}")
|
||||
|
||||
# 创建并行任务
|
||||
tasks = []
|
||||
task_names = []
|
||||
|
||||
for action_name, action_info in tasks_to_run.items():
|
||||
task = self._llm_judge_action(
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
)
|
||||
tasks.append(task)
|
||||
task_names.append(action_name)
|
||||
|
||||
# 并行执行所有任务
|
||||
try:
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果并更新缓存
|
||||
for action_name, result in zip(task_names, task_results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
||||
results[action_name] = False
|
||||
else:
|
||||
results[action_name] = result
|
||||
|
||||
# 更新缓存
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
self._llm_judge_cache[cache_key] = {"result": result, "timestamp": current_time}
|
||||
|
||||
logger.debug(f"{self.log_prefix}并行LLM判定完成,耗时: {time.time() - current_time:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
|
||||
# 如果并行执行失败,为所有任务返回False
|
||||
for action_name in tasks_to_run:
|
||||
results[action_name] = False
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_expired_cache(current_time)
|
||||
|
||||
return results
|
||||
|
||||
def _cleanup_expired_cache(self, current_time: float):
|
||||
"""清理过期的缓存条目"""
|
||||
expired_keys = []
|
||||
expired_keys.extend(
|
||||
cache_key
|
||||
for cache_key, cache_data in self._llm_judge_cache.items()
|
||||
if current_time - cache_data["timestamp"] > self._cache_expiry_time
|
||||
)
|
||||
for key in expired_keys:
|
||||
del self._llm_judge_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了 {len(expired_keys)} 个过期缓存条目")
|
||||
|
||||
async def _llm_judge_action(
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: ActionInfo,
|
||||
chat_content: str = "",
|
||||
) -> bool: # sourcery skip: move-assign-in-block, use-named-expression
|
||||
"""
|
||||
使用LLM判定是否应该激活某个action
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
try:
|
||||
# 构建判定提示词
|
||||
action_description = action_info.description
|
||||
action_require = action_info.action_require
|
||||
custom_prompt = action_info.llm_judge_prompt
|
||||
|
||||
# 构建基础判定提示词
|
||||
base_prompt = f"""
|
||||
你需要判断在当前聊天情况下,是否应该激活名为"{action_name}"的动作。
|
||||
|
||||
动作描述:{action_description}
|
||||
|
||||
动作使用场景:
|
||||
"""
|
||||
for req in action_require:
|
||||
base_prompt += f"- {req}\n"
|
||||
|
||||
if custom_prompt:
|
||||
base_prompt += f"\n额外判定条件:\n{custom_prompt}\n"
|
||||
|
||||
if chat_content:
|
||||
base_prompt += f"\n当前聊天记录:\n{chat_content}\n"
|
||||
|
||||
base_prompt += """
|
||||
请根据以上信息判断是否应该激活这个动作。
|
||||
只需要回答"是"或"否",不要有其他内容。
|
||||
"""
|
||||
|
||||
# 调用LLM进行判定
|
||||
response, _ = await self.llm_judge.generate_response_async(prompt=base_prompt)
|
||||
|
||||
# 解析响应
|
||||
response = response.strip().lower()
|
||||
|
||||
# print(base_prompt)
|
||||
# print(f"LLM判定动作 {action_name}:响应='{response}'")
|
||||
|
||||
should_activate = "是" in response or "yes" in response or "true" in response
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}"
|
||||
)
|
||||
return should_activate
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}LLM判定动作 {action_name} 时出错: {e}")
|
||||
# 出错时默认不激活
|
||||
return False
|
||||
|
||||
def _check_keyword_activation(
|
||||
self,
|
||||
action_name: str,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from json_repair import repair_json
|
|||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
|
|
@ -17,7 +18,7 @@ from src.chat.utils.chat_message_builder import (
|
|||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
||||
|
|
@ -52,7 +53,7 @@ reply
|
|||
4.不要选择回复你自己发送的消息
|
||||
5.不要单独对表情包进行回复
|
||||
6.将上下文中所有含义不明的,疑似黑话的,缩写词均写入unknown_words中
|
||||
7.用一句简单的话来描述当前回复场景,不超过10个字
|
||||
7.如果你对上下文存在疑问,有需要查询的问题,写入question中
|
||||
{reply_action_example}
|
||||
|
||||
no_reply
|
||||
|
|
@ -223,6 +224,25 @@ class ActionPlanner:
|
|||
else:
|
||||
reasoning = "未提供原因"
|
||||
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
|
||||
|
||||
# 验证和清理 question
|
||||
if "question" in action_data:
|
||||
q = action_data.get("question")
|
||||
if isinstance(q, str):
|
||||
cleaned_q = q.strip()
|
||||
if cleaned_q:
|
||||
action_data["question"] = cleaned_q
|
||||
else:
|
||||
# 如果清理后为空字符串,移除该字段
|
||||
action_data.pop("question", None)
|
||||
elif q is None:
|
||||
# 如果为 None,移除该字段
|
||||
action_data.pop("question", None)
|
||||
else:
|
||||
# 如果不是字符串类型,记录警告并移除
|
||||
logger.warning(f"{self.log_prefix}question 格式不正确,应为字符串类型,已忽略")
|
||||
action_data.pop("question", None)
|
||||
|
||||
# 非no_reply动作需要target_message_id
|
||||
target_message = None
|
||||
|
||||
|
|
@ -291,11 +311,9 @@ class ActionPlanner:
|
|||
return action_planner_infos
|
||||
|
||||
def _is_message_from_self(self, message: "DatabaseMessages") -> bool:
|
||||
"""判断消息是否由机器人自身发送"""
|
||||
"""判断消息是否由机器人自身发送(支持多平台,包括 WebUI)"""
|
||||
try:
|
||||
return str(message.user_info.user_id) == str(global_config.bot.qq_account) and (
|
||||
message.user_info.platform or ""
|
||||
) == (global_config.bot.platform or "")
|
||||
return is_bot_self(message.user_info.platform or "", str(message.user_info.user_id))
|
||||
except AttributeError:
|
||||
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
|
||||
return False
|
||||
|
|
@ -310,6 +328,7 @@ class ActionPlanner:
|
|||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
plan_start = time.perf_counter()
|
||||
|
||||
# 获取聊天上下文
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
|
|
@ -345,6 +364,7 @@ class ActionPlanner:
|
|||
|
||||
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
|
||||
|
||||
prompt_build_start = time.perf_counter()
|
||||
# 构建包含所有动作的提示词
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
|
|
@ -353,9 +373,10 @@ class ActionPlanner:
|
|||
chat_content_block=chat_content_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
prompt_build_ms = (time.perf_counter() - prompt_build_start) * 1000
|
||||
|
||||
# 调用LLM获取决策
|
||||
reasoning, actions = await self._execute_main_planner(
|
||||
reasoning, actions, llm_raw_output, llm_reasoning, llm_duration_ms = await self._execute_main_planner(
|
||||
prompt=prompt,
|
||||
message_id_list=message_id_list,
|
||||
filtered_actions=filtered_actions,
|
||||
|
|
@ -397,6 +418,25 @@ class ActionPlanner:
|
|||
|
||||
self.add_plan_log(reasoning, actions)
|
||||
|
||||
try:
|
||||
PlanReplyLogger.log_plan(
|
||||
chat_id=self.chat_id,
|
||||
prompt=prompt,
|
||||
reasoning=reasoning,
|
||||
raw_output=llm_raw_output,
|
||||
raw_reasoning=llm_reasoning,
|
||||
actions=actions,
|
||||
timing={
|
||||
"prompt_build_ms": round(prompt_build_ms, 2),
|
||||
"llm_duration_ms": round(llm_duration_ms, 2) if llm_duration_ms is not None else None,
|
||||
"total_plan_ms": round((time.perf_counter() - plan_start) * 1000, 2),
|
||||
"loop_start_time": loop_start_time,
|
||||
},
|
||||
extra=None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix}记录plan日志失败")
|
||||
|
||||
return actions
|
||||
|
||||
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
|
||||
|
|
@ -480,19 +520,34 @@ class ActionPlanner:
|
|||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
# 根据 think_mode 配置决定 reply action 的示例 JSON
|
||||
# 在 JSON 中直接作为 action 参数携带 unknown_words
|
||||
# 在 JSON 中直接作为 action 参数携带 unknown_words 和 question
|
||||
if global_config.chat.think_mode == "classic":
|
||||
reply_action_example = (
|
||||
reply_action_example = ""
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += "5.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]}}'
|
||||
'"unknown_words":["词语1","词语2"], '
|
||||
'"question":"需要查询的问题"'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += "}"
|
||||
else:
|
||||
reply_action_example = (
|
||||
"5.think_level表示思考深度,0表示该回复不需要思考和回忆,1表示该回复需要进行回忆和思考\n"
|
||||
+ '{{"action":"reply", "think_level":数值等级(0或1), '
|
||||
'"target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"]}}'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += "6.如果要明确回复消息,使用quote,如果消息不多不需要明确回复,设置quote为false\n"
|
||||
reply_action_example += (
|
||||
'{{"action":"reply", "think_level":数值等级(0或1), '
|
||||
'"target_message_id":"消息id(m+数字)", '
|
||||
'"unknown_words":["词语1","词语2"], '
|
||||
'"question":"需要查询的问题"'
|
||||
)
|
||||
if global_config.chat.llm_quote:
|
||||
reply_action_example += ', "quote":"如果需要引用该message,设置为true"'
|
||||
reply_action_example += "}"
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
|
|
@ -547,7 +602,7 @@ class ActionPlanner:
|
|||
if action_info.activation_type == ActionActivationType.NEVER:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 设置为 NEVER 激活类型,跳过")
|
||||
continue
|
||||
elif action_info.activation_type in [ActionActivationType.LLM_JUDGE, ActionActivationType.ALWAYS]:
|
||||
elif action_info.activation_type == ActionActivationType.ALWAYS:
|
||||
filtered_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.RANDOM:
|
||||
if random.random() < action_info.random_activation_probability:
|
||||
|
|
@ -610,14 +665,19 @@ class ActionPlanner:
|
|||
filtered_actions: Dict[str, ActionInfo],
|
||||
available_actions: Dict[str, ActionInfo],
|
||||
loop_start_time: float,
|
||||
) -> Tuple[str, List[ActionPlannerInfo]]:
|
||||
) -> Tuple[str, List[ActionPlannerInfo], Optional[str], Optional[str], Optional[float]]:
|
||||
"""执行主规划器"""
|
||||
llm_content = None
|
||||
actions: List[ActionPlannerInfo] = []
|
||||
llm_reasoning = None
|
||||
llm_duration_ms = None
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
llm_start = time.perf_counter()
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
llm_reasoning = reasoning_content
|
||||
|
||||
if global_config.debug.show_planner_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
|
|
@ -640,7 +700,7 @@ class ActionPlanner:
|
|||
action_message=None,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
]
|
||||
], llm_content, llm_reasoning, llm_duration_ms
|
||||
|
||||
# 解析LLM响应
|
||||
extracted_reasoning = ""
|
||||
|
|
@ -685,7 +745,7 @@ class ActionPlanner:
|
|||
|
||||
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||
|
||||
return extracted_reasoning, actions
|
||||
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms
|
||||
|
||||
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
||||
"""创建no_reply"""
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, Message
|
|||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
|
|
@ -31,6 +31,7 @@ from src.person_info.person_info import Person
|
|||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
|
||||
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
||||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||
|
|
@ -74,6 +75,7 @@ class DefaultReplyer:
|
|||
reply_time_point: Optional[float] = time.time(),
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
log_reply: bool = True,
|
||||
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
|
|
@ -92,6 +94,9 @@ class DefaultReplyer:
|
|||
Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: (是否成功, 生成的回复, 使用的prompt)
|
||||
"""
|
||||
|
||||
overall_start = time.perf_counter()
|
||||
prompt_duration_ms: Optional[float] = None
|
||||
llm_duration_ms: Optional[float] = None
|
||||
prompt = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
llm_response = LLMGenerationDataModel()
|
||||
|
|
@ -101,6 +106,7 @@ class DefaultReplyer:
|
|||
# 3. 构建 Prompt
|
||||
timing_logs = []
|
||||
almost_zero_str = ""
|
||||
prompt_start = time.perf_counter()
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt, selected_expressions, timing_logs, almost_zero_str = await self.build_prompt_reply_context(
|
||||
extra_info=extra_info,
|
||||
|
|
@ -113,11 +119,37 @@ class DefaultReplyer:
|
|||
think_level=think_level,
|
||||
unknown_words=unknown_words,
|
||||
)
|
||||
prompt_duration_ms = (time.perf_counter() - prompt_start) * 1000
|
||||
llm_response.prompt = prompt
|
||||
llm_response.selected_expressions = selected_expressions
|
||||
llm_response.timing = {
|
||||
"prompt_ms": round(prompt_duration_ms or 0.0, 2),
|
||||
"overall_ms": None, # 占位,稍后写入
|
||||
}
|
||||
llm_response.timing_logs = timing_logs
|
||||
llm_response.timing["timing_logs"] = timing_logs
|
||||
|
||||
if not prompt:
|
||||
logger.warning("构建prompt失败,跳过回复生成")
|
||||
llm_response.timing["overall_ms"] = round((time.perf_counter() - overall_start) * 1000, 2)
|
||||
llm_response.timing["almost_zero"] = almost_zero_str
|
||||
llm_response.timing["timing_logs"] = timing_logs
|
||||
if log_reply:
|
||||
try:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
prompt="",
|
||||
output=None,
|
||||
processed_output=None,
|
||||
model=None,
|
||||
timing=llm_response.timing,
|
||||
reasoning=None,
|
||||
think_level=think_level,
|
||||
error="build_prompt_failed",
|
||||
success=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("记录reply日志失败")
|
||||
return False, llm_response
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
|
||||
|
|
@ -137,7 +169,9 @@ class DefaultReplyer:
|
|||
model_name = "unknown_model"
|
||||
|
||||
try:
|
||||
llm_start = time.perf_counter()
|
||||
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
|
||||
llm_duration_ms = (time.perf_counter() - llm_start) * 1000
|
||||
# logger.debug(f"replyer生成内容: {content}")
|
||||
|
||||
# 统一输出所有日志信息,使用try-except确保即使某个步骤出错也能输出
|
||||
|
|
@ -161,6 +195,26 @@ class DefaultReplyer:
|
|||
llm_response.reasoning = reasoning_content
|
||||
llm_response.model = model_name
|
||||
llm_response.tool_calls = tool_call
|
||||
llm_response.timing["llm_ms"] = round(llm_duration_ms or 0.0, 2)
|
||||
llm_response.timing["overall_ms"] = round((time.perf_counter() - overall_start) * 1000, 2)
|
||||
llm_response.timing_logs = timing_logs
|
||||
llm_response.timing["timing_logs"] = timing_logs
|
||||
llm_response.timing["almost_zero"] = almost_zero_str
|
||||
try:
|
||||
if log_reply:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
prompt=prompt,
|
||||
output=content,
|
||||
processed_output=None,
|
||||
model=model_name,
|
||||
timing=llm_response.timing,
|
||||
reasoning=reasoning_content,
|
||||
think_level=think_level,
|
||||
success=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("记录reply日志失败")
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
||||
)
|
||||
|
|
@ -194,6 +248,27 @@ class DefaultReplyer:
|
|||
except Exception as log_e:
|
||||
logger.warning(f"输出日志时出错: {log_e}")
|
||||
|
||||
llm_response.timing["llm_ms"] = round(llm_duration_ms or 0.0, 2)
|
||||
llm_response.timing["overall_ms"] = round((time.perf_counter() - overall_start) * 1000, 2)
|
||||
llm_response.timing_logs = timing_logs
|
||||
llm_response.timing["timing_logs"] = timing_logs
|
||||
llm_response.timing["almost_zero"] = almost_zero_str
|
||||
if log_reply:
|
||||
try:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
prompt=prompt or "",
|
||||
output=None,
|
||||
processed_output=None,
|
||||
model=model_name,
|
||||
timing=llm_response.timing,
|
||||
reasoning=None,
|
||||
think_level=think_level,
|
||||
error=str(llm_e),
|
||||
success=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("记录reply日志失败")
|
||||
return False, llm_response # LLM 调用失败则无法生成回复
|
||||
|
||||
return True, llm_response
|
||||
|
|
@ -541,107 +616,6 @@ class DefaultReplyer:
|
|||
logger.error(f"上下文黑话解释失败: {e}")
|
||||
return ""
|
||||
|
||||
def build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
target_user_id: 目标用户ID(当前对话对象)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
||||
"""
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt = build_readable_messages(
|
||||
latest_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
|
||||
return all_dialogue_prompt
|
||||
|
||||
def core_background_build_chat_history_prompts(
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
message_list_before_now: 历史消息列表
|
||||
target_user_id: 目标用户ID(当前对话对象)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
||||
"""
|
||||
core_dialogue_list: List[DatabaseMessages] = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
|
||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
||||
for msg in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg.user_info.user_id)
|
||||
reply_to = msg.reply_to
|
||||
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
# bot 和目标用户的对话
|
||||
core_dialogue_list.append(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
|
||||
|
||||
# 构建核心对话 prompt
|
||||
core_dialogue_prompt = ""
|
||||
if core_dialogue_list:
|
||||
# 检查最新五条消息中是否包含bot自己说的消息
|
||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||
has_bot_message = any(str(msg.user_info.user_id) == bot_id for msg in latest_5_messages)
|
||||
|
||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
||||
|
||||
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
||||
if not has_bot_message:
|
||||
core_dialogue_prompt = ""
|
||||
else:
|
||||
core_dialogue_list = core_dialogue_list[
|
||||
-int(global_config.chat.max_context_size * 0.6) :
|
||||
] # 限制消息数量
|
||||
|
||||
core_dialogue_prompt_str = build_readable_messages(
|
||||
core_dialogue_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = f"""--------------------------------
|
||||
这是上述中你和{sender}的对话摘要,内容从上面的对话中截取,便于你理解:
|
||||
{core_dialogue_prompt_str}
|
||||
--------------------------------
|
||||
"""
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
)
|
||||
if core_dialogue_prompt:
|
||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
||||
else:
|
||||
all_dialogue_prompt = f"{all_dialogue_prompt_str}"
|
||||
|
||||
return core_dialogue_prompt, all_dialogue_prompt
|
||||
|
||||
async def build_actions_prompt(
|
||||
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
|
||||
) -> str:
|
||||
|
|
@ -841,10 +815,8 @@ class DefaultReplyer:
|
|||
|
||||
person_list_short: List[Person] = []
|
||||
for msg in message_list_before_short:
|
||||
if (
|
||||
global_config.bot.qq_account == msg.user_info.user_id
|
||||
and global_config.bot.platform == msg.user_info.platform
|
||||
):
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
|
||||
continue
|
||||
if (
|
||||
reply_message
|
||||
|
|
@ -865,6 +837,7 @@ class DefaultReplyer:
|
|||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
long_time_notice=True,
|
||||
)
|
||||
|
||||
# 统一黑话解释构建:根据配置选择上下文或 Planner 模式
|
||||
|
|
@ -872,6 +845,18 @@ class DefaultReplyer:
|
|||
chat_id, message_list_before_short, chat_talking_prompt_short, unknown_words
|
||||
)
|
||||
|
||||
# 从 chosen_actions 中提取 question(仅在 reply 动作中)
|
||||
question = None
|
||||
if chosen_actions:
|
||||
for action_info in chosen_actions:
|
||||
if action_info.action_type == "reply" and isinstance(action_info.action_data, dict):
|
||||
q = action_info.action_data.get("question")
|
||||
if isinstance(q, str):
|
||||
cleaned_q = q.strip()
|
||||
if cleaned_q:
|
||||
question = cleaned_q
|
||||
break
|
||||
|
||||
# 并行执行构建任务(包括黑话解释,可配置关闭)
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
|
|
@ -886,7 +871,7 @@ class DefaultReplyer:
|
|||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(
|
||||
build_memory_retrieval_prompt(
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=think_level
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=think_level, unknown_words=unknown_words, question=question
|
||||
),
|
||||
"memory_retrieval",
|
||||
),
|
||||
|
|
@ -960,8 +945,16 @@ class DefaultReplyer:
|
|||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
dialogue_prompt = self.build_chat_history_prompts(message_list_before_now_long, user_id, sender)
|
||||
|
||||
if message_list_before_now_long:
|
||||
latest_msgs = message_list_before_now_long[-int(global_config.chat.max_context_size) :]
|
||||
dialogue_prompt = build_readable_messages(
|
||||
latest_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=True,
|
||||
long_time_notice=True,
|
||||
)
|
||||
|
||||
# 获取匹配的额外prompt
|
||||
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, Message
|
|||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
|
|
@ -76,6 +76,7 @@ class PrivateReplyer:
|
|||
reply_message: Optional[DatabaseMessages] = None,
|
||||
reply_time_point: Optional[float] = time.time(),
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
log_reply: bool = True,
|
||||
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
|
|
@ -109,6 +110,7 @@ class PrivateReplyer:
|
|||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
unknown_words=unknown_words,
|
||||
)
|
||||
llm_response.prompt = prompt
|
||||
llm_response.selected_expressions = selected_expressions
|
||||
|
|
@ -610,6 +612,7 @@ class PrivateReplyer:
|
|||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
) -> Tuple[str, List[int]]:
|
||||
"""
|
||||
构建回复器上下文
|
||||
|
|
@ -664,6 +667,7 @@ class PrivateReplyer:
|
|||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
long_time_notice=True
|
||||
)
|
||||
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
|
|
@ -675,10 +679,8 @@ class PrivateReplyer:
|
|||
|
||||
person_list_short: List[Person] = []
|
||||
for msg in message_list_before_short:
|
||||
if (
|
||||
global_config.bot.qq_account == msg.user_info.user_id
|
||||
and global_config.bot.platform == msg.user_info.platform
|
||||
):
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
|
||||
continue
|
||||
if (
|
||||
reply_message
|
||||
|
|
@ -708,12 +710,24 @@ class PrivateReplyer:
|
|||
else:
|
||||
jargon_coroutine = self._build_disabled_jargon_explanation()
|
||||
|
||||
# 从 chosen_actions 中提取 question(仅在 reply 动作中)
|
||||
question = None
|
||||
if chosen_actions:
|
||||
for action_info in chosen_actions:
|
||||
if action_info.action_type == "reply" and isinstance(action_info.action_data, dict):
|
||||
q = action_info.action_data.get("question")
|
||||
if isinstance(q, str):
|
||||
cleaned_q = q.strip()
|
||||
if cleaned_q:
|
||||
question = cleaned_q
|
||||
break
|
||||
|
||||
# 并行执行九个构建任务(包括黑话解释,可配置关闭)
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
|
||||
),
|
||||
self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"),
|
||||
# self._time_and_run_task(self.build_relation_info(chat_talking_prompt_short, sender), "relation_info"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
),
|
||||
|
|
@ -722,7 +736,7 @@ class PrivateReplyer:
|
|||
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
|
||||
self._time_and_run_task(
|
||||
build_memory_retrieval_prompt(
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
|
||||
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=1, unknown_words=unknown_words, question=question
|
||||
),
|
||||
"memory_retrieval",
|
||||
),
|
||||
|
|
@ -759,7 +773,7 @@ class PrivateReplyer:
|
|||
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
||||
expression_habits_block: str
|
||||
selected_expressions: List[int]
|
||||
relation_info: str = results_dict["relation_info"]
|
||||
relation_info: str = results_dict.get("relation_info") or ""
|
||||
tool_info: str = results_dict["tool_info"]
|
||||
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||
actions_info: str = results_dict["actions_info"]
|
||||
|
|
@ -796,7 +810,19 @@ class PrivateReplyer:
|
|||
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)
|
||||
chat_prompt_block = f"{chat_prompt_content}\n" if chat_prompt_content else ""
|
||||
|
||||
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||
reply_style = global_config.personality.reply_style
|
||||
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
|
||||
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
|
||||
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
|
||||
try:
|
||||
reply_style = random.choice(list(multi_styles))
|
||||
except Exception:
|
||||
# 兜底:即使 multiple_reply_style 配置异常也不影响正常回复
|
||||
reply_style = global_config.personality.reply_style
|
||||
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if is_bot_self(platform, user_id):
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"private_replyer_self_prompt",
|
||||
expression_habits_block=expression_habits_block,
|
||||
|
|
@ -812,7 +838,7 @@ class PrivateReplyer:
|
|||
target=target,
|
||||
reason=reply_reason,
|
||||
sender_name=sender,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
reply_style=reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
memory_retrieval=memory_retrieval,
|
||||
|
|
@ -832,7 +858,7 @@ class PrivateReplyer:
|
|||
jargon_explanation=jargon_explanation,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
reply_style=reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
sender_name=sender,
|
||||
|
|
@ -917,6 +943,17 @@ class PrivateReplyer:
|
|||
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
# 根据配置构建最终的 reply_style:支持 multiple_reply_style 按概率随机替换
|
||||
reply_style = global_config.personality.reply_style
|
||||
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
|
||||
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
|
||||
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
|
||||
try:
|
||||
reply_style = random.choice(list(multi_styles))
|
||||
except Exception:
|
||||
# 兜底:即使 multiple_reply_style 配置异常也不影响正常回复
|
||||
reply_style = global_config.personality.reply_style
|
||||
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
|
|
@ -929,7 +966,7 @@ class PrivateReplyer:
|
|||
reply_target_block=reply_target_block,
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
reply_style=reply_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from src.common.data_models.message_data_model import MessageAndActionModel
|
|||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.person_info.person_info import Person, get_person_id
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self
|
||||
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("chat_message_builder")
|
||||
|
|
@ -43,12 +43,9 @@ def replace_user_references(
|
|||
if name_resolver is None:
|
||||
|
||||
def default_resolver(platform: str, user_id: str) -> str:
|
||||
# 检查是否是机器人自己(支持多平台)
|
||||
if replace_bot_name:
|
||||
if platform == "qq" and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
if platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
# 检查是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if replace_bot_name and is_bot_self(platform, user_id):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_id # type: ignore
|
||||
|
||||
|
|
@ -61,8 +58,8 @@ def replace_user_references(
|
|||
aaa = match[1]
|
||||
bbb = match[2]
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and bbb == global_config.bot.qq_account:
|
||||
# 检查是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if replace_bot_name and is_bot_self(platform, bbb):
|
||||
reply_person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
reply_person_name = name_resolver(platform, bbb) or aaa
|
||||
|
|
@ -370,6 +367,7 @@ def _build_readable_messages_internal(
|
|||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
|
||||
pic_single: bool = False,
|
||||
long_time_notice: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
# sourcery skip: use-getitem-for-re-match-groups
|
||||
"""
|
||||
|
|
@ -467,10 +465,8 @@ def _build_readable_messages_internal(
|
|||
person_name = (
|
||||
person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
|
||||
)
|
||||
if replace_bot_name and (
|
||||
(platform == global_config.bot.platform and user_id == global_config.bot.qq_account)
|
||||
or (platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""))
|
||||
):
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if replace_bot_name and is_bot_self(platform, user_id):
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
|
||||
# 使用独立函数处理用户引用格式
|
||||
|
|
@ -523,7 +519,30 @@ def _build_readable_messages_internal(
|
|||
# 3: 格式化为字符串
|
||||
output_lines: List[str] = []
|
||||
|
||||
prev_timestamp: Optional[float] = None
|
||||
for timestamp, name, content, is_action in detailed_message:
|
||||
# 检查是否需要插入长时间间隔提示
|
||||
if long_time_notice and prev_timestamp is not None:
|
||||
time_diff = timestamp - prev_timestamp
|
||||
time_diff_hours = time_diff / 3600
|
||||
|
||||
# 检查是否跨天
|
||||
prev_date = time.strftime("%Y-%m-%d", time.localtime(prev_timestamp))
|
||||
current_date = time.strftime("%Y-%m-%d", time.localtime(timestamp))
|
||||
is_cross_day = prev_date != current_date
|
||||
|
||||
# 如果间隔大于8小时或跨天,插入提示
|
||||
if time_diff_hours > 8 or is_cross_day:
|
||||
# 格式化日期为中文格式:xxxx年xx月xx日(去掉前导零)
|
||||
current_time_struct = time.localtime(timestamp)
|
||||
year = current_time_struct.tm_year
|
||||
month = current_time_struct.tm_mon
|
||||
day = current_time_struct.tm_mday
|
||||
date_str = f"{year}年{month}月{day}日"
|
||||
hours_str = f"{int(time_diff_hours)}h"
|
||||
notice = f"以下聊天开始时间:{date_str}。距离上一条消息过去了{hours_str}\n"
|
||||
output_lines.append(notice)
|
||||
|
||||
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
|
||||
|
||||
# 查找消息id(如果有)并构建id_prefix
|
||||
|
|
@ -537,6 +556,8 @@ def _build_readable_messages_internal(
|
|||
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
|
||||
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
|
||||
|
||||
prev_timestamp = timestamp
|
||||
|
||||
formatted_string = "".join(output_lines).strip()
|
||||
|
||||
# 返回格式化后的字符串、消息详情列表、图片映射字典和更新后的计数器
|
||||
|
|
@ -651,6 +672,7 @@ async def build_readable_messages_with_list(
|
|||
show_pic=True,
|
||||
message_id_list=None,
|
||||
pic_single=pic_single,
|
||||
long_time_notice=False,
|
||||
)
|
||||
|
||||
if not pic_single:
|
||||
|
|
@ -704,6 +726,7 @@ def build_readable_messages(
|
|||
message_id_list: Optional[List[Tuple[str, DatabaseMessages]]] = None,
|
||||
remove_emoji_stickers: bool = False,
|
||||
pic_single: bool = False,
|
||||
long_time_notice: bool = False,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
|
|
@ -719,6 +742,7 @@ def build_readable_messages(
|
|||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
remove_emoji_stickers: 是否移除表情包并过滤空消息
|
||||
long_time_notice: 是否在消息间隔过长(>8小时)或跨天时插入时间提示
|
||||
"""
|
||||
# WIP HERE and BELOW ----------------------------------------------
|
||||
# 创建messages的深拷贝,避免修改原始列表
|
||||
|
|
@ -812,6 +836,7 @@ def build_readable_messages(
|
|||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
pic_single=pic_single,
|
||||
long_time_notice=long_time_notice,
|
||||
)
|
||||
|
||||
if not pic_single:
|
||||
|
|
@ -839,6 +864,7 @@ def build_readable_messages(
|
|||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
pic_single=pic_single,
|
||||
long_time_notice=long_time_notice,
|
||||
)
|
||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
|
|
@ -850,6 +876,7 @@ def build_readable_messages(
|
|||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
pic_single=pic_single,
|
||||
long_time_notice=long_time_notice,
|
||||
)
|
||||
|
||||
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"
|
||||
|
|
|
|||
|
|
@ -743,13 +743,13 @@ class StatisticOutputTask(AsyncTask):
|
|||
"""
|
||||
if stats[TOTAL_REQ_CNT] <= 0:
|
||||
return ""
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
output = [
|
||||
"按模型分类统计:",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数 每次调用平均Token",
|
||||
]
|
||||
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
|
||||
name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name
|
||||
|
|
@ -764,6 +764,9 @@ class StatisticOutputTask(AsyncTask):
|
|||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
# 计算每次调用平均token
|
||||
avg_tokens_per_call = tokens / count if count > 0 else 0.0
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
|
|
@ -771,6 +774,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens_per_call = _format_large_number(avg_tokens_per_call) if count > 0 else "N/A"
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
|
|
@ -784,6 +788,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
std_time_cost,
|
||||
formatted_avg_count,
|
||||
formatted_avg_tokens,
|
||||
formatted_avg_tokens_per_call,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -797,13 +802,13 @@ class StatisticOutputTask(AsyncTask):
|
|||
"""
|
||||
if stats[TOTAL_REQ_CNT] <= 0:
|
||||
return ""
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
output = [
|
||||
"按模块分类统计:",
|
||||
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数 每次调用平均Token",
|
||||
]
|
||||
for module_name, count in sorted(stats[REQ_CNT_BY_MODULE].items()):
|
||||
name = f"{module_name[:29]}..." if len(module_name) > 32 else module_name
|
||||
|
|
@ -818,6 +823,9 @@ class StatisticOutputTask(AsyncTask):
|
|||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
# 计算每次调用平均token
|
||||
avg_tokens_per_call = tokens / count if count > 0 else 0.0
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
|
|
@ -825,6 +833,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens_per_call = _format_large_number(avg_tokens_per_call) if count > 0 else "N/A"
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
|
|
@ -838,6 +847,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
std_time_cost,
|
||||
formatted_avg_count,
|
||||
formatted_avg_tokens,
|
||||
formatted_avg_tokens_per_call,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -935,11 +945,12 @@ class StatisticOutputTask(AsyncTask):
|
|||
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODEL][model_name] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODEL][model_name] / count, html=True) if count > 0 else 'N/A'}</td>"
|
||||
f"</tr>"
|
||||
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODEL]
|
||||
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
else ["<tr><td colspan='11' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按请求类型分类统计
|
||||
type_rows = "\n".join(
|
||||
|
|
@ -955,11 +966,12 @@ class StatisticOutputTask(AsyncTask):
|
|||
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_TYPE][req_type] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_TYPE][req_type] / count, html=True) if count > 0 else 'N/A'}</td>"
|
||||
f"</tr>"
|
||||
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_TYPE]
|
||||
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
else ["<tr><td colspan='11' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按模块分类统计
|
||||
module_rows = "\n".join(
|
||||
|
|
@ -975,11 +987,12 @@ class StatisticOutputTask(AsyncTask):
|
|||
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODULE][module_name] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODULE][module_name] / count, html=True) if count > 0 else 'N/A'}</td>"
|
||||
f"</tr>"
|
||||
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODULE]
|
||||
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
else ["<tr><td colspan='11' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
|
||||
# 聊天消息统计
|
||||
|
|
@ -1054,7 +1067,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
<h2>按模型分类统计</h2>
|
||||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr></thead>
|
||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th><th>每次调用平均Token</th></tr></thead>
|
||||
<tbody>
|
||||
{model_rows}
|
||||
</tbody>
|
||||
|
|
@ -1065,7 +1078,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr>
|
||||
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th><th>每次调用平均Token</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{module_rows}
|
||||
|
|
@ -1077,7 +1090,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr>
|
||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th><th>每次调用平均Token</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{type_rows}
|
||||
|
|
|
|||
|
|
@ -67,6 +67,53 @@ def get_current_platform_account(platform: str, platform_accounts: dict[str, str
|
|||
return platform_accounts.get(platform, "")
|
||||
|
||||
|
||||
def is_bot_self(platform: str, user_id: str) -> bool:
|
||||
"""判断给定的平台和用户ID是否是机器人自己
|
||||
|
||||
这个函数统一处理所有平台(包括 QQ、Telegram、WebUI 等)的机器人识别逻辑。
|
||||
|
||||
Args:
|
||||
platform: 消息平台(如 "qq", "telegram", "webui" 等)
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 如果是机器人自己则返回 True,否则返回 False
|
||||
"""
|
||||
if not platform or not user_id:
|
||||
return False
|
||||
|
||||
# 将 user_id 转为字符串进行比较
|
||||
user_id_str = str(user_id)
|
||||
|
||||
# 获取机器人的 QQ 账号(主账号)
|
||||
qq_account = str(global_config.bot.qq_account or "")
|
||||
|
||||
# QQ 平台:直接比较 QQ 账号
|
||||
if platform == "qq":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# WebUI 平台:机器人回复时使用的是 QQ 账号,所以也比较 QQ 账号
|
||||
if platform == "webui":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# 获取各平台账号映射
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
|
||||
# Telegram 平台
|
||||
if platform == "telegram":
|
||||
tg_account = platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
return user_id_str == tg_account if tg_account else False
|
||||
|
||||
# 其他平台:尝试从 platforms 配置中查找
|
||||
platform_account = platform_accounts.get(platform, "")
|
||||
if platform_account:
|
||||
return user_id_str == platform_account
|
||||
|
||||
# 默认情况:与主 QQ 账号比较(兼容性)
|
||||
return user_id_str == qq_account
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]:
|
||||
"""检查消息是否提到了机器人(统一多平台实现)"""
|
||||
text = message.processed_plain_text or ""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
from typing import Optional, List, TYPE_CHECKING, Dict, Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
|
|
@ -17,3 +17,6 @@ class LLMGenerationDataModel(BaseDataModel):
|
|||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional["ReplySetModel"] = None
|
||||
timing: Optional[Dict[str, Any]] = None
|
||||
processed_output: Optional[List[str]] = None
|
||||
timing_logs: Optional[List[str]] = None
|
||||
|
|
|
|||
|
|
@ -321,18 +321,14 @@ class Expression(BaseModel):
|
|||
|
||||
situation = TextField()
|
||||
style = TextField()
|
||||
|
||||
# new mode fields
|
||||
context = TextField(null=True)
|
||||
|
||||
content_list = TextField(null=True)
|
||||
style_list = TextField(null=True) # 存储相似的 style,格式与 content_list 相同(JSON 数组)
|
||||
count = IntegerField(default=1)
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
|
||||
checked = BooleanField(default=False) # 是否已检查
|
||||
rejected = BooleanField(default=False) # 是否被拒绝但未更新
|
||||
modified_by = TextField(null=True) # 最后修改来源:'ai' 或 'user',为空表示未检查
|
||||
|
||||
class Meta:
|
||||
table_name = "expression"
|
||||
|
|
|
|||
|
|
@ -97,6 +97,9 @@ class TaskConfig(ConfigBase):
|
|||
slow_threshold: float = 15.0
|
||||
"""慢请求阈值(秒),超过此值会输出警告日志"""
|
||||
|
||||
selection_strategy: str = field(default="balance")
|
||||
"""模型选择策略:balance(负载均衡)或 random(随机选择)"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelTaskConfig(ConfigBase):
|
||||
|
|
@ -105,9 +108,6 @@ class ModelTaskConfig(ConfigBase):
|
|||
utils: TaskConfig
|
||||
"""组件模型配置"""
|
||||
|
||||
utils_small: TaskConfig
|
||||
"""组件小模型配置"""
|
||||
|
||||
replyer: TaskConfig
|
||||
"""normal_chat首要回复模型模型配置"""
|
||||
|
||||
|
|
@ -132,9 +132,6 @@ class ModelTaskConfig(ConfigBase):
|
|||
lpmm_rdf_build: TaskConfig
|
||||
"""LPMM RDF构建模型配置"""
|
||||
|
||||
lpmm_qa: TaskConfig
|
||||
"""LPMM问答模型配置"""
|
||||
|
||||
def get_task(self, task_name: str) -> TaskConfig:
|
||||
"""获取指定任务的配置"""
|
||||
if hasattr(self, task_name):
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
|||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.12.0"
|
||||
MMC_VERSION = "0.12.1"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Union
|
||||
import types
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
|
|
@ -108,6 +109,39 @@ class ConfigBase:
|
|||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理 Union/Optional 类型(包括 float | None 这种 Python 3.10+ 语法)
|
||||
# 注意:
|
||||
# - Optional[float] 等价于 Union[float, None],get_origin() 返回 typing.Union
|
||||
# - float | None 是 types.UnionType,get_origin() 返回 None
|
||||
is_union_type = (
|
||||
field_origin_type is Union # typing.Optional / typing.Union
|
||||
or isinstance(field_type, types.UnionType) # Python 3.10+ 的 | 语法
|
||||
)
|
||||
|
||||
if is_union_type:
|
||||
union_args = field_type_args if field_type_args else get_args(field_type)
|
||||
|
||||
# 安全检查:只允许 T | None 形式的 Optional 类型,禁止 float | str 这种多类型 Union
|
||||
non_none_types = [arg for arg in union_args if arg is not type(None)]
|
||||
if len(non_none_types) > 1:
|
||||
raise TypeError(
|
||||
f"配置字段不支持多类型 Union(如 float | str),只支持 Optional 类型(如 float | None)。"
|
||||
f"当前类型: {field_type}"
|
||||
)
|
||||
|
||||
# 如果值是 None 且 None 在 Union 中,直接返回
|
||||
if value is None and type(None) in union_args:
|
||||
return None
|
||||
# 尝试转换为非 None 的类型
|
||||
for arg in union_args:
|
||||
if arg is not type(None):
|
||||
try:
|
||||
return cls._convert_field(value, arg)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
# 如果所有类型都转换失败,抛出异常
|
||||
raise TypeError("Cannot convert value to any type in Union")
|
||||
|
||||
# 处理基础类型,例如 int, str 等
|
||||
if field_origin_type is type(None) and value is None: # 处理Optional类型
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -122,6 +122,12 @@ class ChatConfig(ConfigBase):
|
|||
- dynamic: think_level由planner动态给出(根据planner返回的think_level决定)
|
||||
"""
|
||||
|
||||
plan_reply_log_max_per_chat: int = 1024
|
||||
"""每个聊天流最大保存的Plan/Reply日志数量,超过此数量时会自动删除最老的日志"""
|
||||
|
||||
llm_quote: bool = False
|
||||
"""是否在 reply action 中启用 quote 参数,启用后 LLM 可以控制是否引用消息"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
|
||||
try:
|
||||
|
|
@ -260,12 +266,32 @@ class MemoryConfig(ConfigBase):
|
|||
agent_timeout_seconds: float = 120.0
|
||||
"""Agent超时时间(秒)"""
|
||||
|
||||
enable_jargon_detection: bool = True
|
||||
"""记忆检索过程中是否启用黑话识别"""
|
||||
|
||||
global_memory: bool = False
|
||||
"""是否允许记忆检索在聊天记录中进行全局查询(忽略当前chat_id,仅对 search_chat_history 等工具生效)"""
|
||||
|
||||
global_memory_blacklist: list[str] = field(default_factory=lambda: [])
|
||||
"""
|
||||
全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索
|
||||
格式: ["platform:id:type", ...]
|
||||
|
||||
示例:
|
||||
[
|
||||
"qq:1919810:private", # 排除特定私聊
|
||||
"qq:114514:group", # 排除特定群聊
|
||||
]
|
||||
|
||||
说明:
|
||||
- 当启用全局记忆时,黑名单中的聊天流不会被检索
|
||||
- 当在黑名单中的聊天流进行查询时,仅使用该聊天流的本地记忆
|
||||
"""
|
||||
|
||||
planner_question: bool = True
|
||||
"""
|
||||
是否使用 Planner 提供的 question 作为记忆检索问题
|
||||
- True: 当 Planner 在 reply 动作中提供了 question 时,直接使用该问题进行记忆检索,跳过 LLM 生成问题的步骤
|
||||
- False: 沿用旧模式,使用 LLM 生成问题
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置值"""
|
||||
if self.max_agent_iterations < 1:
|
||||
|
|
@ -303,10 +329,13 @@ class ExpressionConfig(ConfigBase):
|
|||
格式: [["qq:12345:group", "qq:67890:private"]]
|
||||
"""
|
||||
|
||||
reflect: bool = False
|
||||
"""是否启用表达反思"""
|
||||
expression_self_reflect: bool = False
|
||||
"""是否启用自动表达优化"""
|
||||
|
||||
reflect_operator_id: str = ""
|
||||
expression_manual_reflect: bool = False
|
||||
"""是否启用手动表达优化"""
|
||||
|
||||
manual_reflect_operator_id: str = ""
|
||||
"""表达反思操作员ID"""
|
||||
|
||||
allow_reflect: list[str] = field(default_factory=list)
|
||||
|
|
@ -330,6 +359,34 @@ class ExpressionConfig(ConfigBase):
|
|||
- "planner": 仅使用 Planner 在 reply 动作中给出的 unknown_words 列表进行黑话检索
|
||||
"""
|
||||
|
||||
expression_checked_only: bool = False
|
||||
"""
|
||||
是否仅选择已检查且未拒绝的表达方式
|
||||
当设置为 true 时,只有 checked=True 且 rejected=False 的表达方式才会被选择
|
||||
当设置为 false 时,保留旧的筛选原则(仅排除 rejected=True 的表达方式)
|
||||
"""
|
||||
|
||||
|
||||
expression_auto_check_interval: int = 3600
|
||||
"""
|
||||
表达方式自动检查的间隔时间(单位:秒)
|
||||
默认值:3600秒(1小时)
|
||||
"""
|
||||
|
||||
expression_auto_check_count: int = 10
|
||||
"""
|
||||
每次自动检查时随机选取的表达方式数量
|
||||
默认值:10条
|
||||
"""
|
||||
|
||||
expression_auto_check_custom_criteria: list[str] = field(default_factory=list)
|
||||
"""
|
||||
表达方式自动检查的额外自定义评估标准
|
||||
格式: ["标准1", "标准2", "标准3", ...]
|
||||
这些标准会被添加到评估提示词中,作为额外的评估要求
|
||||
默认值:空列表
|
||||
"""
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""
|
||||
解析流配置字符串并生成对应的 chat_id
|
||||
|
|
|
|||
|
|
@ -44,18 +44,6 @@ def get_random_dream_styles(count: int = 2) -> List[str]:
|
|||
"""从梦境风格列表中随机选择指定数量的风格"""
|
||||
return random.sample(DREAM_STYLES, min(count, len(DREAM_STYLES)))
|
||||
|
||||
|
||||
def get_dream_summary_model() -> LLMRequest:
|
||||
"""获取用于生成梦境总结的 utils 模型实例"""
|
||||
global _dream_summary_model
|
||||
if _dream_summary_model is None:
|
||||
_dream_summary_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="dream.summary",
|
||||
)
|
||||
return _dream_summary_model
|
||||
|
||||
|
||||
def init_dream_summary_prompt() -> None:
|
||||
"""初始化梦境总结的提示词"""
|
||||
Prompt(
|
||||
|
|
@ -186,10 +174,12 @@ async def generate_dream_summary(
|
|||
)
|
||||
|
||||
# 调用 utils 模型生成梦境
|
||||
summary_model = get_dream_summary_model()
|
||||
summary_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer,
|
||||
request_type="dream.summary",
|
||||
)
|
||||
dream_content, (reasoning, model_name, _) = await summary_model.generate_response_async(
|
||||
dream_prompt,
|
||||
max_tokens=512,
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,362 +0,0 @@
|
|||
"""
|
||||
记忆遗忘任务
|
||||
每5分钟进行一次遗忘检查,根据不同的遗忘阶段删除记忆
|
||||
"""
|
||||
|
||||
import time
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
logger = get_logger("memory_forget_task")
|
||||
|
||||
|
||||
class MemoryForgetTask(AsyncTask):
|
||||
"""记忆遗忘任务,每5分钟执行一次"""
|
||||
|
||||
def __init__(self):
|
||||
# 每5分钟执行一次(300秒)
|
||||
super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)
|
||||
|
||||
async def run(self):
|
||||
"""执行遗忘检查"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
# logger.info("[记忆遗忘] 开始遗忘检查...")
|
||||
|
||||
# 执行4个阶段的遗忘检查
|
||||
# await self._forget_stage_1(current_time)
|
||||
# await self._forget_stage_2(current_time)
|
||||
# await self._forget_stage_3(current_time)
|
||||
# await self._forget_stage_4(current_time)
|
||||
|
||||
# logger.info("[记忆遗忘] 遗忘检查完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
|
||||
|
||||
async def _forget_stage_1(self, current_time: float):
|
||||
"""
|
||||
第一次遗忘检查:
|
||||
搜集所有:记忆还未被遗忘检查过(forget_times=0),且已经是30分钟之外的记忆
|
||||
取count最高25%和最低25%,删除,然后标记被遗忘检查次数为1
|
||||
"""
|
||||
try:
|
||||
# 30分钟 = 1800秒
|
||||
time_threshold = current_time - 1800
|
||||
|
||||
# 查询符合条件的记忆:forget_times=0 且 end_time < time_threshold
|
||||
candidates = list(
|
||||
ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")
|
||||
|
||||
# 按count排序
|
||||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||
|
||||
# 计算要删除的数量(最高25%和最低25%)
|
||||
total_count = len(candidates)
|
||||
delete_count = int(total_count * 0.25) # 25%
|
||||
|
||||
if delete_count == 0:
|
||||
logger.debug("[记忆遗忘-阶段1] 删除数量为0,跳过")
|
||||
return
|
||||
|
||||
# 选择要删除的记录(处理count相同的情况:随机选择)
|
||||
to_delete = []
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||
|
||||
# 去重(避免重复删除),使用id去重
|
||||
seen_ids = set()
|
||||
unique_to_delete = []
|
||||
for record in to_delete:
|
||||
if record.id not in seen_ids:
|
||||
seen_ids.add(record.id)
|
||||
unique_to_delete.append(record)
|
||||
to_delete = unique_to_delete
|
||||
|
||||
# 删除记录并更新forget_times
|
||||
deleted_count = 0
|
||||
for record in to_delete:
|
||||
try:
|
||||
record.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")
|
||||
|
||||
# 更新剩余记录的forget_times为1
|
||||
to_delete_ids = {r.id for r in to_delete}
|
||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||
if remaining:
|
||||
# 批量更新
|
||||
ids_to_update = [r.id for r in remaining]
|
||||
ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||
|
||||
logger.info(
|
||||
f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
|
||||
|
||||
async def _forget_stage_2(self, current_time: float):
|
||||
"""
|
||||
第二次遗忘检查:
|
||||
搜集所有:记忆遗忘检查为1,且已经是8小时之外的记忆
|
||||
取count最高7%和最低7%,删除,然后标记被遗忘检查次数为2
|
||||
"""
|
||||
try:
|
||||
# 8小时 = 28800秒
|
||||
time_threshold = current_time - 28800
|
||||
|
||||
# 查询符合条件的记忆:forget_times=1 且 end_time < time_threshold
|
||||
candidates = list(
|
||||
ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")
|
||||
|
||||
# 按count排序
|
||||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||
|
||||
# 计算要删除的数量(最高7%和最低7%)
|
||||
total_count = len(candidates)
|
||||
delete_count = int(total_count * 0.07) # 7%
|
||||
|
||||
if delete_count == 0:
|
||||
logger.debug("[记忆遗忘-阶段2] 删除数量为0,跳过")
|
||||
return
|
||||
|
||||
# 选择要删除的记录
|
||||
to_delete = []
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||
|
||||
# 去重
|
||||
to_delete = list(set(to_delete))
|
||||
|
||||
# 删除记录
|
||||
deleted_count = 0
|
||||
for record in to_delete:
|
||||
try:
|
||||
record.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")
|
||||
|
||||
# 更新剩余记录的forget_times为2
|
||||
to_delete_ids = {r.id for r in to_delete}
|
||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||
if remaining:
|
||||
ids_to_update = [r.id for r in remaining]
|
||||
ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||
|
||||
logger.info(
|
||||
f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
|
||||
|
||||
async def _forget_stage_3(self, current_time: float):
|
||||
"""
|
||||
第三次遗忘检查:
|
||||
搜集所有:记忆遗忘检查为2,且已经是48小时之外的记忆
|
||||
取count最高5%和最低5%,删除,然后标记被遗忘检查次数为3
|
||||
"""
|
||||
try:
|
||||
# 48小时 = 172800秒
|
||||
time_threshold = current_time - 172800
|
||||
|
||||
# 查询符合条件的记忆:forget_times=2 且 end_time < time_threshold
|
||||
candidates = list(
|
||||
ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")
|
||||
|
||||
# 按count排序
|
||||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||
|
||||
# 计算要删除的数量(最高5%和最低5%)
|
||||
total_count = len(candidates)
|
||||
delete_count = int(total_count * 0.05) # 5%
|
||||
|
||||
if delete_count == 0:
|
||||
logger.debug("[记忆遗忘-阶段3] 删除数量为0,跳过")
|
||||
return
|
||||
|
||||
# 选择要删除的记录
|
||||
to_delete = []
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||
|
||||
# 去重
|
||||
to_delete = list(set(to_delete))
|
||||
|
||||
# 删除记录
|
||||
deleted_count = 0
|
||||
for record in to_delete:
|
||||
try:
|
||||
record.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")
|
||||
|
||||
# 更新剩余记录的forget_times为3
|
||||
to_delete_ids = {r.id for r in to_delete}
|
||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||
if remaining:
|
||||
ids_to_update = [r.id for r in remaining]
|
||||
ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||
|
||||
logger.info(
|
||||
f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
|
||||
|
||||
async def _forget_stage_4(self, current_time: float):
|
||||
"""
|
||||
第四次遗忘检查:
|
||||
搜集所有:记忆遗忘检查为3,且已经是7天之外的记忆
|
||||
取count最高2%和最低2%,删除,然后标记被遗忘检查次数为4
|
||||
"""
|
||||
try:
|
||||
# 7天 = 604800秒
|
||||
time_threshold = current_time - 604800
|
||||
|
||||
# 查询符合条件的记忆:forget_times=3 且 end_time < time_threshold
|
||||
candidates = list(
|
||||
ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")
|
||||
|
||||
# 按count排序
|
||||
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||
|
||||
# 计算要删除的数量(最高2%和最低2%)
|
||||
total_count = len(candidates)
|
||||
delete_count = int(total_count * 0.02) # 2%
|
||||
|
||||
if delete_count == 0:
|
||||
logger.debug("[记忆遗忘-阶段4] 删除数量为0,跳过")
|
||||
return
|
||||
|
||||
# 选择要删除的记录
|
||||
to_delete = []
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||
|
||||
# 去重
|
||||
to_delete = list(set(to_delete))
|
||||
|
||||
# 删除记录
|
||||
deleted_count = 0
|
||||
for record in to_delete:
|
||||
try:
|
||||
record.delete_instance()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")
|
||||
|
||||
# 更新剩余记录的forget_times为4
|
||||
to_delete_ids = {r.id for r in to_delete}
|
||||
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||
if remaining:
|
||||
ids_to_update = [r.id for r in remaining]
|
||||
ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()
|
||||
|
||||
logger.info(
|
||||
f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
|
||||
|
||||
def _handle_same_count_random(
|
||||
self, candidates: List[ChatHistory], delete_count: int, mode: str
|
||||
) -> List[ChatHistory]:
|
||||
"""
|
||||
处理count相同的情况,随机选择要删除的记录
|
||||
|
||||
Args:
|
||||
candidates: 候选记录列表(已按count排序)
|
||||
delete_count: 要删除的数量
|
||||
mode: "high" 表示选择最高count的记录,"low" 表示选择最低count的记录
|
||||
|
||||
Returns:
|
||||
要删除的记录列表
|
||||
"""
|
||||
if not candidates or delete_count == 0:
|
||||
return []
|
||||
|
||||
to_delete = []
|
||||
|
||||
if mode == "high":
|
||||
# 从最高count开始选择
|
||||
start_idx = 0
|
||||
while start_idx < len(candidates) and len(to_delete) < delete_count:
|
||||
# 找到所有count相同的记录
|
||||
current_count = candidates[start_idx].count
|
||||
same_count_records = []
|
||||
idx = start_idx
|
||||
while idx < len(candidates) and candidates[idx].count == current_count:
|
||||
same_count_records.append(candidates[idx])
|
||||
idx += 1
|
||||
|
||||
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
|
||||
needed = delete_count - len(to_delete)
|
||||
if len(same_count_records) <= needed:
|
||||
to_delete.extend(same_count_records)
|
||||
else:
|
||||
# 随机选择需要的数量
|
||||
to_delete.extend(random.sample(same_count_records, needed))
|
||||
|
||||
start_idx = idx
|
||||
|
||||
else: # mode == "low"
|
||||
# 从最低count开始选择
|
||||
start_idx = len(candidates) - 1
|
||||
while start_idx >= 0 and len(to_delete) < delete_count:
|
||||
# 找到所有count相同的记录
|
||||
current_count = candidates[start_idx].count
|
||||
same_count_records = []
|
||||
idx = start_idx
|
||||
while idx >= 0 and candidates[idx].count == current_count:
|
||||
same_count_records.append(candidates[idx])
|
||||
idx -= 1
|
||||
|
||||
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
|
||||
needed = delete_count - len(to_delete)
|
||||
if len(same_count_records) <= needed:
|
||||
to_delete.extend(same_count_records)
|
||||
else:
|
||||
# 随机选择需要的数量
|
||||
to_delete.extend(random.sample(same_count_records, needed))
|
||||
|
||||
start_idx = idx
|
||||
|
||||
return to_delete
|
||||
|
|
@ -447,8 +447,20 @@ def _default_normal_response_parser(
|
|||
for call in message_part.tool_calls:
|
||||
try:
|
||||
arguments = json.loads(repair_json(call.function.arguments))
|
||||
# 【新增修复逻辑】如果解析出来还是字符串,说明发生了双重编码,尝试二次解析
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
# 尝试对字符串内容再次进行修复和解析
|
||||
arguments = json.loads(repair_json(arguments))
|
||||
except Exception:
|
||||
# 如果二次解析失败,保留原值,让下方的 isinstance(dict) 抛出更具体的错误
|
||||
pass
|
||||
if not isinstance(arguments, dict):
|
||||
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
|
||||
# 此时为了调试方便,建议打印出 arguments 的类型
|
||||
raise RespParseException(
|
||||
resp,
|
||||
f"响应解析失败,工具调用参数无法解析为字典类型 type={type(arguments)} arguments={arguments}"
|
||||
)
|
||||
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
|
||||
except json.JSONDecodeError as e:
|
||||
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import re
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
|
||||
from enum import Enum
|
||||
from rich.traceback import install
|
||||
|
|
@ -266,7 +267,7 @@ class LLMRequest:
|
|||
|
||||
def _select_model(self, exclude_models: Optional[Set[str]] = None) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||
"""
|
||||
根据总tokens和惩罚值选择的模型
|
||||
根据配置的策略选择模型:balance(负载均衡)或 random(随机选择)
|
||||
"""
|
||||
available_models = {
|
||||
model: scores
|
||||
|
|
@ -276,15 +277,30 @@ class LLMRequest:
|
|||
if not available_models:
|
||||
raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。")
|
||||
|
||||
least_used_model_name = min(
|
||||
available_models,
|
||||
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
|
||||
)
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
strategy = self.model_for_task.selection_strategy.lower()
|
||||
|
||||
if strategy == "random":
|
||||
# 随机选择策略
|
||||
selected_model_name = random.choice(list(available_models.keys()))
|
||||
elif strategy == "balance":
|
||||
# 负载均衡策略:根据总tokens和惩罚值选择
|
||||
selected_model_name = min(
|
||||
available_models,
|
||||
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
|
||||
)
|
||||
else:
|
||||
# 默认使用负载均衡策略
|
||||
logger.warning(f"未知的选择策略 '{strategy}',使用默认的负载均衡策略")
|
||||
selected_model_name = min(
|
||||
available_models,
|
||||
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
|
||||
)
|
||||
|
||||
model_info = model_config.get_model_info(selected_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
force_new_client = self.request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
logger.debug(f"选择请求模型: {model_info.name}")
|
||||
logger.debug(f"选择请求模型: {model_info.name} (策略: {strategy})")
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
|
||||
return model_info, api_provider, client
|
||||
|
|
|
|||
10
src/main.py
10
src/main.py
|
|
@ -24,6 +24,7 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
|||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
from src.dream.dream_agent import start_dream_scheduler
|
||||
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
||||
|
||||
# 插件系统现在使用统一的插件加载器
|
||||
|
||||
|
|
@ -87,16 +88,11 @@ class MainSystem:
|
|||
# 添加统计信息输出任务
|
||||
await async_task_manager.add_task(StatisticOutputTask())
|
||||
|
||||
# 添加聊天流统计任务(每5分钟生成一次报告,统计最近30天的数据)
|
||||
# await async_task_manager.add_task(TokenStatisticsTask())
|
||||
|
||||
# 添加遥测心跳任务
|
||||
await async_task_manager.add_task(TelemetryHeartBeatTask())
|
||||
|
||||
# 添加记忆遗忘任务
|
||||
from src.hippo_memorizer.memory_forget_task import MemoryForgetTask
|
||||
|
||||
await async_task_manager.add_task(MemoryForgetTask())
|
||||
# 添加表达方式自动检查任务
|
||||
await async_task_manager.add_task(ExpressionAutoCheckTask())
|
||||
|
||||
# 启动API服务器
|
||||
# start_api_server()
|
||||
|
|
|
|||
|
|
@ -15,10 +15,11 @@ from json_repair import repair_json
|
|||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import message_api
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.person_info.person_info import Person
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
|
@ -415,11 +416,11 @@ class ChatHistorySummarizer:
|
|||
# 1. 检查当前批次内是否有 bot 发言(只检查当前批次,不往前推)
|
||||
# 原因:我们要记录的是 bot 参与过的对话片段,如果当前批次内 bot 没有发言,
|
||||
# 说明 bot 没有参与这段对话,不应该记录
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
has_bot_message = False
|
||||
|
||||
for msg in messages:
|
||||
if msg.user_info.user_id == bot_user_id:
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
|
||||
has_bot_message = True
|
||||
break
|
||||
|
||||
|
|
@ -848,11 +849,7 @@ class ChatHistorySummarizer:
|
|||
)
|
||||
|
||||
try:
|
||||
response, _ = await self.summarizer_llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=500,
|
||||
)
|
||||
response, _ = await self.summarizer_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
# 解析JSON响应
|
||||
json_str = response.strip()
|
||||
|
|
@ -912,9 +909,12 @@ class ChatHistorySummarizer:
|
|||
result = _parse_with_quote_fix(extracted_json)
|
||||
|
||||
keywords = result.get("keywords", [])
|
||||
summary = result.get("summary", "无概括")
|
||||
summary = result.get("summary", "")
|
||||
key_point = result.get("key_point", [])
|
||||
|
||||
if not (keywords and summary) and key_point:
|
||||
logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少字段,原文\n{response}")
|
||||
|
||||
# 确保keywords和key_point是列表
|
||||
if isinstance(keywords, str):
|
||||
keywords = [keywords]
|
||||
|
|
@ -2,7 +2,7 @@ import time
|
|||
import json
|
||||
import asyncio
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
|
@ -11,7 +11,8 @@ from src.common.database.database_model import ThinkingBack
|
|||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||
from src.memory_system.memory_utils import parse_questions_json
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.bw_learner.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon
|
||||
|
||||
logger = get_logger("memory_retrieval")
|
||||
|
||||
|
|
@ -100,6 +101,7 @@ def init_memory_retrieval_prompt():
|
|||
**工具说明:**
|
||||
- 如果涉及过往事件,或者查询某个过去可能提到过的概念,或者某段时间发生的事件。可以使用聊天记录查询工具查询过往事件
|
||||
- 如果涉及人物,可以使用人物信息查询工具查询人物信息
|
||||
- 如果遇到不熟悉的词语、缩写、黑话或网络用语,可以使用query_words工具查询其含义
|
||||
- 如果没有可靠信息,且查询时间充足,或者不确定查询类别,也可以使用lpmm知识库查询,作为辅助信息
|
||||
|
||||
**思考**
|
||||
|
|
@ -202,7 +204,6 @@ async def _react_agent_solve_question(
|
|||
max_iterations: int = 5,
|
||||
timeout: float = 30.0,
|
||||
initial_info: str = "",
|
||||
initial_jargon_concepts: Optional[List[str]] = None,
|
||||
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
||||
"""使用ReAct架构的Agent来解决问题
|
||||
|
||||
|
|
@ -211,28 +212,29 @@ async def _react_agent_solve_question(
|
|||
chat_id: 聊天ID
|
||||
max_iterations: 最大迭代次数
|
||||
timeout: 超时时间(秒)
|
||||
initial_info: 初始信息(如概念检索结果),将作为collected_info的初始值
|
||||
initial_jargon_concepts: 预先已解析过的黑话列表,避免重复解释
|
||||
initial_info: 初始信息,将作为collected_info的初始值
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
|
||||
"""
|
||||
start_time = time.time()
|
||||
collected_info = initial_info if initial_info else ""
|
||||
enable_jargon_detection = global_config.memory.enable_jargon_detection
|
||||
seen_jargon_concepts: Set[str] = set()
|
||||
if enable_jargon_detection and initial_jargon_concepts:
|
||||
for concept in initial_jargon_concepts:
|
||||
concept = (concept or "").strip()
|
||||
if concept:
|
||||
seen_jargon_concepts.add(concept)
|
||||
# 构造日志前缀:[聊天流名称],用于在日志中标识聊天流
|
||||
try:
|
||||
chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
except Exception:
|
||||
chat_name = chat_id
|
||||
react_log_prefix = f"[{chat_name}] "
|
||||
thinking_steps = []
|
||||
is_timeout = False
|
||||
conversation_messages: List[Message] = []
|
||||
first_head_prompt: Optional[str] = None # 保存第一次使用的head_prompt(用于日志显示)
|
||||
last_tool_name: Optional[str] = None # 记录最后一次使用的工具名称
|
||||
|
||||
# 正常迭代:max_iterations 次(最终评估单独处理,不算在迭代中)
|
||||
for iteration in range(max_iterations):
|
||||
# 使用 while 循环,支持额外迭代
|
||||
iteration = 0
|
||||
max_iterations_with_extra = max_iterations
|
||||
while iteration < max_iterations_with_extra:
|
||||
# 检查超时
|
||||
if time.time() - start_time > timeout:
|
||||
logger.warning(f"ReAct Agent超时,已迭代{iteration}次")
|
||||
|
|
@ -475,7 +477,7 @@ async def _react_agent_solve_question(
|
|||
step["observations"] = ["检测到finish_search文本格式调用,找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
|
|
@ -488,7 +490,7 @@ async def _react_agent_solve_question(
|
|||
else:
|
||||
# found_answer为True但没有提供answer,视为错误,继续迭代
|
||||
logger.warning(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
|
||||
)
|
||||
else:
|
||||
# 未找到答案,直接退出查询
|
||||
|
|
@ -497,7 +499,9 @@ async def _react_agent_solve_question(
|
|||
)
|
||||
step["observations"] = ["检测到finish_search文本格式调用,未找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
|
|
@ -509,12 +513,15 @@ async def _react_agent_solve_question(
|
|||
|
||||
# 如果没有检测到finish_search格式,记录思考过程,继续下一轮迭代
|
||||
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}"
|
||||
)
|
||||
collected_info += f"思考: {response}"
|
||||
else:
|
||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 无工具调用且无响应")
|
||||
logger.warning(f"{react_log_prefix}第 {iteration + 1} 次迭代 无工具调用且无响应")
|
||||
step["observations"] = ["无响应且无工具调用"]
|
||||
thinking_steps.append(step)
|
||||
iteration += 1 # 在continue之前增加迭代计数,避免跳过iteration += 1
|
||||
continue
|
||||
|
||||
# 处理工具调用
|
||||
|
|
@ -541,7 +548,7 @@ async def _react_agent_solve_question(
|
|||
step["observations"] = ["检测到finish_search工具调用,找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
|
|
@ -554,14 +561,16 @@ async def _react_agent_solve_question(
|
|||
else:
|
||||
# found_answer为True但没有提供answer,视为错误
|
||||
logger.warning(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
|
||||
)
|
||||
else:
|
||||
# 未找到答案,直接退出查询
|
||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
||||
step["observations"] = ["检测到finish_search工具调用,未找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案")
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
|
|
@ -578,13 +587,16 @@ async def _react_agent_solve_question(
|
|||
tool_args = tool_call.args or {}
|
||||
|
||||
logger.debug(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||
)
|
||||
|
||||
# 跳过finish_search工具调用(已经在上面处理过了)
|
||||
if tool_name == "finish_search":
|
||||
continue
|
||||
|
||||
# 记录最后一次使用的工具名称(用于判断是否需要额外迭代)
|
||||
last_tool_name = tool_name
|
||||
|
||||
# 普通工具调用
|
||||
tool = tool_registry.get_tool(tool_name)
|
||||
if tool:
|
||||
|
|
@ -604,14 +616,18 @@ async def _react_agent_solve_question(
|
|||
return f"查询{tool_name_str}({param_str})的结果:{observation}"
|
||||
except Exception as e:
|
||||
error_msg = f"工具执行失败: {str(e)}"
|
||||
logger.error(f"ReAct Agent 第 {iter_num + 1} 次迭代 工具 {tool_name_str} {error_msg}")
|
||||
logger.error(
|
||||
f"{react_log_prefix}第 {iter_num + 1} 次迭代 工具 {tool_name_str} {error_msg}"
|
||||
)
|
||||
return f"查询{tool_name_str}失败: {error_msg}"
|
||||
|
||||
tool_tasks.append(execute_single_tool(tool, tool_params, tool_name, iteration))
|
||||
step["actions"].append({"action_type": tool_name, "action_params": tool_args})
|
||||
else:
|
||||
error_msg = f"未知的工具类型: {tool_name}"
|
||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}")
|
||||
logger.warning(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}"
|
||||
)
|
||||
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
|
||||
|
||||
# 并行执行所有工具
|
||||
|
|
@ -622,31 +638,16 @@ async def _react_agent_solve_question(
|
|||
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
|
||||
if isinstance(observation, Exception):
|
||||
observation = f"工具执行异常: {str(observation)}"
|
||||
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}")
|
||||
logger.error(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}"
|
||||
)
|
||||
|
||||
observation_text = observation if isinstance(observation, str) else str(observation)
|
||||
stripped_observation = observation_text.strip()
|
||||
step["observations"].append(observation_text)
|
||||
collected_info += f"\n{observation_text}\n"
|
||||
if stripped_observation:
|
||||
# 检查工具输出中是否有新的jargon,如果有则追加到工具结果中
|
||||
if enable_jargon_detection:
|
||||
jargon_concepts = match_jargon_from_text(stripped_observation, chat_id)
|
||||
if jargon_concepts:
|
||||
new_concepts = []
|
||||
for concept in jargon_concepts:
|
||||
normalized_concept = concept.strip()
|
||||
if normalized_concept and normalized_concept not in seen_jargon_concepts:
|
||||
new_concepts.append(normalized_concept)
|
||||
seen_jargon_concepts.add(normalized_concept)
|
||||
if new_concepts:
|
||||
jargon_info = await retrieve_concepts_with_jargon(new_concepts, chat_id)
|
||||
if jargon_info:
|
||||
# 将jargon查询结果追加到工具结果中
|
||||
observation_text += f"\n\n{jargon_info}"
|
||||
collected_info += f"\n{jargon_info}\n"
|
||||
logger.info(f"工具输出触发黑话解析: {new_concepts}")
|
||||
|
||||
# 不再自动检测工具输出中的jargon,改为通过 query_words 工具主动查询
|
||||
tool_builder = MessageBuilder()
|
||||
tool_builder.set_role(RoleType.Tool)
|
||||
tool_builder.add_text_content(observation_text)
|
||||
|
|
@ -655,15 +656,24 @@ async def _react_agent_solve_question(
|
|||
|
||||
thinking_steps.append(step)
|
||||
|
||||
# 检查是否需要额外迭代:如果最后一次使用的工具是 search_chat_history 且达到最大迭代次数,额外增加一回合
|
||||
if iteration + 1 >= max_iterations and last_tool_name == "search_chat_history" and not is_timeout:
|
||||
max_iterations_with_extra = max_iterations + 1
|
||||
logger.info(
|
||||
f"{react_log_prefix}达到最大迭代次数(已迭代{iteration + 1}次),最后一次使用工具为 search_chat_history,额外增加一回合尝试"
|
||||
)
|
||||
|
||||
iteration += 1
|
||||
|
||||
# 正常迭代结束后,如果达到最大迭代次数或超时,执行最终评估
|
||||
# 最终评估单独处理,不算在迭代中
|
||||
should_do_final_evaluation = False
|
||||
if is_timeout:
|
||||
should_do_final_evaluation = True
|
||||
logger.warning(f"ReAct Agent超时,已迭代{iteration + 1}次,进入最终评估")
|
||||
elif iteration + 1 >= max_iterations:
|
||||
logger.warning(f"{react_log_prefix}超时,已迭代{iteration}次,进入最终评估")
|
||||
elif iteration >= max_iterations:
|
||||
should_do_final_evaluation = True
|
||||
logger.info(f"ReAct Agent达到最大迭代次数(已迭代{iteration + 1}次),进入最终评估")
|
||||
logger.info(f"{react_log_prefix}达到最大迭代次数(已迭代{iteration}次),进入最终评估")
|
||||
|
||||
if should_do_final_evaluation:
|
||||
# 获取必要变量用于最终评估
|
||||
|
|
@ -766,8 +776,8 @@ async def _react_agent_solve_question(
|
|||
return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout
|
||||
|
||||
if global_config.debug.show_memory_prompt:
|
||||
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
|
||||
logger.info(f"ReAct Agent 最终评估响应: {eval_response}")
|
||||
logger.info(f"{react_log_prefix}最终评估Prompt: {evaluation_prompt}")
|
||||
logger.info(f"{react_log_prefix}最终评估响应: {eval_response}")
|
||||
|
||||
# 从最终评估响应中提取found_answer或not_enough_info
|
||||
found_answer_content = None
|
||||
|
|
@ -998,7 +1008,6 @@ async def _process_single_question(
|
|||
chat_id: str,
|
||||
context: str,
|
||||
initial_info: str = "",
|
||||
initial_jargon_concepts: Optional[List[str]] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
"""处理单个问题的查询
|
||||
|
|
@ -1007,12 +1016,17 @@ async def _process_single_question(
|
|||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
initial_info: 初始信息(如概念检索结果),将传递给ReAct Agent
|
||||
initial_jargon_concepts: 已经处理过的黑话概念列表,用于ReAct阶段的去重
|
||||
initial_info: 初始信息,将传递给ReAct Agent
|
||||
max_iterations: 最大迭代次数
|
||||
|
||||
Returns:
|
||||
Optional[str]: 如果找到答案,返回格式化的结果字符串,否则返回None
|
||||
"""
|
||||
# 如果question为空或None,直接返回None,不进行查询
|
||||
if not question or not question.strip():
|
||||
logger.debug("问题为空,跳过查询")
|
||||
return None
|
||||
|
||||
# logger.info(f"开始处理问题: {question}")
|
||||
|
||||
_cleanup_stale_not_found_thinking_back()
|
||||
|
|
@ -1022,8 +1036,6 @@ async def _process_single_question(
|
|||
# 直接使用ReAct Agent查询(不再从thinking_back获取缓存)
|
||||
# logger.info(f"使用ReAct Agent查询,问题: {question[:50]}...")
|
||||
|
||||
jargon_concepts_for_agent = initial_jargon_concepts if global_config.memory.enable_jargon_detection else None
|
||||
|
||||
# 如果未指定max_iterations,使用配置的默认值
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
|
@ -1034,7 +1046,6 @@ async def _process_single_question(
|
|||
max_iterations=max_iterations,
|
||||
timeout=global_config.memory.agent_timeout_seconds,
|
||||
initial_info=question_initial_info,
|
||||
initial_jargon_concepts=jargon_concepts_for_agent,
|
||||
)
|
||||
|
||||
# 存储查询历史到数据库(超时时不存储)
|
||||
|
|
@ -1062,6 +1073,8 @@ async def build_memory_retrieval_prompt(
|
|||
target: str,
|
||||
chat_stream,
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
question: Optional[str] = None,
|
||||
) -> str:
|
||||
"""构建记忆检索提示
|
||||
使用两段式查询:第一步生成问题,第二步使用ReAct Agent查询答案
|
||||
|
|
@ -1071,14 +1084,33 @@ async def build_memory_retrieval_prompt(
|
|||
sender: 发送者名称
|
||||
target: 目标消息内容
|
||||
chat_stream: 聊天流对象
|
||||
tool_executor: 工具执行器(保留参数以兼容接口)
|
||||
think_level: 思考深度等级
|
||||
unknown_words: Planner 提供的未知词语列表,优先使用此列表而不是从聊天记录匹配
|
||||
question: Planner 提供的问题,当 planner_question 配置开启时,直接使用此问题进行检索
|
||||
|
||||
Returns:
|
||||
str: 记忆检索结果字符串
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
# 构造日志前缀:[聊天流名称],用于在日志中标识聊天流(优先群名称/用户昵称)
|
||||
try:
|
||||
group_info = chat_stream.group_info
|
||||
user_info = chat_stream.user_info
|
||||
# 群聊优先使用群名称
|
||||
if group_info is not None and getattr(group_info, "group_name", None):
|
||||
stream_name = group_info.group_name.strip() or str(group_info.group_id)
|
||||
# 私聊使用用户昵称
|
||||
elif user_info is not None and getattr(user_info, "user_nickname", None):
|
||||
stream_name = user_info.user_nickname.strip() or str(user_info.user_id)
|
||||
# 兜底使用 stream_id
|
||||
else:
|
||||
stream_name = chat_stream.stream_id
|
||||
except Exception:
|
||||
stream_name = chat_stream.stream_id
|
||||
log_prefix = f"[{stream_name}] " if stream_name else ""
|
||||
|
||||
logger.info(f"{log_prefix}检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
try:
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
bot_name = global_config.bot.nickname
|
||||
|
|
@ -1089,66 +1121,81 @@ async def build_memory_retrieval_prompt(
|
|||
if not recent_query_history:
|
||||
recent_query_history = "最近没有查询记录。"
|
||||
|
||||
# 第一步:生成问题
|
||||
question_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_question_prompt",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
chat_history=message,
|
||||
recent_query_history=recent_query_history,
|
||||
sender=sender,
|
||||
target_message=target,
|
||||
)
|
||||
# 第一步:生成问题或使用 Planner 提供的问题
|
||||
single_question: Optional[str] = None
|
||||
|
||||
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||
question_prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
request_type="memory.question",
|
||||
)
|
||||
|
||||
if global_config.debug.show_memory_prompt:
|
||||
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
|
||||
# logger.info(f"记忆检索问题生成响应: {response}")
|
||||
|
||||
if not success:
|
||||
logger.error(f"LLM生成问题失败: {response}")
|
||||
return ""
|
||||
|
||||
# 解析概念列表和问题列表
|
||||
_, questions = parse_questions_json(response)
|
||||
if questions:
|
||||
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
|
||||
|
||||
enable_jargon_detection = global_config.memory.enable_jargon_detection
|
||||
concepts: List[str] = []
|
||||
|
||||
if enable_jargon_detection:
|
||||
# 使用匹配逻辑自动识别聊天中的黑话概念
|
||||
concepts = match_jargon_from_text(message, chat_id)
|
||||
if concepts:
|
||||
logger.info(f"黑话匹配命中 {len(concepts)} 个概念: {concepts}")
|
||||
# 如果 planner_question 配置开启,只使用 Planner 提供的问题,不使用旧模式
|
||||
if global_config.memory.planner_question:
|
||||
if question and isinstance(question, str) and question.strip():
|
||||
# 清理和验证 question
|
||||
single_question = question.strip()
|
||||
logger.info(f"{log_prefix}使用 Planner 提供的 question: {single_question}")
|
||||
else:
|
||||
logger.debug("黑话匹配未命中任何概念")
|
||||
# planner_question 开启但没有提供 question,跳过记忆检索
|
||||
logger.debug(f"{log_prefix}planner_question 已开启但未提供 question,跳过记忆检索")
|
||||
end_time = time.time()
|
||||
logger.info(f"{log_prefix}无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}秒")
|
||||
return ""
|
||||
else:
|
||||
logger.debug("已禁用记忆检索中的黑话识别")
|
||||
# planner_question 关闭,使用旧模式:LLM 生成问题
|
||||
question_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_question_prompt",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
chat_history=message,
|
||||
recent_query_history=recent_query_history,
|
||||
sender=sender,
|
||||
target_message=target,
|
||||
)
|
||||
|
||||
# 对匹配到的概念进行jargon检索,作为初始信息
|
||||
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||
question_prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
request_type="memory.question",
|
||||
)
|
||||
|
||||
if global_config.debug.show_memory_prompt:
|
||||
logger.info(f"{log_prefix}记忆检索问题生成提示词: {question_prompt}")
|
||||
# logger.info(f"记忆检索问题生成响应: {response}")
|
||||
|
||||
if not success:
|
||||
logger.error(f"{log_prefix}LLM生成问题失败: {response}")
|
||||
return ""
|
||||
|
||||
# 解析概念列表和问题列表,只取第一个问题
|
||||
_, questions = parse_questions_json(response)
|
||||
if questions and len(questions) > 0:
|
||||
single_question = questions[0].strip()
|
||||
logger.info(f"{log_prefix}解析到问题: {single_question}")
|
||||
|
||||
# 初始阶段:使用 Planner 提供的 unknown_words 进行检索(如果提供)
|
||||
initial_info = ""
|
||||
if enable_jargon_detection and concepts:
|
||||
concept_info = await retrieve_concepts_with_jargon(concepts, chat_id)
|
||||
if concept_info:
|
||||
initial_info += concept_info
|
||||
logger.debug(f"概念检索完成,结果: {concept_info}")
|
||||
else:
|
||||
logger.debug("概念检索未找到任何结果")
|
||||
if unknown_words and len(unknown_words) > 0:
|
||||
# 清理和去重 unknown_words
|
||||
cleaned_concepts = []
|
||||
for word in unknown_words:
|
||||
if isinstance(word, str):
|
||||
cleaned = word.strip()
|
||||
if cleaned:
|
||||
cleaned_concepts.append(cleaned)
|
||||
if cleaned_concepts:
|
||||
# 对匹配到的概念进行jargon检索,作为初始信息
|
||||
concept_info = await retrieve_concepts_with_jargon(cleaned_concepts, chat_id)
|
||||
if concept_info:
|
||||
initial_info += concept_info
|
||||
logger.info(
|
||||
f"{log_prefix}使用 Planner 提供的 unknown_words,共 {len(cleaned_concepts)} 个概念,检索结果: {concept_info[:100]}..."
|
||||
)
|
||||
else:
|
||||
logger.debug(f"{log_prefix}unknown_words 检索未找到任何结果")
|
||||
|
||||
if not questions:
|
||||
logger.debug("模型认为不需要检索记忆或解析失败,不返回任何查询结果")
|
||||
if not single_question:
|
||||
logger.debug(f"{log_prefix}模型认为不需要检索记忆或解析失败,不返回任何查询结果")
|
||||
end_time = time.time()
|
||||
logger.info(f"无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}秒")
|
||||
logger.info(f"{log_prefix}无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}秒")
|
||||
return ""
|
||||
|
||||
# 第二步:并行处理所有问题(使用配置的最大迭代次数和超时时间)
|
||||
# 第二步:处理问题(使用配置的最大迭代次数和超时时间)
|
||||
base_max_iterations = global_config.memory.max_agent_iterations
|
||||
# 根据think_level调整迭代次数:think_level=1时不变,think_level=0时减半
|
||||
if think_level == 0:
|
||||
|
|
@ -1157,32 +1204,21 @@ async def build_memory_retrieval_prompt(
|
|||
max_iterations = base_max_iterations
|
||||
timeout_seconds = global_config.memory.agent_timeout_seconds
|
||||
logger.debug(
|
||||
f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒"
|
||||
f"{log_prefix}问题: {single_question},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒"
|
||||
)
|
||||
|
||||
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
||||
question_tasks = [
|
||||
_process_single_question(
|
||||
question=question,
|
||||
# 处理单个问题
|
||||
try:
|
||||
result = await _process_single_question(
|
||||
question=single_question,
|
||||
chat_id=chat_id,
|
||||
context=message,
|
||||
initial_info=initial_info,
|
||||
initial_jargon_concepts=concepts if enable_jargon_detection else None,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
# 并行执行所有查询任务
|
||||
results = await asyncio.gather(*question_tasks, return_exceptions=True)
|
||||
|
||||
# 收集所有有效结果
|
||||
question_results: List[str] = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"处理问题 '{questions[i]}' 时发生异常: {result}")
|
||||
elif result is not None:
|
||||
question_results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"{log_prefix}处理问题 '{single_question}' 时发生异常: {e}")
|
||||
result = None
|
||||
|
||||
# 获取最近10分钟内已找到答案的缓存记录
|
||||
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
|
||||
|
|
@ -1191,39 +1227,39 @@ async def build_memory_retrieval_prompt(
|
|||
all_results = []
|
||||
|
||||
# 先添加当前查询的结果
|
||||
current_questions = set()
|
||||
for result in question_results:
|
||||
current_question = None
|
||||
if result:
|
||||
all_results.append(result)
|
||||
# 提取问题(格式为 "问题:xxx\n答案:xxx")
|
||||
if result.startswith("问题:"):
|
||||
question_end = result.find("\n答案:")
|
||||
if question_end != -1:
|
||||
current_questions.add(result[4:question_end])
|
||||
all_results.append(result)
|
||||
current_question = result[4:question_end]
|
||||
|
||||
# 添加缓存答案(排除当前查询中已存在的问题)
|
||||
# 添加缓存答案(排除当前查询的问题)
|
||||
for cached_answer in cached_answers:
|
||||
if cached_answer.startswith("问题:"):
|
||||
question_end = cached_answer.find("\n答案:")
|
||||
if question_end != -1:
|
||||
cached_question = cached_answer[4:question_end]
|
||||
if cached_question not in current_questions:
|
||||
if cached_question != current_question:
|
||||
all_results.append(cached_answer)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
if all_results:
|
||||
retrieved_memory = "\n\n".join(all_results)
|
||||
current_count = len(question_results)
|
||||
current_count = 1 if result else 0
|
||||
cached_count = len(all_results) - current_count
|
||||
logger.info(
|
||||
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,"
|
||||
f"{log_prefix}记忆检索成功,耗时: {(end_time - start_time):.3f}秒,"
|
||||
f"当前查询 {current_count} 条记忆,缓存 {cached_count} 条记忆,共 {len(all_results)} 条记忆"
|
||||
)
|
||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||
else:
|
||||
logger.debug("所有问题均未找到答案,且无缓存答案")
|
||||
logger.debug(f"{log_prefix}问题未找到答案,且无缓存答案")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记忆检索时发生异常: {str(e)}")
|
||||
logger.error(f"{log_prefix}记忆检索时发生异常: {str(e)}")
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from .tool_registry import (
|
|||
from .query_chat_history import register_tool as register_query_chat_history
|
||||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_person_info import register_tool as register_query_person_info
|
||||
from .query_words import register_tool as register_query_words
|
||||
from .found_answer import register_tool as register_finish_search
|
||||
from src.config.config import global_config
|
||||
|
||||
|
|
@ -22,6 +23,7 @@ def init_all_tools():
|
|||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
register_query_words() # 注册query_words工具
|
||||
register_finish_search() # 注册finish_search工具
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Optional, Set
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
|
@ -16,36 +16,183 @@ from .tool_registry import register_memory_retrieval_tool
|
|||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
|
||||
def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]:
|
||||
"""将黑名单配置(platform:id:type格式)转换为chat_id集合
|
||||
|
||||
Args:
|
||||
blacklist: 黑名单配置列表,格式为 ["platform:id:type", ...]
|
||||
|
||||
Returns:
|
||||
Set[str]: chat_id集合
|
||||
"""
|
||||
chat_ids = set()
|
||||
if not blacklist:
|
||||
return chat_ids
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
for blacklist_item in blacklist:
|
||||
if not isinstance(blacklist_item, str):
|
||||
continue
|
||||
|
||||
try:
|
||||
parts = blacklist_item.split(":")
|
||||
if len(parts) != 3:
|
||||
logger.warning(f"黑名单配置格式错误,应为 platform:id:type,实际: {blacklist_item}")
|
||||
continue
|
||||
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 转换为chat_id
|
||||
chat_id = chat_manager.get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
if chat_id:
|
||||
chat_ids.add(chat_id)
|
||||
else:
|
||||
logger.warning(f"无法将黑名单配置转换为chat_id: {blacklist_item}")
|
||||
except Exception as e:
|
||||
logger.warning(f"解析黑名单配置失败: {blacklist_item}, 错误: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化黑名单chat_id集合失败: {e}")
|
||||
|
||||
return chat_ids
|
||||
|
||||
|
||||
def _is_chat_id_in_blacklist(chat_id: str) -> bool:
|
||||
"""检查chat_id是否在全局记忆黑名单中
|
||||
|
||||
Args:
|
||||
chat_id: 要检查的chat_id
|
||||
|
||||
Returns:
|
||||
bool: 如果chat_id在黑名单中返回True,否则返回False
|
||||
"""
|
||||
blacklist = getattr(global_config.memory, "global_memory_blacklist", [])
|
||||
if not blacklist:
|
||||
return False
|
||||
|
||||
blacklist_chat_ids = _parse_blacklist_to_chat_ids(blacklist)
|
||||
return chat_id in blacklist_chat_ids
|
||||
|
||||
|
||||
async def search_chat_history(
|
||||
chat_id: str,
|
||||
keyword: Optional[str] = None,
|
||||
participant: Optional[str] = None,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
) -> str:
|
||||
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配)
|
||||
participant: 参与人昵称(可选)
|
||||
start_time: 开始时间(可选,格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01')。如果只提供start_time,查询该时间点之后的记录
|
||||
end_time: 结束时间(可选,格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01')。如果只提供end_time,查询该时间点之前的记录。如果同时提供start_time和end_time,查询该时间段内的记录
|
||||
|
||||
Returns:
|
||||
str: 查询结果,包含记忆id、theme和keywords
|
||||
"""
|
||||
try:
|
||||
# 检查参数
|
||||
if not keyword and not participant:
|
||||
return "未指定查询参数(需要提供keyword或participant之一)"
|
||||
if not keyword and not participant and not start_time and not end_time:
|
||||
return "未指定查询参数(需要提供keyword、participant、start_time或end_time之一)"
|
||||
|
||||
# 解析时间参数
|
||||
start_timestamp = None
|
||||
end_timestamp = None
|
||||
|
||||
if start_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
start_timestamp = parse_datetime_to_timestamp(start_time)
|
||||
except ValueError as e:
|
||||
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
||||
if end_time:
|
||||
try:
|
||||
from src.memory_system.memory_utils import parse_datetime_to_timestamp
|
||||
end_timestamp = parse_datetime_to_timestamp(end_time)
|
||||
except ValueError as e:
|
||||
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'"
|
||||
|
||||
# 验证时间范围
|
||||
if start_timestamp and end_timestamp and start_timestamp > end_timestamp:
|
||||
return "开始时间不能晚于结束时间"
|
||||
|
||||
# 构建查询条件
|
||||
# 检查当前chat_id是否在黑名单中
|
||||
is_current_chat_in_blacklist = _is_chat_id_in_blacklist(chat_id)
|
||||
|
||||
# 根据配置决定是否限制在当前 chat_id 内查询
|
||||
use_global_search = global_config.memory.global_memory
|
||||
# 如果当前chat_id在黑名单中,强制使用本地查询
|
||||
use_global_search = global_config.memory.global_memory and not is_current_chat_in_blacklist
|
||||
|
||||
if use_global_search:
|
||||
# 全局查询所有聊天记录
|
||||
query = ChatHistory.select()
|
||||
logger.debug(
|
||||
f"search_chat_history 启用全局查询模式,忽略 chat_id 过滤,keyword={keyword}, participant={participant}"
|
||||
)
|
||||
# 全局查询所有聊天记录,但排除黑名单中的聊天流
|
||||
blacklist_chat_ids = _parse_blacklist_to_chat_ids(global_config.memory.global_memory_blacklist)
|
||||
if blacklist_chat_ids:
|
||||
# 排除黑名单中的chat_id
|
||||
query = ChatHistory.select().where(~(ChatHistory.chat_id.in_(blacklist_chat_ids)))
|
||||
logger.debug(
|
||||
f"search_chat_history 启用全局查询模式(排除黑名单 {len(blacklist_chat_ids)} 个聊天流),keyword={keyword}, participant={participant}"
|
||||
)
|
||||
else:
|
||||
# 没有黑名单,查询所有
|
||||
query = ChatHistory.select()
|
||||
logger.debug(
|
||||
f"search_chat_history 启用全局查询模式,忽略 chat_id 过滤,keyword={keyword}, participant={participant}"
|
||||
)
|
||||
else:
|
||||
# 仅在当前聊天流内查询
|
||||
if is_current_chat_in_blacklist:
|
||||
logger.debug(
|
||||
f"search_chat_history 当前聊天流在黑名单中,强制使用本地查询,chat_id={chat_id}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
# 添加时间过滤条件
|
||||
if start_timestamp is not None and end_timestamp is not None:
|
||||
# 查询指定时间段内的记录(记录的时间范围与查询时间段有交集)
|
||||
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
|
||||
query = query.where(
|
||||
(
|
||||
(ChatHistory.start_time >= start_timestamp)
|
||||
& (ChatHistory.start_time <= end_timestamp)
|
||||
) # 记录开始时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.end_time >= start_timestamp)
|
||||
& (ChatHistory.end_time <= end_timestamp)
|
||||
) # 记录结束时间在查询时间段内
|
||||
| (
|
||||
(ChatHistory.start_time <= start_timestamp)
|
||||
& (ChatHistory.end_time >= end_timestamp)
|
||||
) # 记录完全包含查询时间段
|
||||
)
|
||||
logger.debug(
|
||||
f"search_chat_history 添加时间范围过滤: {start_timestamp} - {end_timestamp}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
elif start_timestamp is not None:
|
||||
# 只提供开始时间,查询该时间点之后的记录(记录的开始时间或结束时间在该时间点之后)
|
||||
query = query.where(ChatHistory.end_time >= start_timestamp)
|
||||
logger.debug(
|
||||
f"search_chat_history 添加开始时间过滤: >= {start_timestamp}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
elif end_timestamp is not None:
|
||||
# 只提供结束时间,查询该时间点之前的记录(记录的开始时间或结束时间在该时间点之前)
|
||||
query = query.where(ChatHistory.start_time <= end_timestamp)
|
||||
logger.debug(
|
||||
f"search_chat_history 添加结束时间过滤: <= {end_timestamp}, keyword={keyword}, participant={participant}"
|
||||
)
|
||||
|
||||
# 执行查询
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
|
||||
|
|
@ -134,21 +281,31 @@ async def search_chat_history(chat_id: str, keyword: Optional[str] = None, parti
|
|||
filtered_records.append(record)
|
||||
|
||||
if not filtered_records:
|
||||
if keyword and participant:
|
||||
keywords_str = "、".join(parse_keywords_string(keyword) if keyword else [])
|
||||
return f"未找到包含关键词'{keywords_str}'且参与人包含'{participant}'的聊天记录"
|
||||
elif keyword:
|
||||
# 构建查询条件描述
|
||||
conditions = []
|
||||
if keyword:
|
||||
keywords_str = "、".join(parse_keywords_string(keyword))
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if len(keywords_list) > 2:
|
||||
required_count = len(keywords_list) - 1
|
||||
return (
|
||||
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||
)
|
||||
else:
|
||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||
elif participant:
|
||||
return f"未找到参与人包含'{participant}'的聊天记录"
|
||||
conditions.append(f"关键词'{keywords_str}'")
|
||||
if participant:
|
||||
conditions.append(f"参与人'{participant}'")
|
||||
if start_timestamp or end_timestamp:
|
||||
time_desc = ""
|
||||
if start_timestamp and end_timestamp:
|
||||
start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
time_desc = f"时间范围'{start_str}' 至 '{end_str}'"
|
||||
elif start_timestamp:
|
||||
start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
time_desc = f"时间>='{start_str}'"
|
||||
elif end_timestamp:
|
||||
end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
time_desc = f"时间<='{end_str}'"
|
||||
if time_desc:
|
||||
conditions.append(time_desc)
|
||||
|
||||
if conditions:
|
||||
conditions_str = "且".join(conditions)
|
||||
return f"未找到满足条件({conditions_str})的聊天记录"
|
||||
else:
|
||||
return "未找到相关聊天记录"
|
||||
|
||||
|
|
@ -336,7 +493,7 @@ def register_tool():
|
|||
# 注册工具1:搜索记忆
|
||||
register_memory_retrieval_tool(
|
||||
name="search_chat_history",
|
||||
description="根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配(容错匹配)。",
|
||||
description="根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则:如果关键词数量<=2,必须全部匹配;如果关键词数量>2,允许n-1个关键词匹配(容错匹配)。支持按时间点或时间段进行查询。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "keyword",
|
||||
|
|
@ -350,6 +507,18 @@ def register_tool():
|
|||
"description": "参与人昵称(可选),用于查询包含该参与人的记忆",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "start_time",
|
||||
"type": "string",
|
||||
"description": "开始时间(可选),格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'。如果只提供start_time,查询该时间点之后的记录。如果同时提供start_time和end_time,查询该时间段内的记录",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "end_time",
|
||||
"type": "string",
|
||||
"description": "结束时间(可选),格式如:'2025-01-01' 或 '2025-01-01 12:00:00' 或 '2025/01/01'。如果只提供end_time,查询该时间点之前的记录。如果同时提供start_time和end_time,查询该时间段内的记录",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=search_chat_history,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
查询黑话/概念含义 - 工具实现
|
||||
用于在记忆检索过程中主动查询未知词语或黑话的含义
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_words(chat_id: str, words: str) -> str:
|
||||
"""查询词语或黑话的含义
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
words: 要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔)
|
||||
|
||||
Returns:
|
||||
str: 查询结果,包含词语的含义解释
|
||||
"""
|
||||
try:
|
||||
if not words or not words.strip():
|
||||
return "未提供要查询的词语"
|
||||
|
||||
# 解析词语列表(支持逗号、空格等分隔符)
|
||||
words_list = []
|
||||
for separator in [",", ",", " ", "\n", "\t"]:
|
||||
if separator in words:
|
||||
words_list = [w.strip() for w in words.split(separator) if w.strip()]
|
||||
break
|
||||
|
||||
# 如果没有找到分隔符,整个字符串作为一个词语
|
||||
if not words_list:
|
||||
words_list = [words.strip()]
|
||||
|
||||
# 去重
|
||||
unique_words = []
|
||||
seen = set()
|
||||
for word in words_list:
|
||||
if word and word not in seen:
|
||||
unique_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
if not unique_words:
|
||||
return "未提供有效的词语"
|
||||
|
||||
logger.info(f"查询词语含义: {unique_words}")
|
||||
|
||||
# 调用检索函数
|
||||
result = await retrieve_concepts_with_jargon(unique_words, chat_id)
|
||||
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
return f"未找到词语 '{', '.join(unique_words)}' 的含义或黑话解释"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询词语含义失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_words",
|
||||
description="查询词语或黑话的含义。当遇到不熟悉的词语、缩写、黑话或网络用语时,可以使用此工具查询其含义。支持查询单个或多个词语(用逗号、空格等分隔)。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "words",
|
||||
"type": "string",
|
||||
"description": "要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔,如:'YYDS' 或 'YYDS,内卷,996')",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
execute_func=query_words,
|
||||
)
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
|||
logger = get_logger("person_info")
|
||||
|
||||
relation_selection_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="relation_selection"
|
||||
model_set=model_config.model_task_config.tool_use, request_type="relation_selection"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -228,8 +228,59 @@ class Person:
|
|||
|
||||
return person
|
||||
|
||||
def _is_bot_self(self, platform: str, user_id: str) -> bool:
|
||||
"""判断给定的平台和用户ID是否是机器人自己
|
||||
|
||||
这个函数统一处理所有平台(包括 QQ、Telegram、WebUI 等)的机器人识别逻辑。
|
||||
|
||||
Args:
|
||||
platform: 消息平台(如 "qq", "telegram", "webui" 等)
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 如果是机器人自己则返回 True,否则返回 False
|
||||
"""
|
||||
if not platform or not user_id:
|
||||
return False
|
||||
|
||||
# 将 user_id 转为字符串进行比较
|
||||
user_id_str = str(user_id)
|
||||
|
||||
# 获取机器人的 QQ 账号(主账号)
|
||||
qq_account = str(global_config.bot.qq_account or "")
|
||||
|
||||
# QQ 平台:直接比较 QQ 账号
|
||||
if platform == "qq":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# WebUI 平台:机器人回复时使用的是 QQ 账号,所以也比较 QQ 账号
|
||||
if platform == "webui":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# 获取各平台账号映射
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = {}
|
||||
for platform_entry in platforms_list:
|
||||
if ":" in platform_entry:
|
||||
platform_name, account = platform_entry.split(":", 1)
|
||||
platform_accounts[platform_name.strip()] = account.strip()
|
||||
|
||||
# Telegram 平台
|
||||
if platform == "telegram":
|
||||
tg_account = platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
return user_id_str == tg_account if tg_account else False
|
||||
|
||||
# 其他平台:尝试从 platforms 配置中查找
|
||||
platform_account = platform_accounts.get(platform, "")
|
||||
if platform_account:
|
||||
return user_id_str == platform_account
|
||||
|
||||
# 默认情况:与主 QQ 账号比较(兼容性)
|
||||
return user_id_str == qq_account
|
||||
|
||||
def __init__(self, platform: str = "", user_id: str = "", person_id: str = "", person_name: str = ""):
|
||||
if platform == global_config.bot.platform and user_id == global_config.bot.qq_account:
|
||||
# 使用统一的机器人识别函数(支持多平台,包括 WebUI)
|
||||
if self._is_bot_self(platform, user_id):
|
||||
self.is_known = True
|
||||
self.person_id = get_person_id(platform, user_id)
|
||||
self.user_id = user_id
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@ from .apis import (
|
|||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
auto_talk_api,
|
||||
register_plugin,
|
||||
get_logger,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from src.plugin_system.apis import (
|
|||
send_api,
|
||||
tool_api,
|
||||
frequency_api,
|
||||
auto_talk_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
|
|
@ -41,5 +40,4 @@ __all__ = [
|
|||
"register_plugin",
|
||||
"tool_api",
|
||||
"frequency_api",
|
||||
"auto_talk_api",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,56 +0,0 @@
|
|||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("auto_talk_api")
|
||||
|
||||
|
||||
def set_question_probability_multiplier(chat_id: str, multiplier: float) -> bool:
|
||||
"""
|
||||
设置指定 chat_id 的主动发言概率乘数。
|
||||
|
||||
返回:
|
||||
bool: 设置是否成功。仅当目标聊天为群聊(HeartFChatting)且存在时为 True。
|
||||
"""
|
||||
try:
|
||||
if not isinstance(chat_id, str):
|
||||
raise TypeError("chat_id 必须是 str")
|
||||
if not isinstance(multiplier, (int, float)):
|
||||
raise TypeError("multiplier 必须是数值类型")
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
from src.chat.heart_flow.heartflow import heartflow as _heartflow
|
||||
|
||||
chat = _heartflow.heartflow_chat_list.get(chat_id)
|
||||
if chat is None:
|
||||
logger.warning(f"未找到 chat_id={chat_id} 的心流实例,无法设置乘数")
|
||||
return False
|
||||
|
||||
# 仅对拥有该属性的群聊心流生效(鸭子类型,避免导入类)
|
||||
if not hasattr(chat, "question_probability_multiplier"):
|
||||
logger.warning(f"chat_id={chat_id} 实例不支持主动发言乘数设置")
|
||||
return False
|
||||
|
||||
# 约束:不允许负值
|
||||
value = float(multiplier)
|
||||
if value < 0:
|
||||
value = 0.0
|
||||
|
||||
chat.question_probability_multiplier = value
|
||||
logger.info(f"[auto_talk_api] chat_id={chat_id} 主动发言乘数已设为 {value}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置主动发言乘数失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_question_probability_multiplier(chat_id: str) -> float:
|
||||
"""获取指定 chat_id 的主动发言概率乘数,未找到则返回 0。"""
|
||||
try:
|
||||
# 延迟导入以避免循环依赖
|
||||
from src.chat.heart_flow.heartflow import heartflow as _heartflow
|
||||
|
||||
chat = _heartflow.heartflow_chat_list.get(chat_id)
|
||||
if chat is None:
|
||||
return 0.0
|
||||
return float(getattr(chat, "question_probability_multiplier", 0.0))
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
|
@ -9,6 +9,7 @@
|
|||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
|
|
@ -19,6 +20,7 @@ from src.chat.message_receive.chat_stream import ChatStream
|
|||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
|
|
@ -118,6 +120,10 @@ async def generate_reply(
|
|||
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
|
||||
"""
|
||||
try:
|
||||
# 如果 reply_time_point 未传入,设置为当前时间戳
|
||||
if reply_time_point is None:
|
||||
reply_time_point = time.time()
|
||||
|
||||
# 获取回复器
|
||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
|
|
@ -152,21 +158,42 @@ async def generate_reply(
|
|||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
unknown_words=unknown_words,
|
||||
unknown_words=unknown_words,
|
||||
think_level=think_level,
|
||||
from_plugin=from_plugin,
|
||||
stream_id=chat_stream.stream_id if chat_stream else chat_id,
|
||||
reply_time_point=reply_time_point,
|
||||
log_reply=False,
|
||||
)
|
||||
if not success:
|
||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||
return False, None
|
||||
reply_set: Optional[ReplySetModel] = None
|
||||
if content := llm_response.content:
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
llm_response.processed_output = processed_response
|
||||
reply_set = ReplySetModel()
|
||||
for text in processed_response:
|
||||
reply_set.add_text_content(text)
|
||||
llm_response.reply_set = reply_set
|
||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
|
||||
|
||||
# 统一在这里记录最终回复日志(包含分割后的 processed_output)
|
||||
try:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=chat_stream.stream_id if chat_stream else (chat_id or ""),
|
||||
prompt=llm_response.prompt or "",
|
||||
output=llm_response.content,
|
||||
processed_output=llm_response.processed_output,
|
||||
model=llm_response.model,
|
||||
timing=llm_response.timing,
|
||||
reasoning=llm_response.reasoning,
|
||||
think_level=think_level,
|
||||
success=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[GeneratorAPI] 记录reply日志失败")
|
||||
|
||||
return success, llm_response
|
||||
|
||||
except ValueError as ve:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import time
|
|||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.database_model import Images
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
|
|
@ -511,7 +511,8 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag
|
|||
Returns:
|
||||
过滤后的消息列表
|
||||
"""
|
||||
return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)]
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]
|
||||
|
||||
|
||||
def translate_pid_to_description(pid: str) -> str:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ class BaseAction(ABC):
|
|||
- keyword_case_sensitive: 关键词是否区分大小写
|
||||
- parallel_action: 是否允许并行执行
|
||||
- random_activation_probability: 随机激活概率
|
||||
- llm_judge_prompt: LLM判断提示词
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -81,8 +80,6 @@ class BaseAction(ABC):
|
|||
"""激活类型"""
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
"""当激活类型为RANDOM时的概率"""
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "") # 已弃用
|
||||
"""协助LLM进行判断的Prompt"""
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||
|
|
@ -504,7 +501,6 @@ class BaseAction(ABC):
|
|||
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
|
||||
parallel_action=getattr(cls, "parallel_action", True),
|
||||
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
|
||||
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
|
||||
# 使用正确的字段名
|
||||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
||||
action_require=getattr(cls, "action_require", []).copy(),
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ class ActionActivationType(Enum):
|
|||
|
||||
NEVER = "never" # 从不激活(默认关闭)
|
||||
ALWAYS = "always" # 默认参与到planner
|
||||
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
|
||||
RANDOM = "random" # 随机启用action到planner
|
||||
KEYWORD = "keyword" # 关键词触发启用action到planner
|
||||
|
||||
|
|
@ -128,7 +127,6 @@ class ActionInfo(ComponentInfo):
|
|||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
|
||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
random_activation_probability: float = 0.0
|
||||
llm_judge_prompt: str = ""
|
||||
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
|
||||
keyword_case_sensitive: bool = False
|
||||
# 模式和并行设置
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
"type": "action",
|
||||
"name": "tts_action",
|
||||
"description": "将文本转换为语音进行播放",
|
||||
"activation_modes": ["llm_judge", "keyword"],
|
||||
"activation_modes": ["keyword"],
|
||||
"keywords": ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
|
||||
}
|
||||
],
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue