Merge branch 'main-fix0' of https://github.com/Dax233/MaiMBot into main-fix2

pull/588/head
Bakadax 2025-03-28 09:15:57 +08:00
commit 98720b6e00
46 changed files with 2875 additions and 3079 deletions

View File

@ -130,7 +130,7 @@ MaiMBot是一个开源项目我们非常欢迎你的参与。你的贡献
### 💬交流群 ### 💬交流群
- [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779开发和建议相关讨论不一定有空回复会优先写文档和代码 - [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779开发和建议相关讨论不一定有空回复会优先写文档和代码
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 【已满】(开发和建议相关讨论)不一定有空回复,会优先写文档和代码 - [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 【已满】(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
- [二群](https://qm.qq.com/q/RzmCiRtHEW) 571780722 【已满】(开发和建议相关讨论)不一定有空回复,会优先写文档和代码 - [二群](https://qm.qq.com/q/RzmCiRtHEW) 571780722开发和建议相关讨论不一定有空回复会优先写文档和代码
- [三群](https://qm.qq.com/q/wlH5eT8OmQ) 1035228475【已满】开发和建议相关讨论不一定有空回复会优先写文档和代码 - [三群](https://qm.qq.com/q/wlH5eT8OmQ) 1035228475【已满】开发和建议相关讨论不一定有空回复会优先写文档和代码
- [四群](https://qm.qq.com/q/wlH5eT8OmQ) 729957033【已满】开发和建议相关讨论不一定有空回复会优先写文档和代码 - [四群](https://qm.qq.com/q/wlH5eT8OmQ) 729957033【已满】开发和建议相关讨论不一定有空回复会优先写文档和代码
@ -143,7 +143,7 @@ MaiMBot是一个开源项目我们非常欢迎你的参与。你的贡献
- 📦 **Windows 一键傻瓜式部署**:请运行项目根目录中的 `run.bat`,部署完成后请参照后续配置指南进行配置 - 📦 **Windows 一键傻瓜式部署**:请运行项目根目录中的 `run.bat`,部署完成后请参照后续配置指南进行配置
- 📦 Linux 自动部署(实验 :请下载并运行项目根目录中的`run.sh`并按照提示安装,部署完成后请参照后续配置指南进行配置 - 📦 Linux 自动部署(Arch/CentOS9/Debian12/Ubuntu24.10 :请下载并运行项目根目录中的`run.sh`并按照提示安装,部署完成后请参照后续配置指南进行配置
- [📦 Windows 手动部署指南 ](docs/manual_deploy_windows.md) - [📦 Windows 手动部署指南 ](docs/manual_deploy_windows.md)

View File

@ -1,6 +1,6 @@
# 🐳 Docker 部署指南 # 🐳 Docker 部署指南
## 部署步骤 (推荐,但不一定是最新) ## 部署步骤 (不一定是最新)
**"更新镜像与容器"部分在本文档 [Part 6](#6-更新镜像与容器)** **"更新镜像与容器"部分在本文档 [Part 6](#6-更新镜像与容器)**

View File

@ -1,9 +1,10 @@
#!/bin/bash #!/bin/bash
# 麦麦Bot一键安装脚本 by Cookie_987 # 麦麦Bot一键安装脚本 by Cookie_987
# 适用于Debian12 # 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9
# 请小心使用任何一键脚本! # 请小心使用任何一键脚本!
INSTALLER_VERSION="0.0.3"
LANG=C.UTF-8 LANG=C.UTF-8
# 如无法访问GitHub请修改此处镜像地址 # 如无法访问GitHub请修改此处镜像地址
@ -15,7 +16,14 @@ RED="\e[31m"
RESET="\e[0m" RESET="\e[0m"
# 需要的基本软件包 # 需要的基本软件包
REQUIRED_PACKAGES=("git" "sudo" "python3" "python3-venv" "curl" "gnupg" "python3-pip")
declare -A REQUIRED_PACKAGES=(
["common"]="git sudo python3 curl gnupg"
["debian"]="python3-venv python3-pip"
["ubuntu"]="python3-venv python3-pip"
["centos"]="python3-pip"
["arch"]="python-virtualenv python-pip"
)
# 默认项目目录 # 默认项目目录
DEFAULT_INSTALL_DIR="/opt/maimbot" DEFAULT_INSTALL_DIR="/opt/maimbot"
@ -28,8 +36,6 @@ IS_INSTALL_MONGODB=false
IS_INSTALL_NAPCAT=false IS_INSTALL_NAPCAT=false
IS_INSTALL_DEPENDENCIES=false IS_INSTALL_DEPENDENCIES=false
INSTALLER_VERSION="0.0.1"
# 检查是否已安装 # 检查是否已安装
check_installed() { check_installed() {
[[ -f /etc/systemd/system/${SERVICE_NAME}.service ]] [[ -f /etc/systemd/system/${SERVICE_NAME}.service ]]
@ -193,6 +199,11 @@ check_eula() {
# 首先计算当前隐私条款文件的哈希值 # 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(md5sum "${INSTALL_DIR}/repo/PRIVACY.md" | awk '{print $1}') current_md5_privacy=$(md5sum "${INSTALL_DIR}/repo/PRIVACY.md" | awk '{print $1}')
# 如果当前的md5值为空则直接返回
if [[ -z $current_md5 || -z $current_md5_privacy ]]; then
whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60
fi
# 检查eula.confirmed文件是否存在 # 检查eula.confirmed文件是否存在
if [[ -f ${INSTALL_DIR}/repo/eula.confirmed ]]; then if [[ -f ${INSTALL_DIR}/repo/eula.confirmed ]]; then
# 如果存在则检查其中包含的md5与current_md5是否一致 # 如果存在则检查其中包含的md5与current_md5是否一致
@ -213,8 +224,8 @@ check_eula() {
if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then
whiptail --title "📜 使用协议更新" --yesno "检测到麦麦Bot EULA或隐私条款已更新。\nhttps://github.com/SengokuCola/MaiMBot/blob/main/EULA.md\nhttps://github.com/SengokuCola/MaiMBot/blob/main/PRIVACY.md\n\n您是否同意上述协议 \n\n " 12 70 whiptail --title "📜 使用协议更新" --yesno "检测到麦麦Bot EULA或隐私条款已更新。\nhttps://github.com/SengokuCola/MaiMBot/blob/main/EULA.md\nhttps://github.com/SengokuCola/MaiMBot/blob/main/PRIVACY.md\n\n您是否同意上述协议 \n\n " 12 70
if [[ $? -eq 0 ]]; then if [[ $? -eq 0 ]]; then
echo $current_md5 > ${INSTALL_DIR}/repo/eula.confirmed echo -n $current_md5 > ${INSTALL_DIR}/repo/eula.confirmed
echo $current_md5_privacy > ${INSTALL_DIR}/repo/privacy.confirmed echo -n $current_md5_privacy > ${INSTALL_DIR}/repo/privacy.confirmed
else else
exit 1 exit 1
fi fi
@ -227,7 +238,14 @@ run_installation() {
# 1/6: 检测是否安装 whiptail # 1/6: 检测是否安装 whiptail
if ! command -v whiptail &>/dev/null; then if ! command -v whiptail &>/dev/null; then
echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}" echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}"
# 这里的多系统适配很神人,但是能用()
apt update && apt install -y whiptail apt update && apt install -y whiptail
pacman -S --noconfirm libnewt
yum install -y newt
fi fi
# 协议确认 # 协议确认
@ -247,8 +265,18 @@ run_installation() {
if [[ -f /etc/os-release ]]; then if [[ -f /etc/os-release ]]; then
source /etc/os-release source /etc/os-release
if [[ "$ID" != "debian" || "$VERSION_ID" != "12" ]]; then if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Debian 12 (Bookworm)\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60 return
elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then
return
elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then
return
elif [[ "$ID" == "arch" ]]; then
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
whiptail --title "⚠️ 兼容性警告" --msgbox "MongoDB无可用的 Arch Linux 官方安装方法将无法自动安装MongoDB。\n\n您可尝试在AUR中搜索相关包。" 10 60
return
else
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
exit 1 exit 1
fi fi
else else
@ -258,6 +286,20 @@ run_installation() {
} }
check_system check_system
# 设置包管理器
case "$ID" in
debian|ubuntu)
PKG_MANAGER="apt"
;;
centos)
PKG_MANAGER="yum"
;;
arch)
# 添加arch包管理器
PKG_MANAGER="pacman"
;;
esac
# 检查MongoDB # 检查MongoDB
check_mongodb() { check_mongodb() {
if command -v mongod &>/dev/null; then if command -v mongod &>/dev/null; then
@ -281,18 +323,27 @@ run_installation() {
# 安装必要软件包 # 安装必要软件包
install_packages() { install_packages() {
missing_packages=() missing_packages=()
for package in "${REQUIRED_PACKAGES[@]}"; do # 检查 common 及当前系统专属依赖
if ! dpkg -s "$package" &>/dev/null; then for package in ${REQUIRED_PACKAGES["common"]} ${REQUIRED_PACKAGES["$ID"]}; do
missing_packages+=("$package") case "$PKG_MANAGER" in
fi apt)
dpkg -s "$package" &>/dev/null || missing_packages+=("$package")
;;
yum)
rpm -q "$package" &>/dev/null || missing_packages+=("$package")
;;
pacman)
pacman -Qi "$package" &>/dev/null || missing_packages+=("$package")
;;
esac
done done
if [[ ${#missing_packages[@]} -gt 0 ]]; then if [[ ${#missing_packages[@]} -gt 0 ]]; then
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到以下必须的依赖项目缺失:\n${missing_packages[*]}\n\n是否要自动安装" 12 60 whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装" 10 60
if [[ $? -eq 0 ]]; then if [[ $? -eq 0 ]]; then
IS_INSTALL_DEPENDENCIES=true IS_INSTALL_DEPENDENCIES=true
else else
whiptail --title "⚠️ 注意" --yesno "某些必要的依赖项未安装,可能会影响运行!\n是否继续" 10 60 || exit 1 whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续" 10 60 || exit 1
fi fi
fi fi
} }
@ -302,27 +353,24 @@ run_installation() {
install_mongodb() { install_mongodb() {
[[ $MONGO_INSTALLED == true ]] && return [[ $MONGO_INSTALLED == true ]] && return
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装MongoDB是否安装\n如果您想使用远程数据库请跳过此步。" 10 60 && { whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装MongoDB是否安装\n如果您想使用远程数据库请跳过此步。" 10 60 && {
echo -e "${GREEN}安装 MongoDB...${RESET}"
curl -fsSL https://www.mongodb.org/static/pgp/server-8.0.asc | gpg -o /usr/share/keyrings/mongodb-server-8.0.gpg --dearmor
echo "deb [ signed-by=/usr/share/keyrings/mongodb-server-8.0.gpg ] http://repo.mongodb.org/apt/debian bookworm/mongodb-org/8.0 main" | tee /etc/apt/sources.list.d/mongodb-org-8.0.list
apt update
apt install -y mongodb-org
systemctl enable --now mongod
IS_INSTALL_MONGODB=true IS_INSTALL_MONGODB=true
} }
} }
install_mongodb
# 仅在非Arch系统上安装MongoDB
[[ "$ID" != "arch" ]] && install_mongodb
# 安装NapCat # 安装NapCat
install_napcat() { install_napcat() {
[[ $NAPCAT_INSTALLED == true ]] && return [[ $NAPCAT_INSTALLED == true ]] && return
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat是否安装\n如果您想使用远程NapCat请跳过此步。" 10 60 && { whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat是否安装\n如果您想使用远程NapCat请跳过此步。" 10 60 && {
echo -e "${GREEN}安装 NapCat...${RESET}"
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
IS_INSTALL_NAPCAT=true IS_INSTALL_NAPCAT=true
} }
} }
install_napcat
# 仅在非Arch系统上安装NapCat
[[ "$ID" != "arch" ]] && install_napcat
# Python版本检查 # Python版本检查
check_python() { check_python() {
@ -332,7 +380,12 @@ run_installation() {
exit 1 exit 1
fi fi
} }
check_python
# 如果没安装python则不检查python版本
if command -v python3 &>/dev/null; then
check_python
fi
# 选择分支 # 选择分支
choose_branch() { choose_branch() {
@ -358,20 +411,71 @@ run_installation() {
local confirm_msg="请确认以下信息:\n\n" local confirm_msg="请确认以下信息:\n\n"
confirm_msg+="📂 安装麦麦Bot到: $INSTALL_DIR\n" confirm_msg+="📂 安装麦麦Bot到: $INSTALL_DIR\n"
confirm_msg+="🔀 分支: $BRANCH\n" confirm_msg+="🔀 分支: $BRANCH\n"
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages}\n" [[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
[[ $IS_INSTALL_MONGODB == true || $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n" [[ $IS_INSTALL_MONGODB == true || $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
[[ $IS_INSTALL_MONGODB == true ]] && confirm_msg+=" - MongoDB\n" [[ $IS_INSTALL_MONGODB == true ]] && confirm_msg+=" - MongoDB\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n" [[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
confirm_msg+="\n注意本脚本默认使用ghfast.top为GitHub进行加速如不想使用请手动修改脚本开头的GITHUB_REPO变量。" confirm_msg+="\n注意本脚本默认使用ghfast.top为GitHub进行加速如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 16 60 || exit 1 whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1
} }
confirm_install confirm_install
# 开始安装 # 开始安装
echo -e "${GREEN}安装依赖...${RESET}" echo -e "${GREEN}安装${missing_packages[@]}...${RESET}"
[[ $IS_INSTALL_DEPENDENCIES == true ]] && apt update && apt install -y "${missing_packages[@]}"
if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then
case "$PKG_MANAGER" in
apt)
apt update && apt install -y "${missing_packages[@]}"
;;
yum)
yum install -y "${missing_packages[@]}" --nobest
;;
pacman)
pacman -S --noconfirm "${missing_packages[@]}"
;;
esac
fi
if [[ $IS_INSTALL_MONGODB == true ]]; then
echo -e "${GREEN}安装 MongoDB...${RESET}"
case "$ID" in
debian)
curl -fsSL https://www.mongodb.org/static/pgp/server-8.0.asc | gpg -o /usr/share/keyrings/mongodb-server-8.0.gpg --dearmor
echo "deb [ signed-by=/usr/share/keyrings/mongodb-server-8.0.gpg ] http://repo.mongodb.org/apt/debian bookworm/mongodb-org/8.0 main" | tee /etc/apt/sources.list.d/mongodb-org-8.0.list
apt update
apt install -y mongodb-org
systemctl enable --now mongod
;;
ubuntu)
curl -fsSL https://www.mongodb.org/static/pgp/server-8.0.asc | gpg -o /usr/share/keyrings/mongodb-server-8.0.gpg --dearmor
echo "deb [ signed-by=/usr/share/keyrings/mongodb-server-8.0.gpg ] http://repo.mongodb.org/apt/debian bookworm/mongodb-org/8.0 main" | tee /etc/apt/sources.list.d/mongodb-org-8.0.list
apt update
apt install -y mongodb-org
systemctl enable --now mongod
;;
centos)
cat > /etc/yum.repos.d/mongodb-org-8.0.repo <<EOF
[mongodb-org-8.0]
name=MongoDB Repository
baseurl=https://repo.mongodb.org/yum/redhat/9/mongodb-org/8.0/x86_64/
gpgcheck=1
enabled=1
gpgkey=https://pgp.mongodb.com/server-8.0.asc
EOF
yum install -y mongodb-org
systemctl enable --now mongod
;;
esac
fi
if [[ $IS_INSTALL_NAPCAT == true ]]; then
echo -e "${GREEN}安装 NapCat...${RESET}"
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
fi
echo -e "${GREEN}创建安装目录...${RESET}" echo -e "${GREEN}创建安装目录...${RESET}"
mkdir -p "$INSTALL_DIR" mkdir -p "$INSTALL_DIR"
@ -398,8 +502,8 @@ run_installation() {
# 首先计算当前隐私条款文件的哈希值 # 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(md5sum "repo/PRIVACY.md" | awk '{print $1}') current_md5_privacy=$(md5sum "repo/PRIVACY.md" | awk '{print $1}')
echo $current_md5 > repo/eula.confirmed echo -n $current_md5 > repo/eula.confirmed
echo $current_md5_privacy > repo/privacy.confirmed echo -n $current_md5_privacy > repo/privacy.confirmed
echo -e "${GREEN}创建系统服务...${RESET}" echo -e "${GREEN}创建系统服务...${RESET}"
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF

View File

@ -105,6 +105,24 @@ MOOD_STYLE_CONFIG = {
}, },
} }
# relationship
RELATION_STYLE_CONFIG = {
"advanced": {
"console_format": (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{extra[module]: <12}</cyan> | "
"<light-magenta>关系</light-magenta> | "
"<level>{message}</level>"
),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}"),
},
"simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-magenta>关系</light-magenta> | {message}"),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}"),
},
}
SENDER_STYLE_CONFIG = { SENDER_STYLE_CONFIG = {
"advanced": { "advanced": {
"console_format": ( "console_format": (
@ -122,6 +140,40 @@ SENDER_STYLE_CONFIG = {
}, },
} }
HEARTFLOW_STYLE_CONFIG = {
"advanced": {
"console_format": (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{extra[module]: <12}</cyan> | "
"<light-yellow>麦麦大脑袋</light-yellow> | "
"<level>{message}</level>"
),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
},
"simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦大脑袋</light-green> | <light-green>{message}</light-green>"), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
},
}
SCHEDULE_STYLE_CONFIG = {
"advanced": {
"console_format": (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{extra[module]: <12}</cyan> | "
"<light-yellow>在干嘛</light-yellow> | "
"<level>{message}</level>"
),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}"),
},
"simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <cyan>在干嘛</cyan> | <cyan>{message}</cyan>"),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}"),
},
}
LLM_STYLE_CONFIG = { LLM_STYLE_CONFIG = {
"advanced": { "advanced": {
"console_format": ( "console_format": (
@ -176,6 +228,26 @@ CHAT_STYLE_CONFIG = {
}, },
} }
SUB_HEARTFLOW_STYLE_CONFIG = {
"advanced": {
"console_format": (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{extra[module]: <12}</cyan> | "
"<light-blue>麦麦小脑袋</light-blue> | "
"<level>{message}</level>"
),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
},
"simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>麦麦小脑袋</light-blue> | <green>{message}</green>"), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
},
}
# 根据SIMPLE_OUTPUT选择配置 # 根据SIMPLE_OUTPUT选择配置
MEMORY_STYLE_CONFIG = MEMORY_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MEMORY_STYLE_CONFIG["advanced"] MEMORY_STYLE_CONFIG = MEMORY_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MEMORY_STYLE_CONFIG["advanced"]
TOPIC_STYLE_CONFIG = TOPIC_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else TOPIC_STYLE_CONFIG["advanced"] TOPIC_STYLE_CONFIG = TOPIC_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else TOPIC_STYLE_CONFIG["advanced"]
@ -183,6 +255,10 @@ SENDER_STYLE_CONFIG = SENDER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SENDER
LLM_STYLE_CONFIG = LLM_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else LLM_STYLE_CONFIG["advanced"] LLM_STYLE_CONFIG = LLM_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else LLM_STYLE_CONFIG["advanced"]
CHAT_STYLE_CONFIG = CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_STYLE_CONFIG["advanced"] CHAT_STYLE_CONFIG = CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_STYLE_CONFIG["advanced"]
MOOD_STYLE_CONFIG = MOOD_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MOOD_STYLE_CONFIG["advanced"] MOOD_STYLE_CONFIG = MOOD_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MOOD_STYLE_CONFIG["advanced"]
RELATION_STYLE_CONFIG = RELATION_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else RELATION_STYLE_CONFIG["advanced"]
SCHEDULE_STYLE_CONFIG = SCHEDULE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SCHEDULE_STYLE_CONFIG["advanced"]
HEARTFLOW_STYLE_CONFIG = HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else HEARTFLOW_STYLE_CONFIG["advanced"]
SUB_HEARTFLOW_STYLE_CONFIG = SUB_HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SUB_HEARTFLOW_STYLE_CONFIG["advanced"] # noqa: E501
def is_registered_module(record: dict) -> bool: def is_registered_module(record: dict) -> bool:
"""检查是否为已注册的模块""" """检查是否为已注册的模块"""

View File

@ -9,12 +9,13 @@ from ..moods.moods import MoodManager # 导入情绪管理器
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from ..utils.statistic import LLMStatistics from ..utils.statistic import LLMStatistics
from .bot import chat_bot from .bot import chat_bot
from .config import global_config from ..config.config import global_config
from .emoji_manager import emoji_manager from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from ..willing.willing_manager import willing_manager from ..willing.willing_manager import willing_manager
from .chat_stream import chat_manager from .chat_stream import chat_manager
from ..memory_system.memory import hippocampus # from ..memory_system.memory import hippocampus
from src.plugins.memory_system.Hippocampus import HippocampusManager
from .message_sender import message_manager, message_sender from .message_sender import message_manager, message_sender
from .storage import MessageStorage from .storage import MessageStorage
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
@ -59,6 +60,22 @@ async def start_think_flow():
logger.error(f"启动大脑和外部世界失败: {e}") logger.error(f"启动大脑和外部世界失败: {e}")
raise raise
async def start_memory():
"""启动记忆系统"""
try:
start_time = time.time()
logger.info("开始初始化记忆系统...")
# 使用HippocampusManager初始化海马体
hippocampus_manager = HippocampusManager.get_instance()
hippocampus_manager.initialize(global_config=global_config)
end_time = time.time()
logger.success(f"记忆系统初始化完成,耗时: {end_time - start_time:.2f}")
except Exception as e:
logger.error(f"记忆系统初始化失败: {e}")
raise
@driver.on_startup @driver.on_startup
async def start_background_tasks(): async def start_background_tasks():
@ -79,9 +96,19 @@ async def start_background_tasks():
# 只启动表情包管理任务 # 只启动表情包管理任务
asyncio.create_task(emoji_manager.start_periodic_check()) asyncio.create_task(emoji_manager.start_periodic_check())
await bot_schedule.initialize()
bot_schedule.print_schedule()
asyncio.create_task(start_memory())
@driver.on_startup
async def init_schedule():
"""在 NoneBot2 启动时初始化日程系统"""
bot_schedule.initialize(
name=global_config.BOT_NICKNAME,
personality=global_config.PROMPT_PERSONALITY,
behavior=global_config.PROMPT_SCHEDULE_GEN,
interval=global_config.SCHEDULE_DOING_UPDATE_INTERVAL)
asyncio.create_task(bot_schedule.mai_schedule_start())
@driver.on_startup @driver.on_startup
async def init_relationships(): async def init_relationships():
@ -131,14 +158,14 @@ async def _(bot: Bot, event: NoticeEvent, state: T_State):
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory") @scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task(): async def build_memory_task():
"""每build_memory_interval秒执行一次记忆构建""" """每build_memory_interval秒执行一次记忆构建"""
await hippocampus.operation_build_memory() await HippocampusManager.get_instance().build_memory()
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory") @scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
async def forget_memory_task(): async def forget_memory_task():
"""每30秒执行一次记忆构建""" """每30秒执行一次记忆构建"""
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...") print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=global_config.memory_forget_percentage) await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成") print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@ -157,13 +184,13 @@ async def print_mood_task():
mood_manager.print_mood_status() mood_manager.print_mood_status()
@scheduler.scheduled_job("interval", seconds=7200, id="generate_schedule") # @scheduler.scheduled_job("interval", seconds=7200, id="generate_schedule")
async def generate_schedule_task(): # async def generate_schedule_task():
"""每2小时尝试生成一次日程""" # """每2小时尝试生成一次日程"""
logger.debug("尝试生成日程") # logger.debug("尝试生成日程")
await bot_schedule.initialize() # await bot_schedule.initialize()
if not bot_schedule.enable_output: # if not bot_schedule.enable_output:
bot_schedule.print_schedule() # bot_schedule.print_schedule()
@scheduler.scheduled_job("interval", seconds=3600, id="remove_recalled_message") @scheduler.scheduled_job("interval", seconds=3600, id="remove_recalled_message")

View File

@ -12,9 +12,9 @@ from nonebot.adapters.onebot.v11 import (
FriendRecallNoticeEvent, FriendRecallNoticeEvent,
) )
from ..memory_system.memory import hippocampus from ..memory_system.Hippocampus import HippocampusManager
from ..moods.moods import MoodManager # 导入情绪管理器 from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config from ..config.config import global_config
from .emoji_manager import emoji_manager # 导入表情包管理器 from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator from .llm_generator import ResponseGenerator
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
@ -129,8 +129,10 @@ class ChatBot:
# 根据话题计算激活度 # 根据话题计算激活度
topic = "" topic = ""
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100 interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
logger.debug(f"{message.processed_plain_text}的激活度:{interested_rate}") message.processed_plain_text,fast_retrieval=True)
# interested_rate = 0.1
# logger.info(f"对{message.processed_plain_text}的激活度:{interested_rate}")
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") # logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
await self.storage.store_message(message, chat, topic[0] if topic else None) await self.storage.store_message(message, chat, topic[0] if topic else None)
@ -311,7 +313,7 @@ class ChatBot:
) )
# 使用情绪管理器更新情绪 # 使用情绪管理器更新情绪
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor) self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
async def handle_notice(self, event: NoticeEvent, bot: Bot) -> None: async def handle_notice(self, event: NoticeEvent, bot: Bot) -> None:
"""处理收到的通知""" """处理收到的通知"""

View File

@ -10,7 +10,7 @@ from src.common.logger import get_module_logger
from nonebot import get_driver from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from ..config.config import global_config
from .mapper import emojimapper from .mapper import emojimapper
from .message_base import Seg from .message_base import Seg
from .utils_user import get_user_nickname, get_groupname from .utils_user import get_user_nickname, get_groupname

View File

@ -12,7 +12,7 @@ import io
from nonebot import get_driver from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..config.config import global_config
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64 from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request

View File

@ -6,7 +6,7 @@ from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from ..config.config import global_config
from .message import MessageRecv, MessageThinking, Message from .message import MessageRecv, MessageThinking, Message
from .prompt_builder import prompt_builder from .prompt_builder import prompt_builder
from .utils import process_llm_response from .utils import process_llm_response
@ -51,13 +51,13 @@ class ResponseGenerator:
# 从global_config中获取模型概率值并选择模型 # 从global_config中获取模型概率值并选择模型
rand = random.random() rand = random.random()
if rand < global_config.MODEL_R1_PROBABILITY: if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = "r1" self.current_model_type = "深深地"
current_model = self.model_r1 current_model = self.model_r1
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY: elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
self.current_model_type = "v3" self.current_model_type = "浅浅的"
current_model = self.model_v3 current_model = self.model_v3
else: else:
self.current_model_type = "r1_distill" self.current_model_type = "又浅又浅的"
current_model = self.model_r1_distill current_model = self.model_r1_distill
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中") logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
@ -144,18 +144,25 @@ class ResponseGenerator:
try: try:
# 构建提示词,结合回复内容、被回复的内容以及立场分析 # 构建提示词,结合回复内容、被回复的内容以及立场分析
prompt = f""" prompt = f"""
请根据以下对话内容完成以下任务 请严格根据以下对话内容完成以下任务
1. 判断回复者的立场是"supportive"支持"opposed"反对还是"neutrality"中立 1. 判断回复者对被回复者观点的直接立场
2. "happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签 - "支持"明确同意或强化被回复者观点
3. 按照"立场-情绪"的格式输出结果例如"supportive-happy" - "反对"明确反驳或否定被回复者观点
- "中立"不表达明确立场或无关回应
2. "开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签
3. 按照"立场-情绪"的格式直接输出结果例如"反对-愤怒"
被回复的内容 对话示例
{processed_plain_text} 被回复A就是笨
回复A明明很聪明 反对-愤怒
回复内容 当前对话
{content} 被回复{processed_plain_text}
回复{content}
请分析回复者的立场和情感倾向并输出结果 输出要求
- 只需输出"立场-情绪"结果不要解释
- 严格基于文字直接表达的对立关系判断
""" """
# 调用模型生成结果 # 调用模型生成结果
@ -165,18 +172,20 @@ class ResponseGenerator:
# 解析模型输出的结果 # 解析模型输出的结果
if "-" in result: if "-" in result:
stance, emotion = result.split("-", 1) stance, emotion = result.split("-", 1)
valid_stances = ["supportive", "opposed", "neutrality"] valid_stances = ["支持", "反对", "中立"]
valid_emotions = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"] valid_emotions = ["开心", "愤怒", "悲伤", "惊讶", "害羞", "平静", "恐惧", "厌恶", "困惑"]
if stance in valid_stances and emotion in valid_emotions: if stance in valid_stances and emotion in valid_emotions:
return stance, emotion # 返回有效的立场-情绪组合 return stance, emotion # 返回有效的立场-情绪组合
else: else:
return "neutrality", "neutral" # 默认返回中立-中性 logger.debug(f"无效立场-情感组合:{result}")
return "中立", "平静" # 默认返回中立-平静
else: else:
return "neutrality", "neutral" # 格式错误时返回默认值 logger.debug(f"立场-情感格式错误:{result}")
return "中立", "平静" # 格式错误时返回默认值
except Exception as e: except Exception as e:
print(f"获取情感标签时出错: {e}") logger.debug(f"获取情感标签时出错: {e}")
return "neutrality", "neutral" # 出错时返回默认值 return "中立", "平静" # 出错时返回默认值
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]: async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
"""处理响应内容,返回处理后的内容和情感标签""" """处理响应内容,返回处理后的内容和情感标签"""

View File

@ -9,7 +9,7 @@ from .message_cq import MessageSendCQ
from .message import MessageSending, MessageThinking, MessageSet from .message import MessageSending, MessageThinking, MessageSet
from .storage import MessageStorage from .storage import MessageStorage
from .config import global_config from ..config.config import global_config
from .utils import truncate_message, calculate_typing_time from .utils import truncate_message, calculate_typing_time
from src.common.logger import LogConfig, SENDER_STYLE_CONFIG from src.common.logger import LogConfig, SENDER_STYLE_CONFIG
@ -61,6 +61,7 @@ class Message_Sender:
if not is_recalled: if not is_recalled:
typing_time = calculate_typing_time(message.processed_plain_text) typing_time = calculate_typing_time(message.processed_plain_text)
logger.info(f"麦麦正在打字,预计需要{typing_time}")
await asyncio.sleep(typing_time) await asyncio.sleep(typing_time)
message_json = message.to_dict() message_json = message.to_dict()
@ -99,7 +100,7 @@ class MessageContainer:
self.max_size = max_size self.max_size = max_size
self.messages = [] self.messages = []
self.last_send_time = 0 self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒) self.thinking_timeout = 10 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[MessageSending]: def get_timeout_messages(self) -> List[MessageSending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序""" """获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
@ -208,7 +209,7 @@ class MessageManager:
# print(thinking_time) # print(thinking_time)
if ( if (
message_earliest.is_head message_earliest.is_head
and message_earliest.update_thinking_time() > 15 and message_earliest.update_thinking_time() > 20
and not message_earliest.is_private_message() # 避免在私聊时插入reply and not message_earliest.is_private_message() # 避免在私聊时插入reply
): ):
logger.debug(f"设置回复消息{message_earliest.processed_plain_text}") logger.debug(f"设置回复消息{message_earliest.processed_plain_text}")
@ -235,7 +236,7 @@ class MessageManager:
# print(msg.is_private_message()) # print(msg.is_private_message())
if ( if (
msg.is_head msg.is_head
and msg.update_thinking_time() > 15 and msg.update_thinking_time() > 25
and not msg.is_private_message() # 避免在私聊时插入reply and not msg.is_private_message() # 避免在私聊时插入reply
): ):
logger.debug(f"设置回复消息{msg.processed_plain_text}") logger.debug(f"设置回复消息{msg.processed_plain_text}")

View File

@ -3,10 +3,10 @@ import time
from typing import Optional from typing import Optional
from ...common.database import db from ...common.database import db
from ..memory_system.memory import hippocampus, memory_graph from ..memory_system.Hippocampus import HippocampusManager
from ..moods.moods import MoodManager from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from .config import global_config from ..config.config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker from .utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker
from .chat_stream import chat_manager from .chat_stream import chat_manager
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
@ -57,9 +57,7 @@ class PromptBuilder:
mood_prompt = mood_manager.get_prompt() mood_prompt = mood_manager.get_prompt()
# 日程构建 # 日程构建
# current_date = time.strftime("%Y-%m-%d", time.localtime()) # schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}'''
# current_time = time.strftime("%H:%M:%S", time.localtime())
# bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
# 获取聊天上下文 # 获取聊天上下文
chat_in_group = True chat_in_group = True
@ -81,19 +79,26 @@ class PromptBuilder:
start_time = time.time() start_time = time.time()
# 调用 hippocampus 的 get_relevant_memories 方法 # 调用 hippocampus 的 get_relevant_memories 方法
relevant_memories = await hippocampus.get_relevant_memories( relevant_memories = await HippocampusManager.get_instance().get_memory_from_text(
text=message_txt, max_topics=3, similarity_threshold=0.5, max_memory_num=4 text=message_txt,
max_memory_num=3,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
) )
memory_str = ""
for _topic, memories in relevant_memories:
memory_str += f"{memories}\n"
# print(f"memory_str: {memory_str}")
if relevant_memories: if relevant_memories:
# 格式化记忆内容 # 格式化记忆内容
memory_str = "\n".join(m["content"] for m in relevant_memories)
memory_prompt = f"你回忆起:\n{memory_str}\n" memory_prompt = f"你回忆起:\n{memory_str}\n"
# 打印调试信息 # 打印调试信息
logger.debug("[记忆检索]找到以下相关记忆:") logger.debug("[记忆检索]找到以下相关记忆:")
for memory in relevant_memories: # for topic, memory_items, similarity in relevant_memories:
logger.debug(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}") # logger.debug(f"- 主题「{topic}」[相似度: {similarity:.2f}]: {memory_items}")
end_time = time.time() end_time = time.time()
logger.info(f"回忆耗时: {(end_time - start_time):.3f}") logger.info(f"回忆耗时: {(end_time - start_time):.3f}")
@ -173,8 +178,6 @@ class PromptBuilder:
prompt_check_if_response = "" prompt_check_if_response = ""
# print(prompt)
return prompt, prompt_check_if_response return prompt, prompt_check_if_response
def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1): def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
@ -196,7 +199,7 @@ class PromptBuilder:
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 获取主动发言的话题 # 获取主动发言的话题
all_nodes = memory_graph.dots all_nodes = HippocampusManager.get_instance().memory_graph.dots
all_nodes = filter(lambda dot: len(dot[1]["memory_items"]) > 3, all_nodes) all_nodes = filter(lambda dot: len(dot[1]["memory_items"]) > 3, all_nodes)
nodes_for_select = random.sample(all_nodes, 5) nodes_for_select = random.sample(all_nodes, 5)
topics = [info[0] for info in nodes_for_select] topics = [info[0] for info in nodes_for_select]
@ -249,7 +252,7 @@ class PromptBuilder:
related_info = "" related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message, request_type="prompt_build") embedding = await get_embedding(message, request_type="prompt_build")
related_info += self.get_info_from_db(embedding, threshold=threshold) related_info += self.get_info_from_db(embedding, limit=1, threshold=threshold)
return related_info return related_info

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
from typing import Optional from typing import Optional
from src.common.logger import get_module_logger from src.common.logger import get_module_logger, LogConfig, RELATION_STYLE_CONFIG
from ...common.database import db from ...common.database import db
from .message_base import UserInfo from .message_base import UserInfo
@ -8,7 +8,12 @@ from .chat_stream import ChatStream
import math import math
from bson.decimal128 import Decimal128 from bson.decimal128 import Decimal128
logger = get_module_logger("rel_manager") relationship_config = LogConfig(
# 使用关系专用样式
console_format=RELATION_STYLE_CONFIG["console_format"],
file_format=RELATION_STYLE_CONFIG["file_format"],
)
logger = get_module_logger("rel_manager", config=relationship_config)
class Impression: class Impression:
@ -270,19 +275,21 @@ class RelationshipManager:
3.人维护关系的精力往往有限所以当高关系值用户越多对于中高关系值用户增长越慢 3.人维护关系的精力往往有限所以当高关系值用户越多对于中高关系值用户增长越慢
""" """
stancedict = { stancedict = {
"supportive": 0, "支持": 0,
"neutrality": 1, "中立": 1,
"opposed": 2, "反对": 2,
} }
valuedict = { valuedict = {
"happy": 1.5, "开心": 1.5,
"angry": -3.0, "愤怒": -3.5,
"sad": -1.5, "悲伤": -1.5,
"surprised": 0.6, "惊讶": 0.6,
"disgusted": -4.5, "害羞": 2.0,
"fearful": -2.1, "平静": 0.3,
"neutral": 0.3, "恐惧": -2,
"厌恶": -2.5,
"困惑": 0.5,
} }
if self.get_relationship(chat_stream): if self.get_relationship(chat_stream):
old_value = self.get_relationship(chat_stream).relationship_value old_value = self.get_relationship(chat_stream).relationship_value
@ -301,9 +308,12 @@ class RelationshipManager:
if old_value > 500: if old_value > 500:
high_value_count = 0 high_value_count = 0
for _, relationship in self.relationships.items(): for _, relationship in self.relationships.items():
if relationship.relationship_value >= 850: if relationship.relationship_value >= 700:
high_value_count += 1 high_value_count += 1
value *= 3 / (high_value_count + 3) if old_value >= 700:
value *= 3 / (high_value_count + 2) # 排除自己
else:
value *= 3 / (high_value_count + 3)
elif valuedict[label] < 0 and stancedict[stance] != 0: elif valuedict[label] < 0 and stancedict[stance] != 0:
value = value * math.exp(old_value / 1000) value = value * math.exp(old_value / 1000)
else: else:
@ -316,27 +326,20 @@ class RelationshipManager:
else: else:
value = 0 value = 0
logger.info(f"[关系变更] 立场:{stance} 标签:{label} 关系值:{value}") level_num = self.calculate_level_num(old_value+value)
relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
logger.info(
f"当前关系: {relationship_level[level_num]}, "
f"关系值: {old_value:.2f}, "
f"当前立场情感: {stance}-{label}, "
f"变更: {value:+.5f}"
)
await self.update_relationship_value(chat_stream=chat_stream, relationship_value=value) await self.update_relationship_value(chat_stream=chat_stream, relationship_value=value)
def build_relationship_info(self, person) -> str: def build_relationship_info(self, person) -> str:
relationship_value = relationship_manager.get_relationship(person).relationship_value relationship_value = relationship_manager.get_relationship(person).relationship_value
if -1000 <= relationship_value < -227: level_num = self.calculate_level_num(relationship_value)
level_num = 0
elif -227 <= relationship_value < -73:
level_num = 1
elif -73 <= relationship_value < 227:
level_num = 2
elif 227 <= relationship_value < 587:
level_num = 3
elif 587 <= relationship_value < 900:
level_num = 4
elif 900 <= relationship_value <= 1000:
level_num = 5
else:
level_num = 5 if relationship_value > 1000 else 0
relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"] relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
relation_prompt2_list = [ relation_prompt2_list = [
"冷漠回应", "冷漠回应",
@ -357,5 +360,23 @@ class RelationshipManager:
f"回复态度为{relation_prompt2_list[level_num]},关系等级为{level_num}" f"回复态度为{relation_prompt2_list[level_num]},关系等级为{level_num}"
) )
def calculate_level_num(self, relationship_value) -> int:
"""关系等级计算"""
if -1000 <= relationship_value < -227:
level_num = 0
elif -227 <= relationship_value < -73:
level_num = 1
elif -73 <= relationship_value < 227:
level_num = 2
elif 227 <= relationship_value < 587:
level_num = 3
elif 587 <= relationship_value < 900:
level_num = 4
elif 900 <= relationship_value <= 1000:
level_num = 5
else:
level_num = 5 if relationship_value > 1000 else 0
return level_num
relationship_manager = RelationshipManager() relationship_manager = RelationshipManager()

View File

@ -3,7 +3,7 @@ from typing import List, Optional
from nonebot import get_driver from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from ..config.config import global_config
from src.common.logger import get_module_logger, LogConfig, TOPIC_STYLE_CONFIG from src.common.logger import get_module_logger, LogConfig, TOPIC_STYLE_CONFIG
# 定义日志配置 # 定义日志配置

View File

@ -1,4 +1,3 @@
import math
import random import random
import time import time
import re import re
@ -12,7 +11,7 @@ from src.common.logger import get_module_logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config from ..config.config import global_config
from .message import MessageRecv, Message from .message import MessageRecv, Message
from .message_base import UserInfo from .message_base import UserInfo
from .chat_stream import ChatStream from .chat_stream import ChatStream
@ -66,60 +65,6 @@ async def get_embedding(text, request_type="embedding"):
return await llm.get_embedding(text) return await llm.get_embedding(text)
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
total_chars = len(text)
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_closest_chat_from_db(length: int, timestamp: str):
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")
chat_records = []
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)])
# print(f"最接近的记录: {closest_record}")
if closest_record:
closest_time = closest_record["time"]
chat_id = closest_record["chat_id"] # 获取chat_id
# 获取该时间戳之后的length条消息保持相同的chat_id
chat_records = list(
db.messages.find(
{
"time": {"$gt": closest_time},
"chat_id": chat_id, # 添加chat_id过滤
}
)
.sort("time", 1)
.limit(length)
)
# print(f"获取到的记录: {chat_records}")
length = len(chat_records)
# print(f"获取到的记录长度: {length}")
# 转换记录格式
formatted_records = []
for record in chat_records:
# 兼容行为,前向兼容老数据
formatted_records.append(
{
"_id": record["_id"],
"time": record["time"],
"chat_id": record["chat_id"],
"detailed_plain_text": record.get("detailed_plain_text", ""), # 添加文本内容
"memorized_times": record.get("memorized_times", 0), # 添加记忆次数
}
)
return formatted_records
return []
async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list: async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录

View File

@ -9,7 +9,7 @@ import io
from nonebot import get_driver from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..config.config import global_config
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from src.common.logger import get_module_logger from src.common.logger import get_module_logger

View File

@ -1,4 +1,4 @@
from .config import global_config from ..config.config import global_config
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager

View File

@ -42,6 +42,7 @@ class BotConfig:
# schedule # schedule
ENABLE_SCHEDULE_GEN: bool = False # 是否启用日程生成 ENABLE_SCHEDULE_GEN: bool = False # 是否启用日程生成
PROMPT_SCHEDULE_GEN = "无日程" PROMPT_SCHEDULE_GEN = "无日程"
SCHEDULE_DOING_UPDATE_INTERVAL: int = 300 # 日程表更新间隔 单位秒
# message # message
MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数 MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数
@ -221,6 +222,8 @@ class BotConfig:
schedule_config = parent["schedule"] schedule_config = parent["schedule"]
config.ENABLE_SCHEDULE_GEN = schedule_config.get("enable_schedule_gen", config.ENABLE_SCHEDULE_GEN) config.ENABLE_SCHEDULE_GEN = schedule_config.get("enable_schedule_gen", config.ENABLE_SCHEDULE_GEN)
config.PROMPT_SCHEDULE_GEN = schedule_config.get("prompt_schedule_gen", config.PROMPT_SCHEDULE_GEN) config.PROMPT_SCHEDULE_GEN = schedule_config.get("prompt_schedule_gen", config.PROMPT_SCHEDULE_GEN)
config.SCHEDULE_DOING_UPDATE_INTERVAL = schedule_config.get(
"schedule_doing_update_interval", config.SCHEDULE_DOING_UPDATE_INTERVAL)
logger.info( logger.info(
f"载入自定义日程prompt:{schedule_config.get('prompt_schedule_gen', config.PROMPT_SCHEDULE_GEN)}") f"载入自定义日程prompt:{schedule_config.get('prompt_schedule_gen', config.PROMPT_SCHEDULE_GEN)}")

View File

@ -0,0 +1,55 @@
import os
from pathlib import Path
from dotenv import load_dotenv
class EnvConfig:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(EnvConfig, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self.ROOT_DIR = Path(__file__).parent.parent.parent.parent
self.load_env()
def load_env(self):
env_file = self.ROOT_DIR / '.env'
if env_file.exists():
load_dotenv(env_file)
# 根据ENVIRONMENT变量加载对应的环境文件
env_type = os.getenv('ENVIRONMENT', 'prod')
if env_type == 'dev':
env_file = self.ROOT_DIR / '.env.dev'
elif env_type == 'prod':
env_file = self.ROOT_DIR / '.env.prod'
if env_file.exists():
load_dotenv(env_file, override=True)
def get(self, key, default=None):
return os.getenv(key, default)
def get_all(self):
return dict(os.environ)
def __getattr__(self, name):
return self.get(name)
# 创建全局实例
env_config = EnvConfig()
# 导出环境变量
def get_env(key, default=None):
return os.getenv(key, default)
# 导出所有环境变量
def get_all_env():
return dict(os.environ)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
import asyncio
import time
import sys
import os
# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.plugins.config.config import global_config
async def test_memory_system():
"""测试记忆系统的主要功能"""
try:
# 初始化记忆系统
print("开始初始化记忆系统...")
hippocampus_manager = HippocampusManager.get_instance()
hippocampus_manager.initialize(global_config=global_config)
print("记忆系统初始化完成")
# 测试记忆构建
# print("开始测试记忆构建...")
# await hippocampus_manager.build_memory()
# print("记忆构建完成")
# 测试记忆检索
test_text = "千石可乐在群里聊天"
test_text = '''[03-24 10:39:37] 麦麦(ta的id:2814567326): 早说散步结果下雨改成室内运动啊
[03-24 10:39:37] 麦麦(ta的id:2814567326): [回复变量] 变量就像今天计划总变
[03-24 10:39:44] 状态异常(ta的id:535554838): 要把本地文件改成弹出来的路径吗
[03-24 10:40:35] 状态异常(ta的id:535554838): [图片这张图片显示的是Windows系统的环境变量设置界面界面左侧列出了多个环境变量的值包括Intel Dev RedistWindowsWindows PowerShellOpenSSHNVIDIA Corporation的目录等右侧有新建编辑浏览删除上移下移和编辑文本等操作按钮图片下方有一个错误提示框显示"Windows找不到文件'mongodb\\bin\\mongod.exe'。请确定文件名是否正确后,再试一次。"这意味着用户试图运行MongoDB的mongod.exe程序时系统找不到该文件这可能是因为MongoDB的安装路径未正确添加到系统环境变量中或者文件路径有误
图片的含义可能是用户正在尝试设置MongoDB的环境变量以便在命令行或其他程序中使用MongoDB如果用户正确设置了环境变量那么他们应该能够通过命令行或其他方式启动MongoDB服务]
[03-24 10:41:08] 一根猫(ta的id:108886006): [回复 麦麦 的消息: [回复某人消息] 改系统变量或者删库重配 ] [@麦麦] 我中途修改人格需要重配吗
[03-24 10:41:54] 麦麦(ta的id:2814567326): [回复[回复 麦麦 的消息: [回复某人消息] 改系统变量或者删库重配 ] [@麦麦] 我中途修改人格需要重配吗] 看情况
[03-24 10:41:54] 麦麦(ta的id:2814567326):
[03-24 10:41:54] 麦麦(ta的id:2814567326): 小改变量就行大动骨安排重配像游戏副本南度改太大会崩
[03-24 10:45:33] 霖泷(ta的id:1967075066): 话说现在思考高达一分钟
[03-24 10:45:38] 霖泷(ta的id:1967075066): 是不是哪里出问题了
[03-24 10:45:39] 艾卡(ta的id:1786525298): [表情包这张表情包展示了一个动漫角色她有着紫色的头发和大大的眼睛表情显得有些困惑或不解她的头上有一个问号进一步强调了她的疑惑整体情感表达的是困惑或不解]
[03-24 10:46:12] (ta的id:3229291803): [表情包这张表情包显示了一只手正在做"点赞"的动作通常表示赞同喜欢或支持这个表情包所表达的情感是积极的赞同的或支持的]
[03-24 10:46:37] 星野風禾(ta的id:2890165435): 还能思考高达
[03-24 10:46:39] 星野風禾(ta的id:2890165435): 什么知识库
[03-24 10:46:49] 幻凌慌てない(ta的id:2459587037): 为什么改了回复系数麦麦还是不怎么回复大佬们''' # noqa: E501
# test_text = '''千石可乐分不清AI的陪伴和人类的陪伴,是这样吗?'''
print(f"开始测试记忆检索,测试文本: {test_text}\n")
memories = await hippocampus_manager.get_memory_from_text(
text=test_text,
max_memory_num=3,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
)
await asyncio.sleep(1)
print("检索到的记忆:")
for topic, memory_items in memories:
print(f"主题: {topic}")
print(f"- {memory_items}")
# 测试记忆遗忘
# forget_start_time = time.time()
# # print("开始测试记忆遗忘...")
# await hippocampus_manager.forget_memory(percentage=0.005)
# # print("记忆遗忘完成")
# forget_end_time = time.time()
# print(f"记忆遗忘耗时: {forget_end_time - forget_start_time:.2f} 秒")
# 获取所有节点
# nodes = hippocampus_manager.get_all_node_names()
# print(f"当前记忆系统中的节点数量: {len(nodes)}")
# print("节点列表:")
# for node in nodes:
# print(f"- {node}")
except Exception as e:
print(f"测试过程中出现错误: {e}")
raise
async def main():
"""主函数"""
try:
start_time = time.time()
await test_memory_system()
end_time = time.time()
print(f"测试完成,总耗时: {end_time - start_time:.2f}")
except Exception as e:
print(f"程序执行出错: {e}")
raise
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,298 +0,0 @@
# -*- coding: utf-8 -*-
import os
import sys
import time
import jieba
import matplotlib.pyplot as plt
import networkx as nx
from dotenv import load_dotenv
from loguru import logger
# from src.common.logger import get_module_logger
# logger = get_module_logger("draw_memory")
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
print(root_path)
from src.common.database import db # noqa: E402
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), ".env.dev")
load_dotenv(env_path)
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if "memory_items" in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]["memory_items"], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
self.G.nodes[concept]["memory_items"].append(memory)
else:
self.G.nodes[concept]["memory_items"] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
# print(f"第二层: {neighbor}")
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
def store_memory(self):
for node in self.G.nodes():
dot_data = {"concept": node}
db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ""
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)]) # 调试输出
logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}"
)
if closest_record:
closest_time = closest_record["time"]
group_id = closest_record["group_id"] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort("time", 1).limit(length)
)
for record in chat_record:
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(record["time"])))
try:
displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
except (KeyError, TypeError):
# 处理缺少键或类型错误的情况
displayname = record.get("user_nickname", "") or "用户" + str(record.get("user_id", "未知"))
chat_text += f"[{time_str}] {displayname}: {record['processed_plain_text']}\n" # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self):
# 清空现有的图数据
db.graph_data.delete_many({})
# 保存节点
for node in self.G.nodes(data=True):
node_data = {
"concept": node[0],
"memory_items": node[1].get("memory_items", []), # 默认为空列表
}
db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
edge_data = {"source": edge[0], "target": edge[1]}
db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node["concept"], memory_items=memory_items)
# 加载边
edges = db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge["source"], edge["target"])
def main():
memory_graph = Memory_graph()
memory_graph.load_graph_from_db()
# 只显示一次优化后的图形
visualize_graph_lite(memory_graph)
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == "退出":
break
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
logger.debug("第一层记忆:")
for item in first_layer_items:
logger.debug(item)
logger.debug("第二层记忆:")
for item in second_layer_items:
logger.debug(item)
else:
logger.debug("未找到相关记忆。")
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
def find_topic(text, topic_num):
prompt = (
f"这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。"
f"只需要列举{topic_num}个话题就好,不要告诉我其他内容。"
)
return prompt
def topic_what(text, topic):
prompt = (
f"这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。"
f"只输出这句话就好"
)
return prompt
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
logger.debug("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
# nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 计算节点大小和颜色
node_colors = []
node_sizes = []
nodes = list(H.nodes())
# 获取最大记忆数和最大度数用于归一化
max_memories = 1
max_degree = 1
for node in nodes:
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
max_memories = max(max_memories, memory_count)
max_degree = max(max_degree, degree)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 500 + 5000 * (ratio) # 使用1.5次方函数使差异不那么明显
node_sizes.append(size)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
# 红色分量随着度数增加而增加
r = (degree / max_degree) ** 0.3
red = min(1.0, r)
# 蓝色分量随着度数减少而增加
blue = max(0.0, 1 - red)
# blue = 1
color = (red, 0.1, blue)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
nx.draw(
H,
pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=10,
font_family="SimHei",
font_weight="bold",
edge_color="gray",
width=0.5,
alpha=0.9,
)
title = "记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数"
plt.title(title, fontsize=16, fontfamily="SimHei")
plt.show()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import List
@dataclass
class MemoryConfig:
"""记忆系统配置类"""
# 记忆构建相关配置
memory_build_distribution: List[float] # 记忆构建的时间分布参数
build_memory_sample_num: int # 每次构建记忆的样本数量
build_memory_sample_length: int # 每个样本的消息长度
memory_compress_rate: float # 记忆压缩率
# 记忆遗忘相关配置
memory_forget_time: int # 记忆遗忘时间(小时)
# 记忆过滤相关配置
memory_ban_words: List[str] # 记忆过滤词列表
llm_topic_judge: str # 话题判断模型
llm_summary_by_topic: str # 话题总结模型
@classmethod
def from_global_config(cls, global_config):
"""从全局配置创建记忆系统配置"""
return cls(
memory_build_distribution=global_config.memory_build_distribution,
build_memory_sample_num=global_config.build_memory_sample_num,
build_memory_sample_length=global_config.build_memory_sample_length,
memory_compress_rate=global_config.memory_compress_rate,
memory_forget_time=global_config.memory_forget_time,
memory_ban_words=global_config.memory_ban_words,
llm_topic_judge=global_config.llm_topic_judge,
llm_summary_by_topic=global_config.llm_summary_by_topic
)

View File

@ -1,992 +0,0 @@
# -*- coding: utf-8 -*-
import datetime
import math
import os
import random
import sys
import time
from collections import Counter
from pathlib import Path
import matplotlib.pyplot as plt
import networkx as nx
from dotenv import load_dotenv
sys.path.insert(0, sys.path[0]+"/../")
sys.path.insert(0, sys.path[0]+"/../")
sys.path.insert(0, sys.path[0]+"/../")
sys.path.insert(0, sys.path[0]+"/../")
sys.path.insert(0, sys.path[0]+"/../")
from src.common.logger import get_module_logger
import jieba
# from chat.config import global_config
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import db # noqa E402
from src.plugins.memory_system.offline_llm import LLMModel # noqa E402
# 获取当前文件的目录
current_dir = Path(__file__).resolve().parent
# 获取项目根目录(上三层目录)
project_root = current_dir.parent.parent.parent
# env.dev文件路径
env_path = project_root / ".env.dev"
logger = get_module_logger("mem_manual_bd")
# 加载环境变量
if env_path.exists():
logger.info(f"{env_path} 加载环境变量")
load_dotenv(env_path)
else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
total_chars = len(text)
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Returns:
list: 消息记录字典列表每个字典包含消息内容和时间信息
"""
chat_records = []
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)])
if closest_record and closest_record.get("memorized", 0) < 4:
closest_time = closest_record["time"]
group_id = closest_record["group_id"]
# 获取该时间戳之后的length条消息且groupid相同
records = list(
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort("time", 1).limit(length)
)
# 更新每条消息的memorized属性
for record in records:
current_memorized = record.get("memorized", 0)
if current_memorized > 3:
print("消息已读取3次跳过")
return ""
# 更新memorized值
db.messages.update_one({"_id": record["_id"]}, {"$set": {"memorized": current_memorized + 1}})
# 添加到记录列表中
chat_records.append(
{"text": record["detailed_plain_text"], "time": record["time"], "group_id": record["group_id"]}
)
return chat_records
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
def connect_dot(self, concept1, concept2):
# 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2):
self.G[concept1][concept2]["strength"] = self.G[concept1][concept2].get("strength", 1) + 1
else:
# 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, strength=1)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if "memory_items" in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]["memory_items"], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
self.G.nodes[concept]["memory_items"].append(memory)
else:
self.G.nodes[concept]["memory_items"] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
# 海马体
class Hippocampus:
def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
self.llm_model_get_topic = LLMModel(model_name="Pro/Qwen/Qwen2.5-7B-Instruct")
self.llm_model_summary = LLMModel(model_name="Qwen/Qwen2.5-32B-Instruct")
def get_memory_sample(self, chat_size=20, time_frequency=None):
"""获取记忆样本
Returns:
list: 消息记录列表每个元素是一个消息记录字典列表
"""
if time_frequency is None:
time_frequency = {"near": 2, "mid": 4, "far": 3}
current_timestamp = datetime.datetime.now().timestamp()
chat_samples = []
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get("near")):
random_time = current_timestamp - random.randint(1, 3600 * 4)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get("mid")):
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get("far")):
random_time = current_timestamp - random.randint(3600 * 24, 3600 * 24 * 7)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
return chat_samples
def calculate_topic_num(self, text, compress_rate):
"""计算文本的话题数量"""
information_content = calculate_information_content(text)
topic_by_length = text.count("\n") * compress_rate
topic_by_information_content = max(1, min(5, int((information_content - 3) * 2)))
topic_num = int((topic_by_length + topic_by_information_content) / 2)
print(
f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, "
f"topic_num: {topic_num}"
)
return topic_num
async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩消息记录为记忆
Args:
messages: 消息记录字典列表每个字典包含text和time字段
compress_rate: 压缩率
Returns:
set: (话题, 记忆) 元组集合
"""
if not messages:
return set()
# 合并消息文本,同时保留时间信息
input_text = ""
time_info = ""
# 计算最早和最晚时间
earliest_time = min(msg["time"] for msg in messages)
latest_time = max(msg["time"] for msg in messages)
earliest_dt = datetime.datetime.fromtimestamp(earliest_time)
latest_dt = datetime.datetime.fromtimestamp(latest_time)
# 如果是同一年
if earliest_dt.year == latest_dt.year:
earliest_str = earliest_dt.strftime("%m-%d %H:%M:%S")
latest_str = latest_dt.strftime("%m-%d %H:%M:%S")
time_info += f"是在{earliest_dt.year}年,{earliest_str}{latest_str} 的对话:\n"
else:
earliest_str = earliest_dt.strftime("%Y-%m-%d %H:%M:%S")
latest_str = latest_dt.strftime("%Y-%m-%d %H:%M:%S")
time_info += f"是从 {earliest_str}{latest_str} 的对话:\n"
for msg in messages:
input_text += f"{msg['text']}\n"
print(input_text)
topic_num = self.calculate_topic_num(input_text, compress_rate)
topics_response = self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num))
# 过滤topics
filter_keywords = ["表情包", "图片", "回复", "聊天记录"]
topics = [
topic.strip()
for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",")
if topic.strip()
]
filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)]
# print(f"原始话题: {topics}")
print(f"过滤后话题: {filtered_topics}")
# 创建所有话题的请求任务
tasks = []
for topic in filtered_topics:
topic_what_prompt = self.topic_what(input_text, topic, time_info)
# 创建异步任务
task = self.llm_model_small.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
# 等待所有任务完成
compressed_memory = set()
for topic, task in tasks:
response = await task
if response:
compressed_memory.add((topic, response[0]))
return compressed_memory
async def operation_build_memory(self, chat_size=12):
# 最近消息获取频率
time_frequency = {"near": 3, "mid": 8, "far": 5}
memory_samples = self.get_memory_sample(chat_size, time_frequency)
all_topics = [] # 用于存储所有话题
for i, messages in enumerate(memory_samples, 1):
# 加载进度可视化
all_topics = []
progress = (i / len(memory_samples)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_samples))
bar = "" * filled_length + "-" * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
# 生成压缩后记忆
compress_rate = 0.1
compressed_memory = await self.memory_compress(messages, compress_rate)
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
# 将记忆加入到图谱中
for topic, memory in compressed_memory:
print(f"\033[1;32m添加节点\033[0m: {topic}")
self.memory_graph.add_dot(topic, memory)
all_topics.append(topic)
# 连接相关话题
for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)):
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db()
def sync_memory_from_db(self):
"""
从数据库同步数据到内存中的图结构
将清空当前内存中的图并从数据库重新加载所有节点和边
"""
# 清空当前图
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = db.graph_data.nodes.find()
for node in nodes:
concept = node["concept"]
memory_items = node.get("memory_items", [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 添加节点到图中
self.memory_graph.G.add_node(concept, memory_items=memory_items)
# 从数据库加载所有边
edges = db.graph_data.edges.find()
for edge in edges:
source = edge["source"]
target = edge["target"]
strength = edge.get("strength", 1) # 获取 strength默认为 1
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, strength=strength)
logger.success("从数据库同步记忆图谱完成")
def calculate_node_hash(self, concept, memory_items):
"""
计算节点的特征值
"""
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 将记忆项排序以确保相同内容生成相同的哈希值
sorted_items = sorted(memory_items)
# 组合概念和记忆项生成特征值
content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content)
def calculate_edge_hash(self, source, target):
"""
计算边的特征值
"""
# 对源节点和目标节点排序以确保相同的边生成相同的哈希值
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
def sync_memory_to_db(self):
"""
检查并同步内存中的图结构与数据库
使用特征值(哈希值)快速判断是否需要更新
"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node["concept"]: node for node in db_nodes}
# 检查并更新节点
for concept, data in memory_nodes:
memory_items = data.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items)
if concept not in db_nodes_dict:
# 数据库中缺少的节点,添加
# logger.info(f"添加新节点: {concept}")
node_data = {"concept": concept, "memory_items": memory_items, "hash": memory_hash}
db.graph_data.nodes.insert_one(node_data)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
db_hash = db_node.get("hash", None)
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
# logger.info(f"更新节点内容: {concept}")
db.graph_data.nodes.update_one(
{"concept": concept}, {"$set": {"memory_items": memory_items, "hash": memory_hash}}
)
# 检查并删除数据库中多余的节点
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node["concept"] not in memory_concepts:
# logger.info(f"删除多余节点: {db_node['concept']}")
db.graph_data.nodes.delete_one({"concept": db_node["concept"]})
# 处理边的信息
db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges())
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.calculate_edge_hash(edge["source"], edge["target"])
db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "num": edge.get("num", 1)}
# 检查并更新边
for source, target in memory_edges:
edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target)
if edge_key not in db_edge_dict:
# 添加新边
logger.info(f"添加新边: {source} - {target}")
edge_data = {"source": source, "target": target, "num": 1, "hash": edge_hash}
db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]["hash"] != edge_hash:
logger.info(f"更新边: {source} - {target}")
db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": {"hash": edge_hash}})
# 删除多余的边
memory_edge_set = set(memory_edges)
for edge_key in db_edge_dict:
if edge_key not in memory_edge_set:
source, target = edge_key
logger.info(f"删除多余边: {source} - {target}")
db.graph_data.edges.delete_one({"source": source, "target": target})
logger.success("完成记忆图谱与数据库的差异同步")
def find_topic_llm(self, text, topic_num):
prompt = (
f"这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
)
return prompt
def topic_what(self, text, topic, time_info):
# 获取当前时间
prompt = (
f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
f"可以包含时间和人物,以及具体的观点。只输出这句话就好"
)
return prompt
def remove_node_from_db(self, topic):
"""
从数据库中删除指定节点及其相关的边
Args:
topic: 要删除的节点概念
"""
# 删除节点
db.graph_data.nodes.delete_one({"concept": topic})
# 删除所有涉及该节点的边
db.graph_data.edges.delete_many({"$or": [{"source": topic}, {"target": topic}]})
def forget_topic(self, topic):
"""
随机删除指定话题中的一条记忆如果话题没有记忆则移除该话题节点
只在内存中的图上操作不直接与数据库交互
Args:
topic: 要删除记忆的话题
Returns:
removed_item: 被删除的记忆项如果没有删除任何记忆则返回 None
"""
if topic not in self.memory_graph.G:
return None
# 获取话题节点数据
node_data = self.memory_graph.G.nodes[topic]
# 如果节点存在memory_items
if "memory_items" in node_data:
memory_items = node_data["memory_items"]
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果有记忆项可以删除
if memory_items:
# 随机选择一个记忆项删除
removed_item = random.choice(memory_items)
memory_items.remove(removed_item)
# 更新节点的记忆项
if memory_items:
self.memory_graph.G.nodes[topic]["memory_items"] = memory_items
else:
# 如果没有记忆项了,删除整个节点
self.memory_graph.G.remove_node(topic)
return removed_item
return None
async def operation_forget_topic(self, percentage=0.1):
"""
随机选择图中一定比例的节点进行检查根据条件决定是否遗忘
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
forgotten_nodes = []
for node in nodes_to_check:
# 获取节点的连接数
connections = self.memory_graph.G.degree(node)
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 检查连接强度
weak_connections = True
if connections > 1: # 只有当连接数大于1时才检查强度
for neighbor in self.memory_graph.G.neighbors(node):
strength = self.memory_graph.G[node][neighbor].get("strength", 1)
if strength > 2:
weak_connections = False
break
# 如果满足遗忘条件
if (connections <= 1 and weak_connections) or content_count <= 2:
removed_item = self.forget_topic(node)
if removed_item:
forgotten_nodes.append((node, removed_item))
logger.info(f"遗忘节点 {node} 的记忆: {removed_item}")
# 同步到数据库
if forgotten_nodes:
self.sync_memory_to_db()
logger.info(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
else:
logger.info("本次检查没有节点满足遗忘条件")
async def merge_memory(self, topic):
"""
对指定话题的记忆进行合并压缩
Args:
topic: 要合并的话题节点
"""
# 获取节点的记忆项
memory_items = self.memory_graph.G.nodes[topic].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果记忆项不足,直接返回
if len(memory_items) < 10:
return
# 随机选择10条记忆
selected_memories = random.sample(memory_items, 10)
# 拼接成文本
merged_text = "\n".join(selected_memories)
print(f"\n[合并记忆] 话题: {topic}")
print(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(selected_memories, 0.1)
# 从原记忆列表中移除被选中的记忆
for memory in selected_memories:
memory_items.remove(memory)
# 添加新的压缩记忆
for _, compressed_memory in compressed_memories:
memory_items.append(compressed_memory)
print(f"添加压缩记忆: {compressed_memory}")
# 更新节点的记忆项
self.memory_graph.G.nodes[topic]["memory_items"] = memory_items
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
async def operation_merge_memory(self, percentage=0.1):
"""
随机检查一定比例的节点对内容数量超过100的节点进行记忆合并
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
merged_nodes = []
for node in nodes_to_check:
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 如果内容数量超过100进行合并
if content_count > 100:
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
await self.merge_memory(node)
merged_nodes.append(node)
# 同步到数据库
if merged_nodes:
self.sync_memory_to_db()
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
else:
print("\n本次检查没有需要合并的节点")
async def _identify_topics(self, text: str) -> list:
"""从文本中识别可能的主题"""
topics_response = self.llm_model_get_topic.generate_response(self.find_topic_llm(text, 5))
topics = [
topic.strip()
for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",")
if topic.strip()
]
return topics
def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list:
"""查找与给定主题相似的记忆主题"""
all_memory_topics = list(self.memory_graph.G.nodes())
all_similar_topics = []
for topic in topics:
if debug_info:
pass
topic_vector = text_to_vector(topic)
for memory_topic in all_memory_topics:
memory_vector = text_to_vector(memory_topic)
all_words = set(topic_vector.keys()) | set(memory_vector.keys())
v1 = [topic_vector.get(word, 0) for word in all_words]
v2 = [memory_vector.get(word, 0) for word in all_words]
similarity = cosine_similarity(v1, v2)
if similarity >= similarity_threshold:
all_similar_topics.append((memory_topic, similarity))
return all_similar_topics
def _get_top_topics(self, similar_topics: list, max_topics: int = 5) -> list:
"""获取相似度最高的主题"""
seen_topics = set()
top_topics = []
for topic, score in sorted(similar_topics, key=lambda x: x[1], reverse=True):
if topic not in seen_topics and len(top_topics) < max_topics:
seen_topics.add(topic)
top_topics.append((topic, score))
return top_topics
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
"""计算输入文本对记忆的激活程度"""
logger.info(f"[记忆激活]识别主题: {await self._identify_topics(text)}")
identified_topics = await self._identify_topics(text)
if not identified_topics:
return 0
all_similar_topics = self._find_similar_topics(
identified_topics, similarity_threshold=similarity_threshold, debug_info="记忆激活"
)
if not all_similar_topics:
return 0
top_topics = self._get_top_topics(all_similar_topics, max_topics)
if len(top_topics) == 1:
topic, score = top_topics[0]
memory_items = self.memory_graph.G.nodes[topic].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
penalty = 1.0 / (1 + math.log(content_count + 1))
activation = int(score * 50 * penalty)
print(
f"\033[1;32m[记忆激活]\033[0m 单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, "
f"激活值: {activation}"
)
return activation
matched_topics = set()
topic_similarities = {}
for memory_topic, _similarity in top_topics:
memory_items = self.memory_graph.G.nodes[memory_topic].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
penalty = 1.0 / (1 + math.log(content_count + 1))
for input_topic in identified_topics:
topic_vector = text_to_vector(input_topic)
memory_vector = text_to_vector(memory_topic)
all_words = set(topic_vector.keys()) | set(memory_vector.keys())
v1 = [topic_vector.get(word, 0) for word in all_words]
v2 = [memory_vector.get(word, 0) for word in all_words]
sim = cosine_similarity(v1, v2)
if sim >= similarity_threshold:
matched_topics.add(input_topic)
adjusted_sim = sim * penalty
topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim)
print(
f"\033[1;32m[记忆激活]\033[0m 主题「{input_topic}」-> "
f"{memory_topic}」(内容数: {content_count}, "
f"相似度: {adjusted_sim:.3f})"
)
topic_match = len(matched_topics) / len(identified_topics)
average_similarities = sum(topic_similarities.values()) / len(topic_similarities) if topic_similarities else 0
activation = int((topic_match + average_similarities) / 2 * 100)
print(
f"\033[1;32m[记忆激活]\033[0m 匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, "
f"激活值: {activation}"
)
return activation
async def get_relevant_memories(
self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, max_memory_num: int = 5
) -> list:
"""根据输入文本获取相关的记忆内容"""
identified_topics = await self._identify_topics(text)
all_similar_topics = self._find_similar_topics(
identified_topics, similarity_threshold=similarity_threshold, debug_info="记忆检索"
)
relevant_topics = self._get_top_topics(all_similar_topics, max_topics)
relevant_memories = []
for topic, score in relevant_topics:
first_layer, _ = self.memory_graph.get_related_item(topic, depth=1)
if first_layer:
if len(first_layer) > max_memory_num / 2:
first_layer = random.sample(first_layer, max_memory_num // 2)
for memory in first_layer:
relevant_memories.append({"topic": topic, "similarity": score, "content": memory})
relevant_memories.sort(key=lambda x: x["similarity"], reverse=True)
if len(relevant_memories) > max_memory_num:
relevant_memories = random.sample(relevant_memories, max_memory_num)
return relevant_memories
def segment_text(text):
"""使用jieba进行文本分词"""
seg_text = list(jieba.cut(text))
return seg_text
def text_to_vector(text):
"""将文本转换为词频向量"""
words = segment_text(text)
vector = {}
for word in words:
vector[word] = vector.get(word, 0) + 1
return vector
def cosine_similarity(v1, v2):
"""计算两个向量的余弦相似度"""
dot_product = sum(a * b for a, b in zip(v1, v2))
norm1 = math.sqrt(sum(a * a for a in v1))
norm2 = math.sqrt(sum(b * b for b in v2))
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 过滤掉内容数量小于2的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
if memory_count < 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果没有符合条件的节点,直接返回
if len(H.nodes()) == 0:
print("没有找到内容数量大于等于2的节点")
return
# 计算节点大小和颜色
node_colors = []
node_sizes = []
nodes = list(H.nodes())
# 获取最大记忆数用于归一化节点大小
max_memories = 1
for node in nodes:
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
max_memories = max(max_memories, memory_count)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 400 + 2000 * (ratio**2) # 增大节点大小
node_sizes.append(size)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
if degree >= 30:
node_colors.append((1.0, 0, 0)) # 亮红色 (#FF0000)
else:
# 将1-10映射到0-1的范围
color_ratio = (degree - 1) / 29.0 if degree > 1 else 0
# 使用蓝到红的渐变
red = min(0.9, color_ratio)
blue = max(0.0, 1.0 - color_ratio)
node_colors.append((red, 0, blue))
# 绘制图形
plt.figure(figsize=(16, 12)) # 减小图形尺寸
pos = nx.spring_layout(
H,
k=1, # 调整节点间斥力
iterations=100, # 增加迭代次数
scale=1.5, # 减小布局尺寸
weight="strength",
) # 使用边的strength属性作为权重
nx.draw(
H,
pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=12, # 保持增大的字体大小
font_family="SimHei",
font_weight="bold",
edge_color="gray",
width=1.5,
) # 统一的边宽度
title = """记忆图谱可视化仅显示内容≥2的节点
节点大小表示记忆数量
节点颜色(弱连接)到红(强连接)渐变边的透明度表示连接强度
连接强度越大的节点距离越近"""
plt.title(title, fontsize=16, fontfamily="SimHei")
plt.show()
async def main():
start_time = time.time()
test_pare = {
"do_build_memory": False,
"do_forget_topic": False,
"do_visualize_graph": True,
"do_query": False,
"do_merge_memory": False,
}
# 创建记忆图
memory_graph = Memory_graph()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
# 从数据库同步数据
hippocampus.sync_memory_from_db()
end_time = time.time()
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
if test_pare["do_build_memory"]:
logger.info("开始构建记忆...")
chat_size = 20
await hippocampus.operation_build_memory(chat_size=chat_size)
end_time = time.time()
logger.info(
f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m"
)
if test_pare["do_forget_topic"]:
logger.info("开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare["do_merge_memory"]:
logger.info("开始合并记忆...")
await hippocampus.operation_merge_memory(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare["do_visualize_graph"]:
# 展示优化后的图形
logger.info("生成记忆图谱可视化...")
print("\n生成优化后的记忆图谱:")
visualize_graph_lite(memory_graph)
if test_pare["do_query"]:
# 交互式查询
while True:
query = input("\n请输入新的查询概念(输入'退出'以结束):")
if query.lower() == "退出":
break
items_list = memory_graph.get_related_item(query)
if items_list:
first_layer, second_layer = items_list
if first_layer:
print("\n直接相关的记忆:")
for item in first_layer:
print(f"- {item}")
if second_layer:
print("\n间接相关的记忆:")
for item in second_layer:
print(f"- {item}")
else:
print("未找到相关记忆。")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@ -10,7 +10,7 @@ from src.common.logger import get_module_logger
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")
class LLMModel: class LLM_request_off:
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs): def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name self.model_name = model_name
self.params = kwargs self.params = kwargs

View File

@ -6,15 +6,13 @@ from typing import Tuple, Union
import aiohttp import aiohttp
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from nonebot import get_driver
import base64 import base64
from PIL import Image from PIL import Image
import io import io
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..config.config import global_config
from ..config.config_env import env_config
driver = get_driver()
config = driver.config
logger = get_module_logger("model_utils") logger = get_module_logger("model_utils")
@ -34,8 +32,9 @@ class LLM_request:
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值 # 将大写的配置键转换为小写并从config中获取实际值
try: try:
self.api_key = getattr(config, model["key"]) self.api_key = getattr(env_config, model["key"])
self.base_url = getattr(config, model["base_url"]) self.base_url = getattr(env_config, model["base_url"])
# print(self.api_key, self.base_url)
except AttributeError as e: except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}") logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")

View File

@ -3,7 +3,7 @@ import threading
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from ..chat.config import global_config from ..config.config import global_config
from src.common.logger import get_module_logger, LogConfig, MOOD_STYLE_CONFIG from src.common.logger import get_module_logger, LogConfig, MOOD_STYLE_CONFIG
mood_config = LogConfig( mood_config = LogConfig(
@ -55,13 +55,15 @@ class MoodManager:
# 情绪词映射表 (valence, arousal) # 情绪词映射表 (valence, arousal)
self.emotion_map = { self.emotion_map = {
"happy": (0.8, 0.6), # 高愉悦度,中等唤醒度 "开心": (0.8, 0.6), # 高愉悦度,中等唤醒度
"angry": (-0.7, 0.7), # 负愉悦度,高唤醒度 "愤怒": (-0.7, 0.7), # 负愉悦度,高唤醒度
"sad": (-0.6, 0.3), # 负愉悦度,低唤醒度 "悲伤": (-0.6, 0.3), # 负愉悦度,低唤醒度
"surprised": (0.4, 0.8), # 中等愉悦度,高唤醒度 "惊讶": (0.2, 0.8), # 中等愉悦度,高唤醒度
"disgusted": (-0.8, 0.5), # 高负愉悦度,中等唤醒度 "害羞": (0.5, 0.2), # 中等愉悦度,低唤醒度
"fearful": (-0.7, 0.6), # 负愉悦度,高唤醒度 "平静": (0.0, 0.5), # 中性愉悦度,中等唤醒度
"neutral": (0.0, 0.5), # 中性愉悦度,中等唤醒度 "恐惧": (-0.7, 0.6), # 负愉悦度,高唤醒度
"厌恶": (-0.4, 0.4), # 负愉悦度,低唤醒度
"困惑": (0.0, 0.6), # 中性愉悦度,高唤醒度
} }
# 情绪文本映射表 # 情绪文本映射表

View File

@ -6,7 +6,7 @@ import os
import json import json
import threading import threading
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.plugins.chat.config import global_config from src.plugins.config.config import global_config
logger = get_module_logger("remote") logger = get_module_logger("remote")

View File

@ -1,10 +1,7 @@
import asyncio import asyncio
import os import os
import time
from typing import Tuple, Union
import aiohttp import aiohttp
import requests
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")
@ -22,57 +19,7 @@ class LLMModel:
logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url
def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]: async def generate_response_async(self, prompt: str) -> str:
"""根据输入的提示生成模型的响应"""
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params,
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3
base_wait_time = 15 # 基础等待时间(秒)
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data)
if response.status_code == 429:
wait_time = base_wait_time * (2**retry) # 指数退避
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2**retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
logger.error(f"请求失败: {str(e)}")
return f"请求失败: {str(e)}", ""
logger.error("达到最大重试次数,请求仍然失败")
return "达到最大重试次数,请求仍然失败", ""
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""异步方式根据输入的提示生成模型的响应""" """异步方式根据输入的提示生成模型的响应"""
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
@ -80,7 +27,7 @@ class LLMModel:
data = { data = {
"model": self.model_name, "model": self.model_name,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"temperature": 0.5, "temperature": 0.7,
**self.params, **self.params,
} }

View File

@ -1,159 +1,154 @@
import datetime import datetime
import json import os
import re import sys
from typing import Dict, Union from typing import Dict
import asyncio
from nonebot import get_driver
# 添加项目根目录到 Python 路径 # 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.plugins.chat.config import global_config from src.common.database import db # noqa: E402
from ...common.database import db # 使用正确的导入语法 from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfig # noqa: E402
from ..models.utils_model import LLM_request from src.plugins.models.utils_model import LLM_request # noqa: E402
from src.common.logger import get_module_logger from src.plugins.config.config import global_config # noqa: E402
logger = get_module_logger("scheduler")
driver = get_driver() schedule_config = LogConfig(
config = driver.config # 使用海马体专用样式
console_format=SCHEDULE_STYLE_CONFIG["console_format"],
file_format=SCHEDULE_STYLE_CONFIG["file_format"],
)
logger = get_module_logger("scheduler", config=schedule_config)
class ScheduleGenerator: class ScheduleGenerator:
enable_output: bool = True # enable_output: bool = True
def __init__(self): def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型 # 使用离线LLM模型
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9) self.llm_scheduler_all = LLM_request(
self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9, request_type="scheduler") model= global_config.llm_reasoning, temperature=0.9, max_tokens=7000,request_type="schedule")
self.llm_scheduler_doing = LLM_request(
model= global_config.llm_normal, temperature=0.9, max_tokens=2048,request_type="schedule")
self.today_schedule_text = "" self.today_schedule_text = ""
self.today_schedule = {} self.today_done_list = []
self.tomorrow_schedule_text = ""
self.tomorrow_schedule = {}
self.yesterday_schedule_text = "" self.yesterday_schedule_text = ""
self.yesterday_schedule = {} self.yesterday_done_list = []
async def initialize(self): self.name = ""
self.personality = ""
self.behavior = ""
self.start_time = datetime.datetime.now()
self.schedule_doing_update_interval = 300 #最好大于60
def initialize(
self,name: str = "bot_name",
personality: str = "你是一个爱国爱党的新时代青年",
behavior: str = "你非常外向,喜欢尝试新事物和人交流",
interval: int = 60):
"""初始化日程系统"""
self.name = name
self.behavior = behavior
self.schedule_doing_update_interval = interval
for pers in personality:
self.personality += pers + "\n"
async def mai_schedule_start(self):
"""启动日程系统每5分钟执行一次move_doing并在日期变化时重新检查日程"""
try:
logger.info(f"日程系统启动/刷新时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
# 初始化日程
await self.check_and_create_today_schedule()
self.print_schedule()
while True:
print(self.get_current_num_task(1, True))
current_time = datetime.datetime.now()
# 检查是否需要重新生成日程(日期变化)
if current_time.date() != self.start_time.date():
logger.info("检测到日期变化,重新生成日程")
self.start_time = current_time
await self.check_and_create_today_schedule()
self.print_schedule()
# 执行当前活动
# mind_thinking = subheartflow_manager.current_state.current_mind
await self.move_doing()
await asyncio.sleep(self.schedule_doing_update_interval)
except Exception as e:
logger.error(f"日程系统运行时出错: {str(e)}")
logger.exception("详细错误信息:")
async def check_and_create_today_schedule(self):
"""检查昨天的日程,并确保今天有日程安排
Returns:
tuple: (today_schedule_text, today_schedule) 今天的日程文本和解析后的日程字典
"""
today = datetime.datetime.now() today = datetime.datetime.now()
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1) yesterday = today - datetime.timedelta(days=1)
yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
self.today_schedule_text, self.today_schedule = await self.generate_daily_schedule(target_date=today) # 先检查昨天的日程
self.tomorrow_schedule_text, self.tomorrow_schedule = await self.generate_daily_schedule( self.yesterday_schedule_text, self.yesterday_done_list = self.load_schedule_from_db(yesterday)
target_date=tomorrow, read_only=True if self.yesterday_schedule_text:
) logger.debug(f"已加载{yesterday.strftime('%Y-%m-%d')}的日程")
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(
target_date=yesterday, read_only=True
)
async def generate_daily_schedule( # 检查今天的日程
self, target_date: datetime.datetime = None, read_only: bool = False self.today_schedule_text, self.today_done_list = self.load_schedule_from_db(today)
) -> Dict[str, str]: if not self.today_done_list:
self.today_done_list = []
if not self.today_schedule_text:
logger.info(f"{today.strftime('%Y-%m-%d')}的日程不存在,准备生成新的日程")
self.today_schedule_text = await self.generate_daily_schedule(target_date=today)
self.save_today_schedule_to_db()
def construct_daytime_prompt(self, target_date: datetime.datetime):
date_str = target_date.strftime("%Y-%m-%d") date_str = target_date.strftime("%Y-%m-%d")
weekday = target_date.strftime("%A") weekday = target_date.strftime("%A")
schedule_text = str prompt = f"你是{self.name}{self.personality}{self.behavior}"
prompt += f"你昨天的日程是:{self.yesterday_schedule_text}\n"
prompt += f"请为你生成{date_str}{weekday})的日程安排,结合你的个人特点和行为习惯\n"
prompt += "推测你的日程安排包括你一天都在做什么从起床到睡眠有什么发现和思考具体一些详细一些需要1500字以上精确到每半个小时记得写明时间\n" #noqa: E501
prompt += "直接返回你的日程,从起床到睡觉,不要输出其他内容:"
return prompt
existing_schedule = db.schedule.find_one({"date": date_str}) def construct_doing_prompt(self,time: datetime.datetime,mind_thinking: str = ""):
if existing_schedule: now_time = time.strftime("%H:%M")
if self.enable_output: if self.today_done_list:
logger.debug(f"{date_str}的日程已存在:") previous_doings = self.get_current_num_task(5, True)
schedule_text = existing_schedule["schedule"] # print(previous_doings)
# print(self.schedule_text)
elif not read_only:
logger.debug(f"{date_str}的日程不存在,准备生成新的日程。")
prompt = (
f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:"""
+ """
1. 早上的学习和工作安排
2. 下午的活动和任务
3. 晚上的计划和休息时间
请按照时间顺序列出具体时间点和对应的活动用一个时间点而不是时间段来表示时间用JSON格式返回日程表
仅返回内容不要返回注释不要添加任何markdown或代码块样式时间采用24小时制
格式为{"时间": "活动","时间": "活动",...}"""
)
try:
schedule_text, _, _ = await self.llm_scheduler.generate_response(prompt)
db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
self.enable_output = True
except Exception as e:
logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了"
# print(self.schedule_text)
else: else:
if self.enable_output: previous_doings = "你没做什么事情"
logger.debug(f"{date_str}的日程不存在。")
schedule_text = "忘了"
return schedule_text, None
schedule_form = self._parse_schedule(schedule_text) prompt = f"你是{self.name}{self.personality}{self.behavior}"
return schedule_text, schedule_form prompt += f"你今天的日程是:{self.today_schedule_text}\n"
prompt += f"你之前做了的事情是:{previous_doings},从之前到现在已经过去了{self.schedule_doing_update_interval/60}分钟了\n" #noqa: E501
if mind_thinking:
prompt += f"你脑子里在想:{mind_thinking}\n"
prompt += f"现在是{now_time},结合你的个人特点和行为习惯,注意关注你今天的日程安排和想法,这很重要,"
prompt += "推测你现在在做什么,具体一些,详细一些\n"
prompt += "直接返回你在做的事情,注意是当前时间,不要输出其他内容:"
return prompt
def _parse_schedule(self, schedule_text: str) -> Union[bool, Dict[str, str]]: async def generate_daily_schedule(
"""解析日程文本,转换为时间和活动的字典""" self, target_date: datetime.datetime = None,) -> Dict[str, str]:
try: daytime_prompt = self.construct_daytime_prompt(target_date)
reg = r"\{(.|\r|\n)+\}" daytime_response,_ = await self.llm_scheduler_all.generate_response_async(daytime_prompt)
matched = re.search(reg, schedule_text)[0] return daytime_response
schedule_dict = json.loads(matched)
self._check_schedule_validity(schedule_dict)
return schedule_dict
except json.JSONDecodeError:
logger.exception("解析日程失败: {}".format(schedule_text))
return False
except ValueError as e:
logger.exception(f"解析日程失败: {str(e)}")
return False
except Exception as e:
logger.exception(f"解析日程发生错误:{str(e)}")
return False
def _check_schedule_validity(self, schedule_dict: Dict[str, str]):
"""检查日程是否合法"""
if not schedule_dict:
return
for time_str in schedule_dict.keys():
try:
self._parse_time(time_str)
except ValueError:
raise ValueError("日程时间格式不正确") from None
def _parse_time(self, time_str: str) -> str:
"""解析时间字符串,转换为时间"""
return datetime.datetime.strptime(time_str, "%H:%M")
def get_current_task(self) -> str:
"""获取当前时间应该进行的任务"""
current_time = datetime.datetime.now().strftime("%H:%M")
# 找到最接近当前时间的任务
closest_time = None
min_diff = float("inf")
# 检查今天的日程
if not self.today_schedule:
return "摸鱼"
for time_str in self.today_schedule.keys():
diff = abs(self._time_diff(current_time, time_str))
if closest_time is None or diff < min_diff:
closest_time = time_str
min_diff = diff
# 检查昨天的日程中的晚间任务
if self.yesterday_schedule:
for time_str in self.yesterday_schedule.keys():
if time_str >= "20:00": # 只考虑晚上8点之后的任务
# 计算与昨天这个时间点的差异需要加24小时
diff = abs(self._time_diff(current_time, time_str))
if diff < min_diff:
closest_time = time_str
min_diff = diff
return closest_time, self.yesterday_schedule[closest_time]
if closest_time:
return closest_time, self.today_schedule[closest_time]
return "摸鱼"
def _time_diff(self, time1: str, time2: str) -> int: def _time_diff(self, time1: str, time2: str) -> int:
"""计算两个时间字符串之间的分钟差""" """计算两个时间字符串之间的分钟差"""
@ -174,14 +169,138 @@ class ScheduleGenerator:
def print_schedule(self): def print_schedule(self):
"""打印完整的日程安排""" """打印完整的日程安排"""
if not self._parse_schedule(self.today_schedule_text): if not self.today_schedule_text:
logger.warning("今日日程有误,将在两小时后重新生成") logger.warning("今日日程有误,将在下次运行时重新生成")
db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else: else:
logger.info("=== 今日日程安排 ===") logger.info("=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items(): logger.info(self.today_schedule_text)
logger.info(f"时间[{time_str}]: 活动[{activity}]")
logger.info("==================") logger.info("==================")
self.enable_output = False self.enable_output = False
async def update_today_done_list(self):
# 更新数据库中的 today_done_list
today_str = datetime.datetime.now().strftime("%Y-%m-%d")
existing_schedule = db.schedule.find_one({"date": today_str})
if existing_schedule:
# 更新数据库中的 today_done_list
db.schedule.update_one(
{"date": today_str},
{"$set": {"today_done_list": self.today_done_list}}
)
logger.debug(f"已更新{today_str}的已完成活动列表")
else:
logger.warning(f"未找到{today_str}的日程记录")
async def move_doing(self,mind_thinking: str = ""):
current_time = datetime.datetime.now()
if mind_thinking:
doing_prompt = self.construct_doing_prompt(current_time,mind_thinking)
else:
doing_prompt = self.construct_doing_prompt(current_time)
# print(doing_prompt)
doing_response,_ = await self.llm_scheduler_doing.generate_response_async(doing_prompt)
self.today_done_list.append((current_time,doing_response))
await self.update_today_done_list()
logger.info(f"当前活动: {doing_response}")
return doing_response
async def get_task_from_time_to_time(self, start_time: str, end_time: str):
"""获取指定时间范围内的任务列表
Args:
start_time (str): 开始时间格式为"HH:MM"
end_time (str): 结束时间格式为"HH:MM"
Returns:
list: 时间范围内的任务列表
"""
result = []
for task in self.today_done_list:
task_time = task[0] # 获取任务的时间戳
task_time_str = task_time.strftime("%H:%M")
# 检查任务时间是否在指定范围内
if self._time_diff(start_time, task_time_str) >= 0 and self._time_diff(task_time_str, end_time) >= 0:
result.append(task)
return result
def get_current_num_task(self, num=1, time_info = False):
"""获取最新加入的指定数量的日程
Args:
num (int): 需要获取的日程数量默认为1
Returns:
list: 最新加入的日程列表
"""
if not self.today_done_list:
return []
# 确保num不超过列表长度
num = min(num, len(self.today_done_list))
pre_doings = ""
for doing in self.today_done_list[-num:]:
if time_info:
time_str = doing[0].strftime("%H:%M")
pre_doings += time_str + "时," + doing[1] + "\n"
else:
pre_doings += doing[1] + "\n"
# 返回最新的num条日程
return pre_doings
def save_today_schedule_to_db(self):
"""保存日程到数据库,同时初始化 today_done_list"""
date_str = datetime.datetime.now().strftime("%Y-%m-%d")
schedule_data = {
"date": date_str,
"schedule": self.today_schedule_text,
"today_done_list": self.today_done_list if hasattr(self, 'today_done_list') else []
}
# 使用 upsert 操作,如果存在则更新,不存在则插入
db.schedule.update_one(
{"date": date_str},
{"$set": schedule_data},
upsert=True
)
logger.debug(f"已保存{date_str}的日程到数据库")
def load_schedule_from_db(self, date: datetime.datetime):
"""从数据库加载日程,同时加载 today_done_list"""
date_str = date.strftime("%Y-%m-%d")
existing_schedule = db.schedule.find_one({"date": date_str})
if existing_schedule:
schedule_text = existing_schedule["schedule"]
return schedule_text, existing_schedule.get("today_done_list", [])
else:
logger.debug(f"{date_str}的日程不存在")
return None, None
async def main():
# 使用示例
scheduler = ScheduleGenerator()
scheduler.initialize(
name="麦麦",
personality="你叫麦麦你19岁是一个大二的女大学生你有一头黑色短发你会刷贴吧你现在在学习心理学",
behavior="你比较内向一般熬夜比较晚然后第二天早上10点起床吃早午饭",
interval=60)
await scheduler.mai_schedule_start()
# 当作为组件导入时使用的实例 # 当作为组件导入时使用的实例
bot_schedule = ScheduleGenerator() bot_schedule = ScheduleGenerator()
if __name__ == "__main__":
import asyncio
# 当直接运行此文件时执行
asyncio.run(main())

View File

@ -1,222 +0,0 @@
import datetime
import json
import re
import os
import sys
from typing import Dict, Union
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import db # noqa: E402
from src.common.logger import get_module_logger # noqa: E402
from src.plugins.schedule.offline_llm import LLMModel # noqa: E402
logger = get_module_logger("scheduler")
class ScheduleGenerator:
enable_output: bool = True
def __init__(self, name: str = "bot_name", personality: str = "你是一个爱国爱党的新时代青年", behavior: str = "你非常外向,喜欢尝试新事物和人交流"):
# 使用离线LLM模型
self.llm_scheduler = LLMModel(model_name="Pro/deepseek-ai/DeepSeek-V3", temperature=0.9)
self.today_schedule_text = ""
self.today_done_list = []
self.yesterday_schedule_text = ""
self.yesterday_done_list = []
self.name = name
self.personality = personality
self.behavior = behavior
self.start_time = datetime.datetime.now()
async def mai_schedule_start(self):
"""启动日程系统每5分钟执行一次move_doing并在日期变化时重新检查日程"""
try:
logger.info(f"日程系统启动/刷新时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
# 初始化日程
await self.check_and_create_today_schedule()
self.print_schedule()
while True:
current_time = datetime.datetime.now()
# 检查是否需要重新生成日程(日期变化)
if current_time.date() != self.start_time.date():
logger.info("检测到日期变化,重新生成日程")
self.start_time = current_time
await self.check_and_create_today_schedule()
self.print_schedule()
# 执行当前活动
current_activity = await self.move_doing()
logger.info(f"当前活动: {current_activity}")
# 等待5分钟
await asyncio.sleep(300) # 300秒 = 5分钟
except Exception as e:
logger.error(f"日程系统运行时出错: {str(e)}")
logger.exception("详细错误信息:")
async def check_and_create_today_schedule(self):
"""检查昨天的日程,并确保今天有日程安排
Returns:
tuple: (today_schedule_text, today_schedule) 今天的日程文本和解析后的日程字典
"""
today = datetime.datetime.now()
yesterday = today - datetime.timedelta(days=1)
# 先检查昨天的日程
self.yesterday_schedule_text, self.yesterday_done_list = self.load_schedule_from_db(yesterday)
if self.yesterday_schedule_text:
logger.debug(f"已加载{yesterday.strftime('%Y-%m-%d')}的日程")
# 检查今天的日程
self.today_schedule_text, self.today_done_list = self.load_schedule_from_db(today)
if not self.today_schedule_text:
logger.info(f"{today.strftime('%Y-%m-%d')}的日程不存在,准备生成新的日程")
self.today_schedule_text = await self.generate_daily_schedule(target_date=today)
self.save_today_schedule_to_db()
def construct_daytime_prompt(self, target_date: datetime.datetime):
date_str = target_date.strftime("%Y-%m-%d")
weekday = target_date.strftime("%A")
prompt = f"我是{self.name}{self.personality}{self.behavior}"
prompt += f"我昨天的日程是:{self.yesterday_schedule_text}\n"
prompt += f"请为我生成{date_str}{weekday})的日程安排,结合我的个人特点和行为习惯\n"
prompt += "推测我的日程安排,包括我一天都在做什么,有什么发现和思考,具体一些,详细一些,记得写明时间\n"
prompt += "直接返回我的日程,不要输出其他内容:"
return prompt
def construct_doing_prompt(self,time: datetime.datetime):
now_time = time.strftime("%H:%M")
previous_doing = self.today_done_list[-20:] if len(self.today_done_list) > 20 else self.today_done_list
prompt = f"我是{self.name}{self.personality}{self.behavior}"
prompt += f"我今天的日程是:{self.today_schedule_text}\n"
prompt += f"我之前做了的事情是:{previous_doing}\n"
prompt += f"现在是{now_time},结合我的个人特点和行为习惯,"
prompt += "推测我现在做什么,具体一些,详细一些\n"
prompt += "直接返回我在做的事情,不要输出其他内容:"
return prompt
async def generate_daily_schedule(
self, target_date: datetime.datetime = None,) -> Dict[str, str]:
daytime_prompt = self.construct_daytime_prompt(target_date)
daytime_response, _ = await self.llm_scheduler.generate_response(daytime_prompt)
return daytime_response
def _time_diff(self, time1: str, time2: str) -> int:
"""计算两个时间字符串之间的分钟差"""
if time1 == "24:00":
time1 = "23:59"
if time2 == "24:00":
time2 = "23:59"
t1 = datetime.datetime.strptime(time1, "%H:%M")
t2 = datetime.datetime.strptime(time2, "%H:%M")
diff = int((t2 - t1).total_seconds() / 60)
# 考虑时间的循环性
if diff < -720:
diff += 1440 # 加一天的分钟
elif diff > 720:
diff -= 1440 # 减一天的分钟
# print(f"时间1[{time1}]: 时间2[{time2}],差值[{diff}]分钟")
return diff
def print_schedule(self):
"""打印完整的日程安排"""
if not self.today_schedule_text:
logger.warning("今日日程有误,将在下次运行时重新生成")
db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else:
logger.info("=== 今日日程安排 ===")
logger.info(self.today_schedule_text)
logger.info("==================")
self.enable_output = False
async def update_today_done_list(self):
# 更新数据库中的 today_done_list
today_str = datetime.datetime.now().strftime("%Y-%m-%d")
existing_schedule = db.schedule.find_one({"date": today_str})
if existing_schedule:
# 更新数据库中的 today_done_list
db.schedule.update_one(
{"date": today_str},
{"$set": {"today_done_list": self.today_done_list}}
)
logger.debug(f"已更新{today_str}的已完成活动列表")
else:
logger.warning(f"未找到{today_str}的日程记录")
async def move_doing(self):
current_time = datetime.datetime.now()
time_str = current_time.strftime("%H:%M")
doing_prompt = self.construct_doing_prompt(current_time)
doing_response, _ = await self.llm_scheduler.generate_response(doing_prompt)
self.today_done_list.append(current_time,time_str + "" + doing_response)
await self.update_today_done_list()
return doing_response
def save_today_schedule_to_db(self):
"""保存日程到数据库,同时初始化 today_done_list"""
date_str = datetime.datetime.now().strftime("%Y-%m-%d")
schedule_data = {
"date": date_str,
"schedule": self.today_schedule_text,
"today_done_list": self.today_done_list if hasattr(self, 'today_done_list') else []
}
# 使用 upsert 操作,如果存在则更新,不存在则插入
db.schedule.update_one(
{"date": date_str},
{"$set": schedule_data},
upsert=True
)
logger.debug(f"已保存{date_str}的日程到数据库")
def load_schedule_from_db(self, date: datetime.datetime):
"""从数据库加载日程,同时加载 today_done_list"""
date_str = date.strftime("%Y-%m-%d")
existing_schedule = db.schedule.find_one({"date": date_str})
if existing_schedule:
schedule_text = existing_schedule["schedule"]
return schedule_text, existing_schedule.get("today_done_list", [])
else:
logger.debug(f"{date_str}的日程不存在")
return None, None
async def main():
# 使用示例
scheduler = ScheduleGenerator(name="麦麦", personality="你叫麦麦你19岁是一个大二的女大学生你有一头黑色短发你会刷贴吧你现在在学习心理学", behavior="你比较内向")
await scheduler.check_and_create_today_schedule()
scheduler.print_schedule()
print("\n当前任务:")
print(await scheduler.get_current_task())
print("昨天日程:")
print(scheduler.yesterday_schedule)
print("今天日程:")
print(scheduler.today_schedule)
print("明天日程:")
print(scheduler.tomorrow_schedule)
# 当作为组件导入时使用的实例
bot_schedule = ScheduleGenerator()
if __name__ == "__main__":
import asyncio
# 当直接运行此文件时执行
asyncio.run(main())

View File

@ -20,6 +20,13 @@ class LLMStatistics:
self.output_file = output_file self.output_file = output_file
self.running = False self.running = False
self.stats_thread = None self.stats_thread = None
self._init_database()
def _init_database(self):
"""初始化数据库集合"""
if "online_time" not in db.list_collection_names():
db.create_collection("online_time")
db.online_time.create_index([("timestamp", 1)])
def start(self): def start(self):
"""启动统计线程""" """启动统计线程"""
@ -35,6 +42,16 @@ class LLMStatistics:
if self.stats_thread: if self.stats_thread:
self.stats_thread.join() self.stats_thread.join()
def _record_online_time(self):
"""记录在线时间"""
try:
db.online_time.insert_one({
"timestamp": datetime.now(),
"duration": 5 # 5分钟
})
except Exception:
logger.exception("记录在线时间失败")
def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]: def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]:
"""收集指定时间段的LLM请求统计数据 """收集指定时间段的LLM请求统计数据
@ -56,10 +73,11 @@ class LLMStatistics:
"tokens_by_type": defaultdict(int), "tokens_by_type": defaultdict(int),
"tokens_by_user": defaultdict(int), "tokens_by_user": defaultdict(int),
"tokens_by_model": defaultdict(int), "tokens_by_model": defaultdict(int),
# 新增在线时间统计
"online_time_minutes": 0,
} }
cursor = db.llm_usage.find({"timestamp": {"$gte": start_time}}) cursor = db.llm_usage.find({"timestamp": {"$gte": start_time}})
total_requests = 0 total_requests = 0
for doc in cursor: for doc in cursor:
@ -74,7 +92,7 @@ class LLMStatistics:
prompt_tokens = doc.get("prompt_tokens", 0) prompt_tokens = doc.get("prompt_tokens", 0)
completion_tokens = doc.get("completion_tokens", 0) completion_tokens = doc.get("completion_tokens", 0)
total_tokens = prompt_tokens + completion_tokens # 根据数据库字段调整 total_tokens = prompt_tokens + completion_tokens
stats["tokens_by_type"][request_type] += total_tokens stats["tokens_by_type"][request_type] += total_tokens
stats["tokens_by_user"][user_id] += total_tokens stats["tokens_by_user"][user_id] += total_tokens
stats["tokens_by_model"][model_name] += total_tokens stats["tokens_by_model"][model_name] += total_tokens
@ -91,6 +109,11 @@ class LLMStatistics:
if total_requests > 0: if total_requests > 0:
stats["average_tokens"] = stats["total_tokens"] / total_requests stats["average_tokens"] = stats["total_tokens"] / total_requests
# 统计在线时间
online_time_cursor = db.online_time.find({"timestamp": {"$gte": start_time}})
for doc in online_time_cursor:
stats["online_time_minutes"] += doc.get("duration", 0)
return stats return stats
def _collect_all_statistics(self) -> Dict[str, Dict[str, Any]]: def _collect_all_statistics(self) -> Dict[str, Dict[str, Any]]:
@ -115,7 +138,8 @@ class LLMStatistics:
output.append(f"总请求数: {stats['total_requests']}") output.append(f"总请求数: {stats['total_requests']}")
if stats["total_requests"] > 0: if stats["total_requests"] > 0:
output.append(f"总Token数: {stats['total_tokens']}") output.append(f"总Token数: {stats['total_tokens']}")
output.append(f"总花费: {stats['total_cost']:.4f}¥\n") output.append(f"总花费: {stats['total_cost']:.4f}¥")
output.append(f"在线时间: {stats['online_time_minutes']}分钟\n")
data_fmt = "{:<32} {:>10} {:>14} {:>13.4f} ¥" data_fmt = "{:<32} {:>10} {:>14} {:>13.4f} ¥"
@ -184,13 +208,16 @@ class LLMStatistics:
"""统计循环每1分钟运行一次""" """统计循环每1分钟运行一次"""
while self.running: while self.running:
try: try:
# 记录在线时间
self._record_online_time()
# 收集并保存统计数据
all_stats = self._collect_all_statistics() all_stats = self._collect_all_statistics()
self._save_statistics(all_stats) self._save_statistics(all_stats)
except Exception: except Exception:
logger.exception("统计数据处理失败") logger.exception("统计数据处理失败")
# 等待1分钟 # 等待5分钟
for _ in range(60): for _ in range(300): # 5分钟 = 300秒
if not self.running: if not self.running:
break break
time.sleep(1) time.sleep(1)

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
from typing import Dict from typing import Dict
from ..chat.chat_stream import ChatStream from ..chat.chat_stream import ChatStream
from ..chat.config import global_config from ..config.config import global_config
class WillingManager: class WillingManager:

View File

@ -3,7 +3,7 @@ import random
import time import time
from typing import Dict from typing import Dict
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..chat.config import global_config from ..config.config import global_config
from ..chat.chat_stream import ChatStream from ..chat.chat_stream import ChatStream
logger = get_module_logger("mode_dynamic") logger = get_module_logger("mode_dynamic")

View File

@ -1,7 +1,7 @@
from typing import Optional from typing import Optional
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..chat.config import global_config from ..config.config import global_config
from .mode_classical import WillingManager as ClassicalWillingManager from .mode_classical import WillingManager as ClassicalWillingManager
from .mode_dynamic import WillingManager as DynamicWillingManager from .mode_dynamic import WillingManager as DynamicWillingManager
from .mode_custom import WillingManager as CustomWillingManager from .mode_custom import WillingManager as CustomWillingManager

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

View File

@ -1,8 +1,17 @@
from .current_mind import SubHeartflow from .sub_heartflow import SubHeartflow
from src.plugins.moods.moods import MoodManager from src.plugins.moods.moods import MoodManager
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLM_request
from src.plugins.chat.config import global_config from src.plugins.config.config import global_config, BotConfig
from src.plugins.schedule.schedule_generator import bot_schedule
import asyncio import asyncio
from src.common.logger import get_module_logger, LogConfig, HEARTFLOW_STYLE_CONFIG # noqa: E402
heartflow_config = LogConfig(
# 使用海马体专用样式
console_format=HEARTFLOW_STYLE_CONFIG["console_format"],
file_format=HEARTFLOW_STYLE_CONFIG["file_format"],
)
logger = get_module_logger("heartflow", config=heartflow_config)
class CuttentState: class CuttentState:
def __init__(self): def __init__(self):
@ -30,22 +39,24 @@ class Heartflow:
async def heartflow_start_working(self): async def heartflow_start_working(self):
while True: while True:
# await self.do_a_thinking() await self.do_a_thinking()
await asyncio.sleep(60) await asyncio.sleep(600)
async def do_a_thinking(self): async def do_a_thinking(self):
print("麦麦大脑袋转起来了") logger.info("麦麦大脑袋转起来了")
self.current_state.update_current_state_info() self.current_state.update_current_state_info()
personality_info = open("src/think_flow_demo/personality_info.txt", "r", encoding="utf-8").read() personality_info = " ".join(global_config.PROMPT_PERSONALITY)
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
related_memory_info = 'memory' related_memory_info = 'memory'
sub_flows_info = await self.get_all_subheartflows_minds() sub_flows_info = await self.get_all_subheartflows_minds()
schedule_info = bot_schedule.get_current_num_task(num = 5,time_info = True)
prompt = "" prompt = ""
prompt += f"你刚刚在做的事情是:{schedule_info}\n"
prompt += f"{personality_info}\n" prompt += f"{personality_info}\n"
# prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{message_stream_info}\n"
prompt += f"你想起来{related_memory_info}" prompt += f"你想起来{related_memory_info}"
prompt += f"刚刚你的主要想法是{current_thinking_info}" prompt += f"刚刚你的主要想法是{current_thinking_info}"
prompt += f"你还有一些小想法,因为你在参加不同的群聊天,是你正在做的事情:{sub_flows_info}\n" prompt += f"你还有一些小想法,因为你在参加不同的群聊天,是你正在做的事情:{sub_flows_info}\n"
@ -58,7 +69,10 @@ class Heartflow:
self.update_current_mind(reponse) self.update_current_mind(reponse)
self.current_mind = reponse self.current_mind = reponse
print(f"麦麦的总体脑内状态:{self.current_mind}") logger.info(f"麦麦的总体脑内状态:{self.current_mind}")
logger.info("麦麦想了想,当前活动:")
await bot_schedule.move_doing(self.current_mind)
for _, subheartflow in self._subheartflows.items(): for _, subheartflow in self._subheartflows.items():
subheartflow.main_heartflow_info = reponse subheartflow.main_heartflow_info = reponse
@ -77,13 +91,13 @@ class Heartflow:
return await self.minds_summary(sub_minds) return await self.minds_summary(sub_minds)
async def minds_summary(self,minds_str): async def minds_summary(self,minds_str):
personality_info = open("src/think_flow_demo/personality_info.txt", "r", encoding="utf-8").read() personality_info = " ".join(BotConfig.PROMPT_PERSONALITY)
mood_info = self.current_state.mood mood_info = self.current_state.mood
prompt = "" prompt = ""
prompt += f"{personality_info}\n" prompt += f"{personality_info}\n"
prompt += f"现在{global_config.BOT_NICKNAME}的想法是:{self.current_mind}\n" prompt += f"现在{global_config.BOT_NICKNAME}的想法是:{self.current_mind}\n"
prompt += f"现在麦麦在qq群里进行聊天聊天的话题如下{minds_str}\n" prompt += f"现在{global_config.BOT_NICKNAME}在qq群里进行聊天聊天的话题如下{minds_str}\n"
prompt += f"你现在{mood_info}\n" prompt += f"你现在{mood_info}\n"
prompt += '''现在请你总结这些聊天内容,注意关注聊天内容对原有的想法的影响,输出连贯的内心独白 prompt += '''现在请你总结这些聊天内容,注意关注聊天内容对原有的想法的影响,输出连贯的内心独白
不要太长但是记得结合上述的消息要记得你的人设关注新内容:''' 不要太长但是记得结合上述的消息要记得你的人设关注新内容:'''

View File

@ -2,7 +2,7 @@
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLM_request
from src.plugins.chat.config import global_config from src.plugins.config.config import global_config
from src.common.database import db from src.common.database import db
#存储一段聊天的大致内容 #存储一段聊天的大致内容
@ -16,6 +16,10 @@ class Talking_info:
self.observe_times = 0 self.observe_times = 0
self.activate = 360 self.activate = 360
self.last_summary_time = int(datetime.now().timestamp()) # 上次更新summary的时间
self.summary_count = 0 # 30秒内的更新次数
self.max_update_in_30s = 2
self.oberve_interval = 3 self.oberve_interval = 3
self.llm_summary = LLM_request( self.llm_summary = LLM_request(
@ -60,16 +64,22 @@ class Talking_info:
if len(self.talking_message) > 20: if len(self.talking_message) > 20:
self.talking_message = self.talking_message[-20:] # 只保留最新的20条 self.talking_message = self.talking_message[-20:] # 只保留最新的20条
self.translate_message_list_to_str() self.translate_message_list_to_str()
# print(self.talking_message_str)
self.observe_times += 1 self.observe_times += 1
self.last_observe_time = new_messages[-1]["time"] self.last_observe_time = new_messages[-1]["time"]
if self.observe_times > 3: # 检查是否需要更新summary
current_time = int(datetime.now().timestamp())
if current_time - self.last_summary_time >= 30: # 如果超过30秒重置计数
self.summary_count = 0
self.last_summary_time = current_time
if self.summary_count < self.max_update_in_30s: # 如果30秒内更新次数小于2次
await self.update_talking_summary() await self.update_talking_summary()
# print(f"更新了聊天总结:{self.talking_summary}") self.summary_count += 1
async def update_talking_summary(self): async def update_talking_summary(self):
#基于已经有的talking_summary和新的talking_message生成一个summary #基于已经有的talking_summary和新的talking_message生成一个summary
# print(f"更新聊天总结:{self.talking_summary}")
prompt = "" prompt = ""
prompt = f"你正在参与一个qq群聊的讨论这个群之前在聊的内容是{self.talking_summary}\n" prompt = f"你正在参与一个qq群聊的讨论这个群之前在聊的内容是{self.talking_summary}\n"
prompt += f"现在群里的群友们产生了新的讨论,有了新的发言,具体内容如下:{self.talking_message_str}\n" prompt += f"现在群里的群友们产生了新的讨论,有了新的发言,具体内容如下:{self.talking_message_str}\n"

View File

@ -1 +0,0 @@
你叫麦麦你19岁是一个大二的女大学生你有一头黑色短发你会刷贴吧

View File

@ -2,9 +2,21 @@ from .outer_world import outer_world
import asyncio import asyncio
from src.plugins.moods.moods import MoodManager from src.plugins.moods.moods import MoodManager
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLM_request
from src.plugins.chat.config import global_config from src.plugins.config.config import global_config, BotConfig
import re import re
import time import time
from src.plugins.schedule.schedule_generator import bot_schedule
from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
subheartflow_config = LogConfig(
# 使用海马体专用样式
console_format=SUB_HEARTFLOW_STYLE_CONFIG["console_format"],
file_format=SUB_HEARTFLOW_STYLE_CONFIG["file_format"],
)
logger = get_module_logger("subheartflow", config=subheartflow_config)
class CuttentState: class CuttentState:
def __init__(self): def __init__(self):
self.willing = 0 self.willing = 0
@ -35,6 +47,8 @@ class SubHeartflow:
if not self.current_mind: if not self.current_mind:
self.current_mind = "你什么也没想" self.current_mind = "你什么也没想"
self.personality_info = " ".join(global_config.PROMPT_PERSONALITY)
def assign_observe(self,stream_id): def assign_observe(self,stream_id):
self.outer_world = outer_world.get_world_by_stream_id(stream_id) self.outer_world = outer_world.get_world_by_stream_id(stream_id)
self.observe_chat_id = stream_id self.observe_chat_id = stream_id
@ -44,29 +58,49 @@ class SubHeartflow:
current_time = time.time() current_time = time.time()
if current_time - self.last_reply_time > 180: # 3分钟 = 180秒 if current_time - self.last_reply_time > 180: # 3分钟 = 180秒
# print(f"{self.observe_chat_id}麦麦已经3分钟没有回复了暂时停止思考") # print(f"{self.observe_chat_id}麦麦已经3分钟没有回复了暂时停止思考")
await asyncio.sleep(25) # 每30秒检查一次 await asyncio.sleep(60) # 每30秒检查一次
else: else:
await self.do_a_thinking() await self.do_a_thinking()
await self.judge_willing() await self.judge_willing()
await asyncio.sleep(25) await asyncio.sleep(60)
async def do_a_thinking(self): async def do_a_thinking(self):
print("麦麦小脑袋转起来了")
self.current_state.update_current_state_info() self.current_state.update_current_state_info()
personality_info = open("src/think_flow_demo/personality_info.txt", "r", encoding="utf-8").read()
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
related_memory_info = 'memory'
message_stream_info = self.outer_world.talking_summary message_stream_info = self.outer_world.talking_summary
print(f"message_stream_info{message_stream_info}")
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=message_stream_info,
max_memory_num=3,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
)
# print(f"相关记忆:{related_memory}")
if related_memory:
related_memory_info = ""
for memory in related_memory:
related_memory_info += memory[1]
else:
related_memory_info = ''
print(f"相关记忆:{related_memory_info}")
schedule_info = bot_schedule.get_current_num_task(num = 1,time_info = False)
prompt = "" prompt = ""
prompt += f"你刚刚在做的事情是:{schedule_info}\n"
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
prompt += f"{personality_info}\n" prompt += f"{self.personality_info}\n"
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{message_stream_info}\n" prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{message_stream_info}\n"
prompt += f"你想起来{related_memory_info}" if related_memory_info:
prompt += f"你想起来{related_memory_info}"
prompt += f"刚刚你的想法是{current_thinking_info}" prompt += f"刚刚你的想法是{current_thinking_info}"
prompt += f"你现在{mood_info}" prompt += f"你现在{mood_info}\n"
prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长,"
prompt += "但是记得结合上述的消息,要记得维持住你的人设,关注聊天和新内容,不要思考太多:" prompt += "但是记得结合上述的消息,要记得维持住你的人设,关注聊天和新内容,不要思考太多:"
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt) reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
@ -74,48 +108,52 @@ class SubHeartflow:
self.update_current_mind(reponse) self.update_current_mind(reponse)
self.current_mind = reponse self.current_mind = reponse
print(f"麦麦的脑内状态:{self.current_mind}") print(prompt)
logger.info(f"麦麦的脑内状态:{self.current_mind}")
async def do_after_reply(self,reply_content,chat_talking_prompt): async def do_after_reply(self,reply_content,chat_talking_prompt):
# print("麦麦脑袋转起来了") # print("麦麦脑袋转起来了")
self.current_state.update_current_state_info() self.current_state.update_current_state_info()
personality_info = open("src/think_flow_demo/personality_info.txt", "r", encoding="utf-8").read()
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
related_memory_info = 'memory' related_memory_info = 'memory'
message_stream_info = self.outer_world.talking_summary message_stream_info = self.outer_world.talking_summary
message_new_info = chat_talking_prompt message_new_info = chat_talking_prompt
reply_info = reply_content reply_info = reply_content
schedule_info = bot_schedule.get_current_num_task(num = 1,time_info = False)
prompt = "" prompt = ""
prompt += f"{personality_info}\n" prompt += f"你刚刚在做的事情是:{schedule_info}\n"
prompt += f"{self.personality_info}\n"
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{message_stream_info}\n" prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{message_stream_info}\n"
prompt += f"你想起来{related_memory_info}" if related_memory_info:
prompt += f"你想起来{related_memory_info}"
prompt += f"刚刚你的想法是{current_thinking_info}" prompt += f"刚刚你的想法是{current_thinking_info}"
prompt += f"你现在看到了网友们发的新消息:{message_new_info}\n" prompt += f"你现在看到了网友们发的新消息:{message_new_info}\n"
prompt += f"你刚刚回复了群友们:{reply_info}" prompt += f"你刚刚回复了群友们:{reply_info}"
prompt += f"你现在{mood_info}" prompt += f"你现在{mood_info}"
prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白" prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,以及你回复的内容,不要思考太多:" prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt) reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
self.update_current_mind(reponse) self.update_current_mind(reponse)
self.current_mind = reponse self.current_mind = reponse
print(f"{self.observe_chat_id}麦麦的脑内状态:{self.current_mind}") logger.info(f"麦麦回复后的脑内状态:{self.current_mind}")
self.last_reply_time = time.time() self.last_reply_time = time.time()
async def judge_willing(self): async def judge_willing(self):
# print("麦麦闹情绪了1") # print("麦麦闹情绪了1")
personality_info = open("src/think_flow_demo/personality_info.txt", "r", encoding="utf-8").read()
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
# print("麦麦闹情绪了2") # print("麦麦闹情绪了2")
prompt = "" prompt = ""
prompt += f"{personality_info}\n" prompt += f"{self.personality_info}\n"
prompt += "现在你正在上网和qq群里的网友们聊天" prompt += "现在你正在上网和qq群里的网友们聊天"
prompt += f"你现在的想法是{current_thinking_info}" prompt += f"你现在的想法是{current_thinking_info}"
prompt += f"你现在{mood_info}" prompt += f"你现在{mood_info}"
@ -130,7 +168,7 @@ class SubHeartflow:
else: else:
self.current_state.willing = 0 self.current_state.willing = 0
print(f"{self.observe_chat_id}麦麦的回复意愿:{self.current_state.willing}") logger.info(f"{self.observe_chat_id}麦麦的回复意愿:{self.current_state.willing}")
return self.current_state.willing return self.current_state.willing

View File

@ -3,7 +3,7 @@ version = "0.0.11"
[mai_version] [mai_version]
version = "0.6.0" version = "0.6.0"
version-fix = "snapshot-1" version-fix = "snapshot-2"
#以下是给开发人员阅读的,一般用户不需要阅读 #以下是给开发人员阅读的,一般用户不需要阅读
#如果你想要修改配置文件请在修改后将version的值进行变更 #如果你想要修改配置文件请在修改后将version的值进行变更
@ -43,6 +43,7 @@ personality_3_probability = 0.1 # 第三种人格出现概率,请确保三个
[schedule] [schedule]
enable_schedule_gen = true # 是否启用日程表(尚未完成) enable_schedule_gen = true # 是否启用日程表(尚未完成)
prompt_schedule_gen = "用几句话描述描述性格特点或行动规律,这个特征会用来生成日程表" prompt_schedule_gen = "用几句话描述描述性格特点或行动规律,这个特征会用来生成日程表"
schedule_doing_update_interval = 900 # 日程表更新间隔 单位秒
[message] [message]
max_context_size = 15 # 麦麦获得的上文数量建议15太短太长都会导致脑袋尖尖 max_context_size = 15 # 麦麦获得的上文数量建议15太短太长都会导致脑袋尖尖
@ -85,7 +86,7 @@ check_prompt = "符合公序良俗" # 表情包过滤要求
[memory] [memory]
build_memory_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多 build_memory_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
build_memory_distribution = [4,2,0.6,24,8,0.4] # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重 build_memory_distribution = [4.0,2.0,0.6,24.0,8.0,0.4] # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
build_memory_sample_num = 10 # 采样数量,数值越高记忆采样次数越多 build_memory_sample_num = 10 # 采样数量,数值越高记忆采样次数越多
build_memory_sample_length = 20 # 采样长度,数值越高一段记忆内容越丰富 build_memory_sample_length = 20 # 采样长度,数值越高一段记忆内容越丰富
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多 memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
@ -135,6 +136,7 @@ enable = true
[experimental] [experimental]
enable_friend_chat = false # 是否启用好友聊天 enable_friend_chat = false # 是否启用好友聊天
enable_think_flow = false # 是否启用思维流 注意可能会消耗大量token请谨慎开启 enable_think_flow = false # 是否启用思维流 注意可能会消耗大量token请谨慎开启
#思维流适合搭配低能耗普通模型使用例如qwen2.5 32b
#下面的模型若使用硅基流动则不需要更改使用ds官方则改成.env.prod自定义的宏使用自定义模型则选择定位相似的模型自己填写 #下面的模型若使用硅基流动则不需要更改使用ds官方则改成.env.prod自定义的宏使用自定义模型则选择定位相似的模型自己填写
#推理模型 #推理模型

655
webui.py
View File

@ -5,6 +5,7 @@ import toml
import signal import signal
import sys import sys
import requests import requests
import socket
try: try:
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
@ -39,50 +40,35 @@ def signal_handler(signum, frame):
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
is_share = False is_share = False
debug = True debug = False
# 检查配置文件是否存在
if not os.path.exists("config/bot_config.toml"):
logger.error("配置文件 bot_config.toml 不存在,请检查配置文件路径")
raise FileNotFoundError("配置文件 bot_config.toml 不存在,请检查配置文件路径")
if not os.path.exists(".env.prod"): def init_model_pricing():
logger.error("环境配置文件 .env.prod 不存在,请检查配置文件路径") """初始化模型价格配置"""
raise FileNotFoundError("环境配置文件 .env.prod 不存在,请检查配置文件路径") model_list = [
"llm_reasoning",
"llm_reasoning_minor",
"llm_normal",
"llm_topic_judge",
"llm_summary_by_topic",
"llm_emotion_judge",
"vlm",
"embedding",
"moderation"
]
config_data = toml.load("config/bot_config.toml") for model in model_list:
# 增加对老版本配置文件支持 if model in config_data["model"]:
LEGACY_CONFIG_VERSION = version.parse("0.0.1") # 检查是否已有pri_in和pri_out配置
has_pri_in = "pri_in" in config_data["model"][model]
# 增加最低支持版本 has_pri_out = "pri_out" in config_data["model"][model]
MIN_SUPPORT_VERSION = version.parse("0.0.8")
MIN_SUPPORT_MAIMAI_VERSION = version.parse("0.5.13")
if "inner" in config_data:
CONFIG_VERSION = config_data["inner"]["version"]
PARSED_CONFIG_VERSION = version.parse(CONFIG_VERSION)
if PARSED_CONFIG_VERSION < MIN_SUPPORT_VERSION:
logger.error("您的麦麦版本过低!!已经不再支持,请更新到最新版本!!")
logger.error("最低支持的麦麦版本:" + str(MIN_SUPPORT_MAIMAI_VERSION))
raise Exception("您的麦麦版本过低!!已经不再支持,请更新到最新版本!!")
else:
logger.error("您的麦麦版本过低!!已经不再支持,请更新到最新版本!!")
logger.error("最低支持的麦麦版本:" + str(MIN_SUPPORT_MAIMAI_VERSION))
raise Exception("您的麦麦版本过低!!已经不再支持,请更新到最新版本!!")
HAVE_ONLINE_STATUS_VERSION = version.parse("0.0.9")
# 定义意愿模式可选项
WILLING_MODE_CHOICES = [
"classical",
"dynamic",
"custom",
]
# 添加WebUI配置文件版本
WEBUI_VERSION = version.parse("0.0.10")
# 只在缺少配置时添加默认值
if not has_pri_in:
config_data["model"][model]["pri_in"] = 0
logger.info(f"为模型 {model} 添加默认输入价格配置")
if not has_pri_out:
config_data["model"][model]["pri_out"] = 0
logger.info(f"为模型 {model} 添加默认输出价格配置")
# ============================================== # ==============================================
# env环境配置文件读取部分 # env环境配置文件读取部分
@ -124,6 +110,68 @@ def parse_env_config(config_file):
return env_variables return env_variables
# 检查配置文件是否存在
if not os.path.exists("config/bot_config.toml"):
logger.error("配置文件 bot_config.toml 不存在,请检查配置文件路径")
raise FileNotFoundError("配置文件 bot_config.toml 不存在,请检查配置文件路径")
else:
config_data = toml.load("config/bot_config.toml")
init_model_pricing()
if not os.path.exists(".env.prod"):
logger.error("环境配置文件 .env.prod 不存在,请检查配置文件路径")
raise FileNotFoundError("环境配置文件 .env.prod 不存在,请检查配置文件路径")
else:
# 载入env文件并解析
env_config_file = ".env.prod" # 配置文件路径
env_config_data = parse_env_config(env_config_file)
# 增加最低支持版本
MIN_SUPPORT_VERSION = version.parse("0.0.8")
MIN_SUPPORT_MAIMAI_VERSION = version.parse("0.5.13")
if "inner" in config_data:
CONFIG_VERSION = config_data["inner"]["version"]
PARSED_CONFIG_VERSION = version.parse(CONFIG_VERSION)
if PARSED_CONFIG_VERSION < MIN_SUPPORT_VERSION:
logger.error("您的麦麦版本过低!!已经不再支持,请更新到最新版本!!")
logger.error("最低支持的麦麦版本:" + str(MIN_SUPPORT_MAIMAI_VERSION))
raise Exception("您的麦麦版本过低!!已经不再支持,请更新到最新版本!!")
else:
logger.error("您的麦麦版本过低!!已经不再支持,请更新到最新版本!!")
logger.error("最低支持的麦麦版本:" + str(MIN_SUPPORT_MAIMAI_VERSION))
raise Exception("您的麦麦版本过低!!已经不再支持,请更新到最新版本!!")
# 添加麦麦版本
if "mai_version" in config_data:
MAI_VERSION = version.parse(str(config_data["mai_version"]["version"]))
logger.info("您的麦麦版本为:" + str(MAI_VERSION))
else:
logger.info("检测到配置文件中并没有定义麦麦版本,将使用默认版本")
MAI_VERSION = version.parse("0.5.15")
logger.info("您的麦麦版本为:" + str(MAI_VERSION))
# 增加在线状态更新版本
HAVE_ONLINE_STATUS_VERSION = version.parse("0.0.9")
# 增加日程设置重构版本
SCHEDULE_CHANGED_VERSION = version.parse("0.0.11")
# 定义意愿模式可选项
WILLING_MODE_CHOICES = [
"classical",
"dynamic",
"custom",
]
# 添加WebUI配置文件版本
WEBUI_VERSION = version.parse("0.0.11")
# env环境配置文件保存函数 # env环境配置文件保存函数
def save_to_env_file(env_variables, filename=".env.prod"): def save_to_env_file(env_variables, filename=".env.prod"):
""" """
@ -482,7 +530,9 @@ def save_personality_config(
t_prompt_personality_1, t_prompt_personality_1,
t_prompt_personality_2, t_prompt_personality_2,
t_prompt_personality_3, t_prompt_personality_3,
t_prompt_schedule, t_enable_schedule_gen,
t_prompt_schedule_gen,
t_schedule_doing_update_interval,
t_personality_1_probability, t_personality_1_probability,
t_personality_2_probability, t_personality_2_probability,
t_personality_3_probability, t_personality_3_probability,
@ -492,8 +542,13 @@ def save_personality_config(
config_data["personality"]["prompt_personality"][1] = t_prompt_personality_2 config_data["personality"]["prompt_personality"][1] = t_prompt_personality_2
config_data["personality"]["prompt_personality"][2] = t_prompt_personality_3 config_data["personality"]["prompt_personality"][2] = t_prompt_personality_3
# 保存日程生成提示词 # 保存日程生成部分
config_data["personality"]["prompt_schedule"] = t_prompt_schedule if PARSED_CONFIG_VERSION >= SCHEDULE_CHANGED_VERSION:
config_data["schedule"]["enable_schedule_gen"] = t_enable_schedule_gen
config_data["schedule"]["prompt_schedule_gen"] = t_prompt_schedule_gen
config_data["schedule"]["schedule_doing_update_interval"] = t_schedule_doing_update_interval
else:
config_data["personality"]["prompt_schedule"] = t_prompt_schedule_gen
# 保存三个人格的概率 # 保存三个人格的概率
config_data["personality"]["personality_1_probability"] = t_personality_1_probability config_data["personality"]["personality_1_probability"] = t_personality_1_probability
@ -521,13 +576,15 @@ def save_message_and_emoji_config(
t_enable_check, t_enable_check,
t_check_prompt, t_check_prompt,
): ):
config_data["message"]["min_text_length"] = t_min_text_length if PARSED_CONFIG_VERSION < version.parse("0.0.11"):
config_data["message"]["min_text_length"] = t_min_text_length
config_data["message"]["max_context_size"] = t_max_context_size config_data["message"]["max_context_size"] = t_max_context_size
config_data["message"]["emoji_chance"] = t_emoji_chance config_data["message"]["emoji_chance"] = t_emoji_chance
config_data["message"]["thinking_timeout"] = t_thinking_timeout config_data["message"]["thinking_timeout"] = t_thinking_timeout
config_data["message"]["response_willing_amplifier"] = t_response_willing_amplifier if PARSED_CONFIG_VERSION < version.parse("0.0.11"):
config_data["message"]["response_interested_rate_amplifier"] = t_response_interested_rate_amplifier config_data["message"]["response_willing_amplifier"] = t_response_willing_amplifier
config_data["message"]["down_frequency_rate"] = t_down_frequency_rate config_data["message"]["response_interested_rate_amplifier"] = t_response_interested_rate_amplifier
config_data["message"]["down_frequency_rate"] = t_down_frequency_rate
config_data["message"]["ban_words"] = t_ban_words_final_result config_data["message"]["ban_words"] = t_ban_words_final_result
config_data["message"]["ban_msgs_regex"] = t_ban_msgs_regex_final_result config_data["message"]["ban_msgs_regex"] = t_ban_msgs_regex_final_result
config_data["emoji"]["check_interval"] = t_check_interval config_data["emoji"]["check_interval"] = t_check_interval
@ -539,6 +596,21 @@ def save_message_and_emoji_config(
logger.info("消息和表情配置已保存到 bot_config.toml 文件中") logger.info("消息和表情配置已保存到 bot_config.toml 文件中")
return "消息和表情配置已保存" return "消息和表情配置已保存"
def save_willing_config(
t_willing_mode,
t_response_willing_amplifier,
t_response_interested_rate_amplifier,
t_down_frequency_rate,
t_emoji_response_penalty,
):
config_data["willing"]["willing_mode"] = t_willing_mode
config_data["willing"]["response_willing_amplifier"] = t_response_willing_amplifier
config_data["willing"]["response_interested_rate_amplifier"] = t_response_interested_rate_amplifier
config_data["willing"]["down_frequency_rate"] = t_down_frequency_rate
config_data["willing"]["emoji_response_penalty"] = t_emoji_response_penalty
save_config_to_file(config_data)
logger.info("willinng配置已保存到 bot_config.toml 文件中")
return "willinng配置已保存"
def save_response_model_config( def save_response_model_config(
t_willing_mode, t_willing_mode,
@ -552,39 +624,79 @@ def save_response_model_config(
t_model1_pri_out, t_model1_pri_out,
t_model2_name, t_model2_name,
t_model2_provider, t_model2_provider,
t_model2_pri_in,
t_model2_pri_out,
t_model3_name, t_model3_name,
t_model3_provider, t_model3_provider,
t_model3_pri_in,
t_model3_pri_out,
t_emotion_model_name, t_emotion_model_name,
t_emotion_model_provider, t_emotion_model_provider,
t_emotion_model_pri_in,
t_emotion_model_pri_out,
t_topic_judge_model_name, t_topic_judge_model_name,
t_topic_judge_model_provider, t_topic_judge_model_provider,
t_topic_judge_model_pri_in,
t_topic_judge_model_pri_out,
t_summary_by_topic_model_name, t_summary_by_topic_model_name,
t_summary_by_topic_model_provider, t_summary_by_topic_model_provider,
t_summary_by_topic_model_pri_in,
t_summary_by_topic_model_pri_out,
t_vlm_model_name, t_vlm_model_name,
t_vlm_model_provider, t_vlm_model_provider,
t_vlm_model_pri_in,
t_vlm_model_pri_out,
): ):
if PARSED_CONFIG_VERSION >= version.parse("0.0.10"): if PARSED_CONFIG_VERSION >= version.parse("0.0.10"):
config_data["willing"]["willing_mode"] = t_willing_mode config_data["willing"]["willing_mode"] = t_willing_mode
config_data["response"]["model_r1_probability"] = t_model_r1_probability config_data["response"]["model_r1_probability"] = t_model_r1_probability
config_data["response"]["model_v3_probability"] = t_model_r2_probability config_data["response"]["model_v3_probability"] = t_model_r2_probability
config_data["response"]["model_r1_distill_probability"] = t_model_r3_probability config_data["response"]["model_r1_distill_probability"] = t_model_r3_probability
config_data["response"]["max_response_length"] = t_max_response_length if PARSED_CONFIG_VERSION <= version.parse("0.0.10"):
config_data["response"]["max_response_length"] = t_max_response_length
# 保存模型1配置
config_data["model"]["llm_reasoning"]["name"] = t_model1_name config_data["model"]["llm_reasoning"]["name"] = t_model1_name
config_data["model"]["llm_reasoning"]["provider"] = t_model1_provider config_data["model"]["llm_reasoning"]["provider"] = t_model1_provider
config_data["model"]["llm_reasoning"]["pri_in"] = t_model1_pri_in config_data["model"]["llm_reasoning"]["pri_in"] = t_model1_pri_in
config_data["model"]["llm_reasoning"]["pri_out"] = t_model1_pri_out config_data["model"]["llm_reasoning"]["pri_out"] = t_model1_pri_out
# 保存模型2配置
config_data["model"]["llm_normal"]["name"] = t_model2_name config_data["model"]["llm_normal"]["name"] = t_model2_name
config_data["model"]["llm_normal"]["provider"] = t_model2_provider config_data["model"]["llm_normal"]["provider"] = t_model2_provider
config_data["model"]["llm_normal"]["pri_in"] = t_model2_pri_in
config_data["model"]["llm_normal"]["pri_out"] = t_model2_pri_out
# 保存模型3配置
config_data["model"]["llm_reasoning_minor"]["name"] = t_model3_name config_data["model"]["llm_reasoning_minor"]["name"] = t_model3_name
config_data["model"]["llm_normal"]["provider"] = t_model3_provider config_data["model"]["llm_reasoning_minor"]["provider"] = t_model3_provider
config_data["model"]["llm_reasoning_minor"]["pri_in"] = t_model3_pri_in
config_data["model"]["llm_reasoning_minor"]["pri_out"] = t_model3_pri_out
# 保存情感模型配置
config_data["model"]["llm_emotion_judge"]["name"] = t_emotion_model_name config_data["model"]["llm_emotion_judge"]["name"] = t_emotion_model_name
config_data["model"]["llm_emotion_judge"]["provider"] = t_emotion_model_provider config_data["model"]["llm_emotion_judge"]["provider"] = t_emotion_model_provider
config_data["model"]["llm_emotion_judge"]["pri_in"] = t_emotion_model_pri_in
config_data["model"]["llm_emotion_judge"]["pri_out"] = t_emotion_model_pri_out
# 保存主题判断模型配置
config_data["model"]["llm_topic_judge"]["name"] = t_topic_judge_model_name config_data["model"]["llm_topic_judge"]["name"] = t_topic_judge_model_name
config_data["model"]["llm_topic_judge"]["provider"] = t_topic_judge_model_provider config_data["model"]["llm_topic_judge"]["provider"] = t_topic_judge_model_provider
config_data["model"]["llm_topic_judge"]["pri_in"] = t_topic_judge_model_pri_in
config_data["model"]["llm_topic_judge"]["pri_out"] = t_topic_judge_model_pri_out
# 保存主题总结模型配置
config_data["model"]["llm_summary_by_topic"]["name"] = t_summary_by_topic_model_name config_data["model"]["llm_summary_by_topic"]["name"] = t_summary_by_topic_model_name
config_data["model"]["llm_summary_by_topic"]["provider"] = t_summary_by_topic_model_provider config_data["model"]["llm_summary_by_topic"]["provider"] = t_summary_by_topic_model_provider
config_data["model"]["llm_summary_by_topic"]["pri_in"] = t_summary_by_topic_model_pri_in
config_data["model"]["llm_summary_by_topic"]["pri_out"] = t_summary_by_topic_model_pri_out
# 保存识图模型配置
config_data["model"]["vlm"]["name"] = t_vlm_model_name config_data["model"]["vlm"]["name"] = t_vlm_model_name
config_data["model"]["vlm"]["provider"] = t_vlm_model_provider config_data["model"]["vlm"]["provider"] = t_vlm_model_provider
config_data["model"]["vlm"]["pri_in"] = t_vlm_model_pri_in
config_data["model"]["vlm"]["pri_out"] = t_vlm_model_pri_out
save_config_to_file(config_data) save_config_to_file(config_data)
logger.info("回复&模型设置已保存到 bot_config.toml 文件中") logger.info("回复&模型设置已保存到 bot_config.toml 文件中")
return "回复&模型设置已保存" return "回复&模型设置已保存"
@ -600,6 +712,12 @@ def save_memory_mood_config(
t_mood_update_interval, t_mood_update_interval,
t_mood_decay_rate, t_mood_decay_rate,
t_mood_intensity_factor, t_mood_intensity_factor,
t_build_memory_dist1_mean,
t_build_memory_dist1_std,
t_build_memory_dist1_weight,
t_build_memory_dist2_mean,
t_build_memory_dist2_std,
t_build_memory_dist2_weight,
): ):
config_data["memory"]["build_memory_interval"] = t_build_memory_interval config_data["memory"]["build_memory_interval"] = t_build_memory_interval
config_data["memory"]["memory_compress_rate"] = t_memory_compress_rate config_data["memory"]["memory_compress_rate"] = t_memory_compress_rate
@ -607,6 +725,15 @@ def save_memory_mood_config(
config_data["memory"]["memory_forget_time"] = t_memory_forget_time config_data["memory"]["memory_forget_time"] = t_memory_forget_time
config_data["memory"]["memory_forget_percentage"] = t_memory_forget_percentage config_data["memory"]["memory_forget_percentage"] = t_memory_forget_percentage
config_data["memory"]["memory_ban_words"] = t_memory_ban_words_final_result config_data["memory"]["memory_ban_words"] = t_memory_ban_words_final_result
if PARSED_CONFIG_VERSION >= version.parse("0.0.11"):
config_data["memory"]["build_memory_distribution"] = [
t_build_memory_dist1_mean,
t_build_memory_dist1_std,
t_build_memory_dist1_weight,
t_build_memory_dist2_mean,
t_build_memory_dist2_std,
t_build_memory_dist2_weight,
]
config_data["mood"]["update_interval"] = t_mood_update_interval config_data["mood"]["update_interval"] = t_mood_update_interval
config_data["mood"]["decay_rate"] = t_mood_decay_rate config_data["mood"]["decay_rate"] = t_mood_decay_rate
config_data["mood"]["intensity_factor"] = t_mood_intensity_factor config_data["mood"]["intensity_factor"] = t_mood_intensity_factor
@ -627,6 +754,9 @@ def save_other_config(
t_tone_error_rate, t_tone_error_rate,
t_word_replace_rate, t_word_replace_rate,
t_remote_status, t_remote_status,
t_enable_response_spliter,
t_max_response_length,
t_max_sentence_num,
): ):
config_data["keywords_reaction"]["enable"] = t_keywords_reaction_enabled config_data["keywords_reaction"]["enable"] = t_keywords_reaction_enabled
config_data["others"]["enable_advance_output"] = t_enable_advance_output config_data["others"]["enable_advance_output"] = t_enable_advance_output
@ -640,6 +770,10 @@ def save_other_config(
config_data["chinese_typo"]["word_replace_rate"] = t_word_replace_rate config_data["chinese_typo"]["word_replace_rate"] = t_word_replace_rate
if PARSED_CONFIG_VERSION > HAVE_ONLINE_STATUS_VERSION: if PARSED_CONFIG_VERSION > HAVE_ONLINE_STATUS_VERSION:
config_data["remote"]["enable"] = t_remote_status config_data["remote"]["enable"] = t_remote_status
if PARSED_CONFIG_VERSION >= version.parse("0.0.11"):
config_data["response_spliter"]["enable_response_spliter"] = t_enable_response_spliter
config_data["response_spliter"]["response_max_length"] = t_max_response_length
config_data["response_spliter"]["response_max_sentence_num"] = t_max_sentence_num
save_config_to_file(config_data) save_config_to_file(config_data)
logger.info("其他设置已保存到 bot_config.toml 文件中") logger.info("其他设置已保存到 bot_config.toml 文件中")
return "其他设置已保存" return "其他设置已保存"
@ -657,7 +791,6 @@ def save_group_config(
logger.info("群聊设置已保存到 bot_config.toml 文件中") logger.info("群聊设置已保存到 bot_config.toml 文件中")
return "群聊设置已保存" return "群聊设置已保存"
with gr.Blocks(title="MaimBot配置文件编辑") as app: with gr.Blocks(title="MaimBot配置文件编辑") as app:
gr.Markdown( gr.Markdown(
value=""" value="""
@ -997,11 +1130,33 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
inputs=personality_probability_change_inputs, inputs=personality_probability_change_inputs,
outputs=[warning_less_text], outputs=[warning_less_text],
) )
with gr.Row(): with gr.Row():
prompt_schedule = gr.Textbox( gr.Markdown("---")
label="日程生成提示词", value=config_data["personality"]["prompt_schedule"], interactive=True with gr.Row():
) gr.Markdown("麦麦提示词设置")
if PARSED_CONFIG_VERSION >= SCHEDULE_CHANGED_VERSION:
with gr.Row():
enable_schedule_gen = gr.Checkbox(value=config_data["schedule"]["enable_schedule_gen"],
label="是否开启麦麦日程生成(尚未完成)",
interactive=True
)
with gr.Row():
prompt_schedule_gen = gr.Textbox(
label="日程生成提示词", value=config_data["schedule"]["prompt_schedule_gen"], interactive=True
)
with gr.Row():
schedule_doing_update_interval = gr.Number(
value=config_data["schedule"]["schedule_doing_update_interval"],
label="日程表更新间隔 单位秒",
interactive=True
)
else:
with gr.Row():
prompt_schedule_gen = gr.Textbox(
label="日程生成提示词", value=config_data["personality"]["prompt_schedule"], interactive=True
)
enable_schedule_gen = gr.Checkbox(value=False,visible=False,interactive=False)
schedule_doing_update_interval = gr.Number(value=0,visible=False,interactive=False)
with gr.Row(): with gr.Row():
personal_save_btn = gr.Button( personal_save_btn = gr.Button(
"保存人格配置", "保存人格配置",
@ -1017,7 +1172,9 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
prompt_personality_1, prompt_personality_1,
prompt_personality_2, prompt_personality_2,
prompt_personality_3, prompt_personality_3,
prompt_schedule, enable_schedule_gen,
prompt_schedule_gen,
schedule_doing_update_interval,
personality_1_probability, personality_1_probability,
personality_2_probability, personality_2_probability,
personality_3_probability, personality_3_probability,
@ -1027,11 +1184,14 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
with gr.TabItem("3-消息&表情包设置"): with gr.TabItem("3-消息&表情包设置"):
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
with gr.Row(): if PARSED_CONFIG_VERSION < version.parse("0.0.11"):
min_text_length = gr.Number( with gr.Row():
value=config_data["message"]["min_text_length"], min_text_length = gr.Number(
label="与麦麦聊天时麦麦只会回答文本大于等于此数的消息", value=config_data["message"]["min_text_length"],
) label="与麦麦聊天时麦麦只会回答文本大于等于此数的消息",
)
else:
min_text_length = gr.Number(visible=False,value=0,interactive=False)
with gr.Row(): with gr.Row():
max_context_size = gr.Number( max_context_size = gr.Number(
value=config_data["message"]["max_context_size"], label="麦麦获得的上文数量" value=config_data["message"]["max_context_size"], label="麦麦获得的上文数量"
@ -1049,21 +1209,27 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
value=config_data["message"]["thinking_timeout"], value=config_data["message"]["thinking_timeout"],
label="麦麦正在思考时,如果超过此秒数,则停止思考", label="麦麦正在思考时,如果超过此秒数,则停止思考",
) )
with gr.Row(): if PARSED_CONFIG_VERSION < version.parse("0.0.11"):
response_willing_amplifier = gr.Number( with gr.Row():
value=config_data["message"]["response_willing_amplifier"], response_willing_amplifier = gr.Number(
label="麦麦回复意愿放大系数一般为1", value=config_data["message"]["response_willing_amplifier"],
) label="麦麦回复意愿放大系数一般为1",
with gr.Row(): )
response_interested_rate_amplifier = gr.Number( with gr.Row():
value=config_data["message"]["response_interested_rate_amplifier"], response_interested_rate_amplifier = gr.Number(
label="麦麦回复兴趣度放大系数,听到记忆里的内容时放大系数", value=config_data["message"]["response_interested_rate_amplifier"],
) label="麦麦回复兴趣度放大系数,听到记忆里的内容时放大系数",
with gr.Row(): )
down_frequency_rate = gr.Number( with gr.Row():
value=config_data["message"]["down_frequency_rate"], down_frequency_rate = gr.Number(
label="降低回复频率的群组回复意愿降低系数", value=config_data["message"]["down_frequency_rate"],
) label="降低回复频率的群组回复意愿降低系数",
)
else:
response_willing_amplifier = gr.Number(visible=False,value=0,interactive=False)
response_interested_rate_amplifier = gr.Number(visible=False,value=0,interactive=False)
down_frequency_rate = gr.Number(visible=False,value=0,interactive=False)
with gr.Row(): with gr.Row():
gr.Markdown("### 违禁词列表") gr.Markdown("### 违禁词列表")
with gr.Row(): with gr.Row():
@ -1207,7 +1373,7 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
], ],
outputs=[emoji_save_message], outputs=[emoji_save_message],
) )
with gr.TabItem("4-回复&模型设置"): with gr.TabItem("4-意愿设置"):
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
with gr.Row(): with gr.Row():
@ -1229,6 +1395,55 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
) )
else: else:
willing_mode = gr.Textbox(visible=False, value="disabled") willing_mode = gr.Textbox(visible=False, value="disabled")
if PARSED_CONFIG_VERSION >= version.parse("0.0.11"):
with gr.Row():
response_willing_amplifier = gr.Number(
value=config_data["willing"]["response_willing_amplifier"],
label="麦麦回复意愿放大系数一般为1",
)
with gr.Row():
response_interested_rate_amplifier = gr.Number(
value=config_data["willing"]["response_interested_rate_amplifier"],
label="麦麦回复兴趣度放大系数,听到记忆里的内容时放大系数",
)
with gr.Row():
down_frequency_rate = gr.Number(
value=config_data["willing"]["down_frequency_rate"],
label="降低回复频率的群组回复意愿降低系数",
)
with gr.Row():
emoji_response_penalty = gr.Number(
value=config_data["willing"]["emoji_response_penalty"],
label="表情包回复惩罚系数设为0为不回复单个表情包减少单独回复表情包的概率",
)
else:
response_willing_amplifier = gr.Number(visible=False, value=1.0)
response_interested_rate_amplifier = gr.Number(visible=False, value=1.0)
down_frequency_rate = gr.Number(visible=False, value=1.0)
emoji_response_penalty = gr.Number(visible=False, value=1.0)
with gr.Row():
willing_save_btn = gr.Button(
"保存意愿设置设置",
variant="primary",
elem_id="save_personality_btn",
elem_classes="save_personality_btn",
)
with gr.Row():
willing_save_message = gr.Textbox(label="意愿设置保存结果")
willing_save_btn.click(
save_willing_config,
inputs=[
willing_mode,
response_willing_amplifier,
response_interested_rate_amplifier,
down_frequency_rate,
emoji_response_penalty,
],
outputs=[emoji_save_message],
)
with gr.TabItem("4-回复&模型设置"):
with gr.Row():
with gr.Column(scale=3):
with gr.Row(): with gr.Row():
model_r1_probability = gr.Slider( model_r1_probability = gr.Slider(
minimum=0, minimum=0,
@ -1289,10 +1504,13 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
inputs=[model_r1_probability, model_r2_probability, model_r3_probability], inputs=[model_r1_probability, model_r2_probability, model_r3_probability],
outputs=[model_warning_less_text], outputs=[model_warning_less_text],
) )
with gr.Row(): if PARSED_CONFIG_VERSION <= version.parse("0.0.10"):
max_response_length = gr.Number( with gr.Row():
value=config_data["response"]["max_response_length"], label="麦麦回答的最大token数" max_response_length = gr.Number(
) value=config_data["response"]["max_response_length"], label="麦麦回答的最大token数"
)
else:
max_response_length = gr.Number(visible=False,value=0)
with gr.Row(): with gr.Row():
gr.Markdown("""### 模型设置""") gr.Markdown("""### 模型设置""")
with gr.Row(): with gr.Row():
@ -1336,6 +1554,16 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
value=config_data["model"]["llm_normal"]["provider"], value=config_data["model"]["llm_normal"]["provider"],
label="模型2提供商", label="模型2提供商",
) )
with gr.Row():
model2_pri_in = gr.Number(
value=config_data["model"]["llm_normal"]["pri_in"],
label="模型2次要回复模型的输入价格非必填可以记录消耗",
)
with gr.Row():
model2_pri_out = gr.Number(
value=config_data["model"]["llm_normal"]["pri_out"],
label="模型2次要回复模型的输出价格非必填可以记录消耗",
)
with gr.TabItem("3-次要模型"): with gr.TabItem("3-次要模型"):
with gr.Row(): with gr.Row():
model3_name = gr.Textbox( model3_name = gr.Textbox(
@ -1347,6 +1575,16 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
value=config_data["model"]["llm_reasoning_minor"]["provider"], value=config_data["model"]["llm_reasoning_minor"]["provider"],
label="模型3提供商", label="模型3提供商",
) )
with gr.Row():
model3_pri_in = gr.Number(
value=config_data["model"]["llm_reasoning_minor"]["pri_in"],
label="模型3次要回复模型的输入价格非必填可以记录消耗",
)
with gr.Row():
model3_pri_out = gr.Number(
value=config_data["model"]["llm_reasoning_minor"]["pri_out"],
label="模型3次要回复模型的输出价格非必填可以记录消耗",
)
with gr.TabItem("4-情感&主题模型"): with gr.TabItem("4-情感&主题模型"):
with gr.Row(): with gr.Row():
gr.Markdown("""### 情感模型设置""") gr.Markdown("""### 情感模型设置""")
@ -1360,6 +1598,16 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
value=config_data["model"]["llm_emotion_judge"]["provider"], value=config_data["model"]["llm_emotion_judge"]["provider"],
label="情感模型提供商", label="情感模型提供商",
) )
with gr.Row():
emotion_model_pri_in = gr.Number(
value=config_data["model"]["llm_emotion_judge"]["pri_in"],
label="情感模型的输入价格(非必填,可以记录消耗)",
)
with gr.Row():
emotion_model_pri_out = gr.Number(
value=config_data["model"]["llm_emotion_judge"]["pri_out"],
label="情感模型的输出价格(非必填,可以记录消耗)",
)
with gr.Row(): with gr.Row():
gr.Markdown("""### 主题模型设置""") gr.Markdown("""### 主题模型设置""")
with gr.Row(): with gr.Row():
@ -1372,6 +1620,18 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
value=config_data["model"]["llm_topic_judge"]["provider"], value=config_data["model"]["llm_topic_judge"]["provider"],
label="主题判断模型提供商", label="主题判断模型提供商",
) )
with gr.Row():
topic_judge_model_pri_in = gr.Number(
value=config_data["model"]["llm_topic_judge"]["pri_in"],
label="主题判断模型的输入价格(非必填,可以记录消耗)",
)
with gr.Row():
topic_judge_model_pri_out = gr.Number(
value=config_data["model"]["llm_topic_judge"]["pri_out"],
label="主题判断模型的输出价格(非必填,可以记录消耗)",
)
with gr.Row():
gr.Markdown("""### 主题总结模型设置""")
with gr.Row(): with gr.Row():
summary_by_topic_model_name = gr.Textbox( summary_by_topic_model_name = gr.Textbox(
value=config_data["model"]["llm_summary_by_topic"]["name"], label="主题总结模型名称" value=config_data["model"]["llm_summary_by_topic"]["name"], label="主题总结模型名称"
@ -1382,6 +1642,16 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
value=config_data["model"]["llm_summary_by_topic"]["provider"], value=config_data["model"]["llm_summary_by_topic"]["provider"],
label="主题总结模型提供商", label="主题总结模型提供商",
) )
with gr.Row():
summary_by_topic_model_pri_in = gr.Number(
value=config_data["model"]["llm_summary_by_topic"]["pri_in"],
label="主题总结模型的输入价格(非必填,可以记录消耗)",
)
with gr.Row():
summary_by_topic_model_pri_out = gr.Number(
value=config_data["model"]["llm_summary_by_topic"]["pri_out"],
label="主题总结模型的输出价格(非必填,可以记录消耗)",
)
with gr.TabItem("5-识图模型"): with gr.TabItem("5-识图模型"):
with gr.Row(): with gr.Row():
gr.Markdown("""### 识图模型设置""") gr.Markdown("""### 识图模型设置""")
@ -1395,6 +1665,16 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
value=config_data["model"]["vlm"]["provider"], value=config_data["model"]["vlm"]["provider"],
label="识图模型提供商", label="识图模型提供商",
) )
with gr.Row():
vlm_model_pri_in = gr.Number(
value=config_data["model"]["vlm"]["pri_in"],
label="识图模型的输入价格(非必填,可以记录消耗)",
)
with gr.Row():
vlm_model_pri_out = gr.Number(
value=config_data["model"]["vlm"]["pri_out"],
label="识图模型的输出价格(非必填,可以记录消耗)",
)
with gr.Row(): with gr.Row():
save_model_btn = gr.Button("保存回复&模型设置", variant="primary", elem_id="save_model_btn") save_model_btn = gr.Button("保存回复&模型设置", variant="primary", elem_id="save_model_btn")
with gr.Row(): with gr.Row():
@ -1413,16 +1693,28 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
model1_pri_out, model1_pri_out,
model2_name, model2_name,
model2_provider, model2_provider,
model2_pri_in,
model2_pri_out,
model3_name, model3_name,
model3_provider, model3_provider,
model3_pri_in,
model3_pri_out,
emotion_model_name, emotion_model_name,
emotion_model_provider, emotion_model_provider,
emotion_model_pri_in,
emotion_model_pri_out,
topic_judge_model_name, topic_judge_model_name,
topic_judge_model_provider, topic_judge_model_provider,
topic_judge_model_pri_in,
topic_judge_model_pri_out,
summary_by_topic_model_name, summary_by_topic_model_name,
summary_by_topic_model_provider, summary_by_topic_model_provider,
summary_by_topic_model_pri_in,
summary_by_topic_model_pri_out,
vlm_model_name, vlm_model_name,
vlm_model_provider, vlm_model_provider,
vlm_model_pri_in,
vlm_model_pri_out,
], ],
outputs=[save_btn_message], outputs=[save_btn_message],
) )
@ -1436,6 +1728,79 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
value=config_data["memory"]["build_memory_interval"], value=config_data["memory"]["build_memory_interval"],
label="记忆构建间隔 单位秒,间隔越低,麦麦学习越多,但是冗余信息也会增多", label="记忆构建间隔 单位秒,间隔越低,麦麦学习越多,但是冗余信息也会增多",
) )
if PARSED_CONFIG_VERSION >= version.parse("0.0.11"):
with gr.Row():
gr.Markdown("---")
with gr.Row():
gr.Markdown("""### 记忆构建分布设置""")
with gr.Row():
gr.Markdown("""记忆构建分布参数说明:\n
分布1均值第一个正态分布的均值\n
分布1标准差第一个正态分布的标准差\n
分布1权重第一个正态分布的权重\n
分布2均值第二个正态分布的均值\n
分布2标准差第二个正态分布的标准差\n
分布2权重第二个正态分布的权重
""")
with gr.Row():
with gr.Column(scale=1):
build_memory_dist1_mean = gr.Number(
value=config_data["memory"].get(
"build_memory_distribution",
[4.0,2.0,0.6,24.0,8.0,0.4]
)[0],
label="分布1均值",
)
with gr.Column(scale=1):
build_memory_dist1_std = gr.Number(
value=config_data["memory"].get(
"build_memory_distribution",
[4.0,2.0,0.6,24.0,8.0,0.4]
)[1],
label="分布1标准差",
)
with gr.Column(scale=1):
build_memory_dist1_weight = gr.Number(
value=config_data["memory"].get(
"build_memory_distribution",
[4.0,2.0,0.6,24.0,8.0,0.4]
)[2],
label="分布1权重",
)
with gr.Row():
with gr.Column(scale=1):
build_memory_dist2_mean = gr.Number(
value=config_data["memory"].get(
"build_memory_distribution",
[4.0,2.0,0.6,24.0,8.0,0.4]
)[3],
label="分布2均值",
)
with gr.Column(scale=1):
build_memory_dist2_std = gr.Number(
value=config_data["memory"].get(
"build_memory_distribution",
[4.0,2.0,0.6,24.0,8.0,0.4]
)[4],
label="分布2标准差",
)
with gr.Column(scale=1):
build_memory_dist2_weight = gr.Number(
value=config_data["memory"].get(
"build_memory_distribution",
[4.0,2.0,0.6,24.0,8.0,0.4]
)[5],
label="分布2权重",
)
with gr.Row():
gr.Markdown("---")
else:
build_memory_dist1_mean = gr.Number(value=0.0,visible=False,interactive=False)
build_memory_dist1_std = gr.Number(value=0.0,visible=False,interactive=False)
build_memory_dist1_weight = gr.Number(value=0.0,visible=False,interactive=False)
build_memory_dist2_mean = gr.Number(value=0.0,visible=False,interactive=False)
build_memory_dist2_std = gr.Number(value=0.0,visible=False,interactive=False)
build_memory_dist2_weight = gr.Number(value=0.0,visible=False,interactive=False)
with gr.Row(): with gr.Row():
memory_compress_rate = gr.Number( memory_compress_rate = gr.Number(
value=config_data["memory"]["memory_compress_rate"], value=config_data["memory"]["memory_compress_rate"],
@ -1538,6 +1903,12 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
mood_update_interval, mood_update_interval,
mood_decay_rate, mood_decay_rate,
mood_intensity_factor, mood_intensity_factor,
build_memory_dist1_mean,
build_memory_dist1_std,
build_memory_dist1_weight,
build_memory_dist2_mean,
build_memory_dist2_std,
build_memory_dist2_weight,
], ],
outputs=[save_memory_mood_message], outputs=[save_memory_mood_message],
) )
@ -1709,22 +2080,31 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
keywords_reaction_enabled = gr.Checkbox( keywords_reaction_enabled = gr.Checkbox(
value=config_data["keywords_reaction"]["enable"], label="是否针对某个关键词作出反应" value=config_data["keywords_reaction"]["enable"], label="是否针对某个关键词作出反应"
) )
with gr.Row(): if PARSED_CONFIG_VERSION <= version.parse("0.0.10"):
enable_advance_output = gr.Checkbox( with gr.Row():
value=config_data["others"]["enable_advance_output"], label="是否开启高级输出" enable_advance_output = gr.Checkbox(
) value=config_data["others"]["enable_advance_output"], label="是否开启高级输出"
with gr.Row(): )
enable_kuuki_read = gr.Checkbox( with gr.Row():
value=config_data["others"]["enable_kuuki_read"], label="是否启用读空气功能" enable_kuuki_read = gr.Checkbox(
) value=config_data["others"]["enable_kuuki_read"], label="是否启用读空气功能"
with gr.Row(): )
enable_debug_output = gr.Checkbox( with gr.Row():
value=config_data["others"]["enable_debug_output"], label="是否开启调试输出" enable_debug_output = gr.Checkbox(
) value=config_data["others"]["enable_debug_output"], label="是否开启调试输出"
with gr.Row(): )
enable_friend_chat = gr.Checkbox( with gr.Row():
value=config_data["others"]["enable_friend_chat"], label="是否开启好友聊天" enable_friend_chat = gr.Checkbox(
) value=config_data["others"]["enable_friend_chat"], label="是否开启好友聊天"
)
elif PARSED_CONFIG_VERSION >= version.parse("0.0.11"):
with gr.Row():
enable_friend_chat = gr.Checkbox(
value=config_data["experimental"]["enable_friend_chat"], label="是否开启好友聊天"
)
enable_advance_output = gr.Checkbox(value=False,visible=False,interactive=False)
enable_kuuki_read = gr.Checkbox(value=False,visible=False,interactive=False)
enable_debug_output = gr.Checkbox(value=False,visible=False,interactive=False)
if PARSED_CONFIG_VERSION > HAVE_ONLINE_STATUS_VERSION: if PARSED_CONFIG_VERSION > HAVE_ONLINE_STATUS_VERSION:
with gr.Row(): with gr.Row():
gr.Markdown( gr.Markdown(
@ -1736,7 +2116,28 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
remote_status = gr.Checkbox( remote_status = gr.Checkbox(
value=config_data["remote"]["enable"], label="是否开启麦麦在线全球统计" value=config_data["remote"]["enable"], label="是否开启麦麦在线全球统计"
) )
if PARSED_CONFIG_VERSION >= version.parse("0.0.11"):
with gr.Row():
gr.Markdown("""### 回复分割器设置""")
with gr.Row():
enable_response_spliter = gr.Checkbox(
value=config_data["response_spliter"]["enable_response_spliter"],
label="是否启用回复分割器"
)
with gr.Row():
response_max_length = gr.Number(
value=config_data["response_spliter"]["response_max_length"],
label="回复允许的最大长度"
)
with gr.Row():
response_max_sentence_num = gr.Number(
value=config_data["response_spliter"]["response_max_sentence_num"],
label="回复允许的最大句子数"
)
else:
enable_response_spliter = gr.Checkbox(value=False,visible=False,interactive=False)
response_max_length = gr.Number(value=0,visible=False,interactive=False)
response_max_sentence_num = gr.Number(value=0,visible=False,interactive=False)
with gr.Row(): with gr.Row():
gr.Markdown("""### 中文错别字设置""") gr.Markdown("""### 中文错别字设置""")
with gr.Row(): with gr.Row():
@ -1790,14 +2191,56 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
tone_error_rate, tone_error_rate,
word_replace_rate, word_replace_rate,
remote_status, remote_status,
enable_response_spliter,
response_max_length,
response_max_sentence_num
], ],
outputs=[save_other_config_message], outputs=[save_other_config_message],
) )
app.queue().launch( # concurrency_count=511, max_size=1022 # 检查端口是否可用
server_name="0.0.0.0", def is_port_available(port, host='0.0.0.0'):
inbrowser=True, """检查指定的端口是否可用"""
share=is_share, try:
server_port=7000, # 创建一个socket对象
debug=debug, sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
quiet=True, # 设置socket重用地址选项
) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# 尝试绑定端口
sock.bind((host, port))
# 如果成功绑定则关闭socket并返回True
sock.close()
return True
except socket.error:
# 如果绑定失败,说明端口已被占用
return False
# 寻找可用端口
def find_available_port(start_port=7000, max_port=8000):
"""
从start_port开始寻找可用的端口
如果端口被占用尝试下一个端口直到找到可用端口或达到max_port
"""
port = start_port
while port <= max_port:
if is_port_available(port):
logger.info(f"找到可用端口: {port}")
return port
logger.warning(f"端口 {port} 已被占用,尝试下一个端口")
port += 1
# 如果所有端口都被占用返回None
logger.error(f"无法找到可用端口 (已尝试 {start_port}-{max_port})")
return None
# 寻找可用端口
launch_port = find_available_port(7000, 8000) or 7000
app.queue().launch( # concurrency_count=511, max_size=1022
server_name="0.0.0.0",
inbrowser=True,
share=is_share,
server_port=launch_port,
debug=debug,
quiet=True,
)

View File

@ -1,8 +1,7 @@
import tomli import tomli
import sys import sys
import re
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Set, Tuple from typing import Dict, Any, List, Tuple
def load_toml_file(file_path: str) -> Dict[str, Any]: def load_toml_file(file_path: str) -> Dict[str, Any]:
"""加载TOML文件""" """加载TOML文件"""
@ -184,10 +183,15 @@ def check_model_configurations(config: Dict[str, Any], env_vars: Dict[str, str])
provider = model_config["provider"].upper() provider = model_config["provider"].upper()
# 检查拼写错误 # 检查拼写错误
for known_provider, correct_provider in reverse_mapping.items(): for known_provider, _correct_provider in reverse_mapping.items():
# 使用模糊匹配检测拼写错误 # 使用模糊匹配检测拼写错误
if provider != known_provider and _similar_strings(provider, known_provider) and provider not in reverse_mapping: if (provider != known_provider and
errors.append(f"[model.{model_name}]的provider '{model_config['provider']}' 可能拼写错误,应为 '{known_provider}'") _similar_strings(provider, known_provider) and
provider not in reverse_mapping):
errors.append(
f"[model.{model_name}]的provider '{model_config['provider']}' "
f"可能拼写错误,应为 '{known_provider}'"
)
break break
return errors return errors
@ -223,7 +227,7 @@ def check_api_providers(config: Dict[str, Any], env_vars: Dict[str, str]) -> Lis
# 检查配置文件中使用的所有提供商 # 检查配置文件中使用的所有提供商
used_providers = set() used_providers = set()
for model_category, model_config in config["model"].items(): for _model_category, model_config in config["model"].items():
if "provider" in model_config: if "provider" in model_config:
provider = model_config["provider"] provider = model_config["provider"]
used_providers.add(provider) used_providers.add(provider)
@ -247,7 +251,7 @@ def check_api_providers(config: Dict[str, Any], env_vars: Dict[str, str]) -> Lis
# 特别检查常见的拼写错误 # 特别检查常见的拼写错误
for provider in used_providers: for provider in used_providers:
if provider.upper() == "SILICONFOLW": if provider.upper() == "SILICONFOLW":
errors.append(f"提供商 'SILICONFOLW' 存在拼写错误,应为 'SILICONFLOW'") errors.append("提供商 'SILICONFOLW' 存在拼写错误,应为 'SILICONFLOW'")
return errors return errors
@ -272,7 +276,7 @@ def check_groups_configuration(config: Dict[str, Any]) -> List[str]:
"main": "groups.talk_allowed中存在默认示例值'123',请修改为真实的群号", "main": "groups.talk_allowed中存在默认示例值'123',请修改为真实的群号",
"details": [ "details": [
f" 当前值: {groups['talk_allowed']}", f" 当前值: {groups['talk_allowed']}",
f" '123'为示例值,需要替换为真实群号" " '123'为示例值,需要替换为真实群号"
] ]
}) })
@ -371,7 +375,8 @@ def check_memory_config(config: Dict[str, Any]) -> List[str]:
if "memory_compress_rate" in memory and (memory["memory_compress_rate"] <= 0 or memory["memory_compress_rate"] > 1): if "memory_compress_rate" in memory and (memory["memory_compress_rate"] <= 0 or memory["memory_compress_rate"] > 1):
errors.append(f"memory.memory_compress_rate值无效: {memory['memory_compress_rate']}, 应在0-1之间") errors.append(f"memory.memory_compress_rate值无效: {memory['memory_compress_rate']}, 应在0-1之间")
if "memory_forget_percentage" in memory and (memory["memory_forget_percentage"] <= 0 or memory["memory_forget_percentage"] > 1): if ("memory_forget_percentage" in memory
and (memory["memory_forget_percentage"] <= 0 or memory["memory_forget_percentage"] > 1)):
errors.append(f"memory.memory_forget_percentage值无效: {memory['memory_forget_percentage']}, 应在0-1之间") errors.append(f"memory.memory_forget_percentage值无效: {memory['memory_forget_percentage']}, 应在0-1之间")
return errors return errors
@ -393,7 +398,10 @@ def check_personality_config(config: Dict[str, Any]) -> List[str]:
else: else:
# 检查数组长度 # 检查数组长度
if len(personality["prompt_personality"]) < 1: if len(personality["prompt_personality"]) < 1:
errors.append(f"personality.prompt_personality数组长度不足当前长度: {len(personality['prompt_personality'])}, 需要至少1项") errors.append(
f"personality.prompt_personality至少需要1项"
f"当前长度: {len(personality['prompt_personality'])}"
)
else: else:
# 模板默认值 # 模板默认值
template_values = [ template_values = [
@ -452,10 +460,13 @@ def check_bot_config(config: Dict[str, Any]) -> List[str]:
def format_results(all_errors): def format_results(all_errors):
"""格式化检查结果""" """格式化检查结果"""
sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results = all_errors sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results = all_errors # noqa: E501, F821
bot_errors, bot_infos = bot_results bot_errors, bot_infos = bot_results
if not any([sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_errors]): if not any([
sections_errors, prob_sum_errors,
prob_range_errors, model_errors, api_errors, groups_errors,
kr_errors, willing_errors, memory_errors, personality_errors, bot_errors]):
result = "✅ 配置文件检查通过,未发现问题。" result = "✅ 配置文件检查通过,未发现问题。"
# 添加机器人信息 # 添加机器人信息
@ -574,7 +585,10 @@ def main():
bot_results = check_bot_config(config) bot_results = check_bot_config(config)
# 格式化并打印结果 # 格式化并打印结果
all_errors = (sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results) all_errors = (
sections_errors, prob_sum_errors,
prob_range_errors, model_errors, api_errors, groups_errors,
kr_errors, willing_errors, memory_errors, personality_errors, bot_results)
result = format_results(all_errors) result = format_results(all_errors)
print("📋 机器人配置检查结果:") print("📋 机器人配置检查结果:")
print(result) print(result)
@ -586,7 +600,9 @@ def main():
bot_errors, _ = bot_results bot_errors, _ = bot_results
# 计算普通错误列表的长度 # 计算普通错误列表的长度
for errors in [sections_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, bot_errors]: for errors in [
sections_errors, model_errors, api_errors,
groups_errors, kr_errors, willing_errors, memory_errors, bot_errors]:
total_errors += len(errors) total_errors += len(errors)
# 计算元组列表的长度(概率相关错误) # 计算元组列表的长度(概率相关错误)