mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'dev' of https://github.com/SnowindMe/MaiBot into dev
commit
774e72340d
|
|
@ -1,5 +1,12 @@
|
||||||
name: Ruff
|
name: Ruff
|
||||||
on: [ push ]
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- dev
|
||||||
|
- dev-refactor # 例如:匹配所有以 feature/ 开头的分支
|
||||||
|
# 添加你希望触发此 workflow 的其他分支
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
|
|
@ -7,6 +14,10 @@ permissions:
|
||||||
jobs:
|
jobs:
|
||||||
ruff:
|
ruff:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
# 关键修改:添加条件判断
|
||||||
|
# 确保只有在 event_name 是 'push' 且不是由 Pull Request 引起的 push 时才运行
|
||||||
|
if: github.event_name == 'push' && !startsWith(github.ref, 'refs/pull/')
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
|
@ -25,5 +36,4 @@ jobs:
|
||||||
git config --local user.name "github-actions[bot]"
|
git config --local user.name "github-actions[bot]"
|
||||||
git add -A
|
git add -A
|
||||||
git diff --quiet && git diff --staged --quiet || git commit -m "🤖 自动格式化代码 [skip ci]"
|
git diff --quiet && git diff --staged --quiet || git commit -m "🤖 自动格式化代码 [skip ci]"
|
||||||
git push
|
git push
|
||||||
|
|
||||||
|
|
@ -65,11 +65,11 @@
|
||||||
|
|
||||||
## 💬 讨论
|
## 💬 讨论
|
||||||
|
|
||||||
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) |
|
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) |
|
||||||
|
[四群](https://qm.qq.com/q/wGePTl1UyY) |
|
||||||
[二群](https://qm.qq.com/q/RzmCiRtHEW) |
|
[二群](https://qm.qq.com/q/RzmCiRtHEW) |
|
||||||
[五群](https://qm.qq.com/q/JxvHZnxyec) |
|
[五群](https://qm.qq.com/q/JxvHZnxyec)(已满) |
|
||||||
[三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)|
|
[三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)
|
||||||
[四群](https://qm.qq.com/q/wGePTl1UyY)(已满)
|
|
||||||
|
|
||||||
## 📚 文档
|
## 📚 文档
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,11 @@
|
||||||
- 表达器:装饰语言风格
|
- 表达器:装饰语言风格
|
||||||
- 可通过插件添加和自定义HFC部件(目前只支持action定义)
|
- 可通过插件添加和自定义HFC部件(目前只支持action定义)
|
||||||
|
|
||||||
|
**插件系统**
|
||||||
|
- 添加示例插件
|
||||||
|
- 示例插件:禁言插件
|
||||||
|
- 示例插件:豆包绘图插件
|
||||||
|
|
||||||
**新增表达方式学习**
|
**新增表达方式学习**
|
||||||
- 自主学习群聊中的表达方式,更贴近群友
|
- 自主学习群聊中的表达方式,更贴近群友
|
||||||
- 可自定义的学习频率和开关
|
- 可自定义的学习频率和开关
|
||||||
|
|
@ -45,7 +50,6 @@
|
||||||
**优化**
|
**优化**
|
||||||
- 移除日程系统,减少幻觉(将会在未来版本回归)
|
- 移除日程系统,减少幻觉(将会在未来版本回归)
|
||||||
- 移除主心流思考和LLM进入聊天判定
|
- 移除主心流思考和LLM进入聊天判定
|
||||||
-
|
|
||||||
|
|
||||||
|
|
||||||
## [0.6.3-fix-4] - 2025-5-18
|
## [0.6.3-fix-4] - 2025-5-18
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9
|
# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9
|
||||||
# 请小心使用任何一键脚本!
|
# 请小心使用任何一键脚本!
|
||||||
|
|
||||||
INSTALLER_VERSION="0.0.4-refactor"
|
INSTALLER_VERSION="0.0.5-refactor"
|
||||||
LANG=C.UTF-8
|
LANG=C.UTF-8
|
||||||
|
|
||||||
# 如无法访问GitHub请修改此处镜像地址
|
# 如无法访问GitHub请修改此处镜像地址
|
||||||
|
|
@ -33,7 +33,6 @@ SERVICE_NAME="maicore"
|
||||||
SERVICE_NAME_WEB="maicore-web"
|
SERVICE_NAME_WEB="maicore-web"
|
||||||
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
|
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
|
||||||
|
|
||||||
IS_INSTALL_MONGODB=false
|
|
||||||
IS_INSTALL_NAPCAT=false
|
IS_INSTALL_NAPCAT=false
|
||||||
IS_INSTALL_DEPENDENCIES=false
|
IS_INSTALL_DEPENDENCIES=false
|
||||||
|
|
||||||
|
|
@ -255,7 +254,6 @@ run_installation() {
|
||||||
return
|
return
|
||||||
elif [[ "$ID" == "arch" ]]; then
|
elif [[ "$ID" == "arch" ]]; then
|
||||||
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法,将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
|
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法,将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
|
||||||
whiptail --title "⚠️ 兼容性警告" --msgbox "MongoDB无可用的 Arch Linux 官方安装方法,将无法自动安装MongoDB。\n\n您可尝试在AUR中搜索相关包。" 10 60
|
|
||||||
return
|
return
|
||||||
else
|
else
|
||||||
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9!\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
|
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9!\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
|
||||||
|
|
@ -282,16 +280,6 @@ run_installation() {
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|
||||||
# 检查MongoDB
|
|
||||||
check_mongodb() {
|
|
||||||
if command -v mongod &>/dev/null; then
|
|
||||||
MONGO_INSTALLED=true
|
|
||||||
else
|
|
||||||
MONGO_INSTALLED=false
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
check_mongodb
|
|
||||||
|
|
||||||
# 检查NapCat
|
# 检查NapCat
|
||||||
check_napcat() {
|
check_napcat() {
|
||||||
if command -v napcat &>/dev/null; then
|
if command -v napcat &>/dev/null; then
|
||||||
|
|
@ -330,19 +318,7 @@ run_installation() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
install_packages
|
install_packages
|
||||||
|
|
||||||
# 安装MongoDB
|
|
||||||
install_mongodb() {
|
|
||||||
[[ $MONGO_INSTALLED == true ]] && return
|
|
||||||
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装MongoDB,是否安装?\n如果您想使用远程数据库,请跳过此步。" 10 60 && {
|
|
||||||
IS_INSTALL_MONGODB=true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 仅在非Arch系统上安装MongoDB
|
|
||||||
[[ "$ID" != "arch" ]] && install_mongodb
|
|
||||||
|
|
||||||
|
|
||||||
# 安装NapCat
|
# 安装NapCat
|
||||||
install_napcat() {
|
install_napcat() {
|
||||||
[[ $NAPCAT_INSTALLED == true ]] && return
|
[[ $NAPCAT_INSTALLED == true ]] && return
|
||||||
|
|
@ -413,9 +389,8 @@ run_installation() {
|
||||||
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
|
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
|
||||||
confirm_msg+="🔀 分支: $BRANCH\n"
|
confirm_msg+="🔀 分支: $BRANCH\n"
|
||||||
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
|
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
|
||||||
[[ $IS_INSTALL_MONGODB == true || $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
|
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
|
||||||
|
|
||||||
[[ $IS_INSTALL_MONGODB == true ]] && confirm_msg+=" - MongoDB\n"
|
|
||||||
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
|
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
|
||||||
confirm_msg+="\n注意:本脚本默认使用ghfast.top为GitHub进行加速,如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
|
confirm_msg+="\n注意:本脚本默认使用ghfast.top为GitHub进行加速,如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
|
||||||
|
|
||||||
|
|
@ -440,39 +415,6 @@ run_installation() {
|
||||||
esac
|
esac
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ $IS_INSTALL_MONGODB == true ]]; then
|
|
||||||
echo -e "${GREEN}安装 MongoDB...${RESET}"
|
|
||||||
case "$ID" in
|
|
||||||
debian)
|
|
||||||
curl -fsSL https://www.mongodb.org/static/pgp/server-8.0.asc | gpg -o /usr/share/keyrings/mongodb-server-8.0.gpg --dearmor
|
|
||||||
echo "deb [ signed-by=/usr/share/keyrings/mongodb-server-8.0.gpg ] http://repo.mongodb.org/apt/debian bookworm/mongodb-org/8.0 main" | tee /etc/apt/sources.list.d/mongodb-org-8.0.list
|
|
||||||
apt update
|
|
||||||
apt install -y mongodb-org
|
|
||||||
systemctl enable --now mongod
|
|
||||||
;;
|
|
||||||
ubuntu)
|
|
||||||
curl -fsSL https://www.mongodb.org/static/pgp/server-8.0.asc | gpg -o /usr/share/keyrings/mongodb-server-8.0.gpg --dearmor
|
|
||||||
echo "deb [ signed-by=/usr/share/keyrings/mongodb-server-8.0.gpg ] http://repo.mongodb.org/apt/debian bookworm/mongodb-org/8.0 main" | tee /etc/apt/sources.list.d/mongodb-org-8.0.list
|
|
||||||
apt update
|
|
||||||
apt install -y mongodb-org
|
|
||||||
systemctl enable --now mongod
|
|
||||||
;;
|
|
||||||
centos)
|
|
||||||
cat > /etc/yum.repos.d/mongodb-org-8.0.repo <<EOF
|
|
||||||
[mongodb-org-8.0]
|
|
||||||
name=MongoDB Repository
|
|
||||||
baseurl=https://repo.mongodb.org/yum/redhat/9/mongodb-org/8.0/x86_64/
|
|
||||||
gpgcheck=1
|
|
||||||
enabled=1
|
|
||||||
gpgkey=https://pgp.mongodb.com/server-8.0.asc
|
|
||||||
EOF
|
|
||||||
yum install -y mongodb-org
|
|
||||||
systemctl enable --now mongod
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ $IS_INSTALL_NAPCAT == true ]]; then
|
if [[ $IS_INSTALL_NAPCAT == true ]]; then
|
||||||
echo -e "${GREEN}安装 NapCat...${RESET}"
|
echo -e "${GREEN}安装 NapCat...${RESET}"
|
||||||
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
|
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
|
||||||
|
|
@ -537,7 +479,7 @@ EOF
|
||||||
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
|
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
|
||||||
[Unit]
|
[Unit]
|
||||||
Description=MaiCore
|
Description=MaiCore
|
||||||
After=network.target mongod.service ${SERVICE_NAME_NBADAPTER}.service
|
After=network.target ${SERVICE_NAME_NBADAPTER}.service
|
||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
Type=simple
|
Type=simple
|
||||||
|
|
@ -550,21 +492,21 @@ RestartSec=10s
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
|
# cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
|
||||||
[Unit]
|
# [Unit]
|
||||||
Description=MaiCore WebUI
|
# Description=MaiCore WebUI
|
||||||
After=network.target mongod.service ${SERVICE_NAME}.service
|
# After=network.target ${SERVICE_NAME}.service
|
||||||
|
|
||||||
[Service]
|
# [Service]
|
||||||
Type=simple
|
# Type=simple
|
||||||
WorkingDirectory=${INSTALL_DIR}/MaiBot
|
# WorkingDirectory=${INSTALL_DIR}/MaiBot
|
||||||
ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
|
# ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
|
||||||
Restart=always
|
# Restart=always
|
||||||
RestartSec=10s
|
# RestartSec=10s
|
||||||
|
|
||||||
[Install]
|
# [Install]
|
||||||
WantedBy=multi-user.target
|
# WantedBy=multi-user.target
|
||||||
EOF
|
# EOF
|
||||||
|
|
||||||
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
|
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
|
||||||
[Unit]
|
[Unit]
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ MaiBot模块系统
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
from src.chat.normal_chat.willing.willing_manager import willing_manager
|
from src.chat.normal_chat.willing.willing_manager import willing_manager
|
||||||
|
|
||||||
# 导出主要组件供外部使用
|
# 导出主要组件供外部使用
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,11 @@ import re
|
||||||
|
|
||||||
# from gradio_client import file
|
# from gradio_client import file
|
||||||
|
|
||||||
from ...common.database.database_model import Emoji
|
from src.common.database.database_model import Emoji
|
||||||
from ...common.database.database import db as peewee_db
|
from src.common.database.database import db as peewee_db
|
||||||
from ...config.config import global_config
|
from src.config.config import global_config
|
||||||
from ..utils.utils_image import image_path_to_base64, image_manager
|
from src.chat.utils.utils_image import image_path_to_base64, image_manager
|
||||||
from ..models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from src.chat.message_receive.message import Seg # Local import needed after mo
|
||||||
from src.chat.message_receive.message import UserInfo
|
from src.chat.message_receive.message import UserInfo
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
|
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
|
||||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import time
|
||||||
import random
|
import random
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
||||||
from src.chat.focus_chat.heartflow_prompt_builder import Prompt, global_prompt_manager
|
from src.chat.focus_chat.heartflow_prompt_builder import Prompt, global_prompt_manager
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any, Deque
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.chat.heart_flow.observation.observation import Observation
|
from src.chat.heart_flow.observation.observation import Observation
|
||||||
|
|
@ -30,13 +31,14 @@ from src.config.config import global_config
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
# 定义处理器映射
|
# 定义处理器映射:键是处理器名称,值是 (处理器类, 可选的配置键名)
|
||||||
|
# 如果配置键名为 None,则该处理器默认启用且不能通过 focus_chat_processor 配置禁用
|
||||||
PROCESSOR_CLASSES = {
|
PROCESSOR_CLASSES = {
|
||||||
"ChattingInfoProcessor": ChattingInfoProcessor,
|
"ChattingInfoProcessor": (ChattingInfoProcessor, None),
|
||||||
"MindProcessor": MindProcessor,
|
"MindProcessor": (MindProcessor, None),
|
||||||
"ToolProcessor": ToolProcessor,
|
"ToolProcessor": (ToolProcessor, "tool_use_processor"),
|
||||||
"WorkingMemoryProcessor": WorkingMemoryProcessor,
|
"WorkingMemoryProcessor": (WorkingMemoryProcessor, "working_memory_processor"),
|
||||||
"SelfProcessor": SelfProcessor,
|
"SelfProcessor": (SelfProcessor, "self_identify_processor"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -101,27 +103,19 @@ class HeartFChatting:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据配置文件和默认规则确定启用的处理器
|
# 根据配置文件和默认规则确定启用的处理器
|
||||||
PROCESSOR_NAME_TO_CONFIG_KEY_MAP = {
|
|
||||||
"SelfProcessor": "self_identify_processor",
|
|
||||||
"ToolProcessor": "tool_use_processor",
|
|
||||||
"WorkingMemoryProcessor": "working_memory_processor",
|
|
||||||
}
|
|
||||||
self.enabled_processor_names: List[str] = []
|
self.enabled_processor_names: List[str] = []
|
||||||
config_processor_settings = global_config.focus_chat_processor # 获取处理器配置,若不存在则为空字典
|
config_processor_settings = global_config.focus_chat_processor
|
||||||
|
|
||||||
for proc_name in PROCESSOR_CLASSES.keys():
|
for proc_name, (_proc_class, config_key) in PROCESSOR_CLASSES.items():
|
||||||
config_key = PROCESSOR_NAME_TO_CONFIG_KEY_MAP.get(proc_name)
|
if config_key: # 此处理器可通过配置控制
|
||||||
if config_key:
|
if getattr(config_processor_settings, config_key, True): # 默认启用 (如果配置中未指定该键)
|
||||||
# 此处理器可通过配置控制
|
|
||||||
# getattr(config_processor_settings, config_key, True)
|
|
||||||
# 如果config_processor_settings是字典,则用 config_processor_settings.get(config_key, True)
|
|
||||||
if getattr(config_processor_settings, config_key, True): # 默认启用,如果配置中未指定
|
|
||||||
self.enabled_processor_names.append(proc_name)
|
self.enabled_processor_names.append(proc_name)
|
||||||
else:
|
else: # 此处理器不在配置映射中 (config_key is None),默认启用
|
||||||
# 此处理器不在配置映射中,默认启用
|
|
||||||
self.enabled_processor_names.append(proc_name)
|
self.enabled_processor_names.append(proc_name)
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 将启用的处理器: {self.enabled_processor_names}")
|
logger.info(f"{self.log_prefix} 将启用的处理器: {self.enabled_processor_names}")
|
||||||
|
self.processors: List[BaseProcessor] = []
|
||||||
|
self._register_default_processors()
|
||||||
|
|
||||||
self.expressor = DefaultExpressor(chat_id=self.stream_id)
|
self.expressor = DefaultExpressor(chat_id=self.stream_id)
|
||||||
self.action_manager = ActionManager()
|
self.action_manager = ActionManager()
|
||||||
|
|
@ -130,9 +124,6 @@ class HeartFChatting:
|
||||||
self.hfcloop_observation.set_action_manager(self.action_manager)
|
self.hfcloop_observation.set_action_manager(self.action_manager)
|
||||||
|
|
||||||
self.all_observations = observations
|
self.all_observations = observations
|
||||||
# --- 处理器列表 ---
|
|
||||||
self.processors: List[BaseProcessor] = []
|
|
||||||
self._register_default_processors()
|
|
||||||
|
|
||||||
# 初始化状态控制
|
# 初始化状态控制
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
@ -186,19 +177,20 @@ class HeartFChatting:
|
||||||
"""根据 self.enabled_processor_names 注册信息处理器"""
|
"""根据 self.enabled_processor_names 注册信息处理器"""
|
||||||
self.processors = [] # 清空已有的
|
self.processors = [] # 清空已有的
|
||||||
|
|
||||||
# self.enabled_processor_names 由 __init__ 保证是一个列表
|
|
||||||
for name in self.enabled_processor_names: # 'name' is "ChattingInfoProcessor", etc.
|
for name in self.enabled_processor_names: # 'name' is "ChattingInfoProcessor", etc.
|
||||||
processor_class = PROCESSOR_CLASSES.get(name)
|
processor_info = PROCESSOR_CLASSES.get(name) # processor_info is (ProcessorClass, config_key)
|
||||||
if processor_class:
|
if processor_info:
|
||||||
|
processor_actual_class = processor_info[0] # 获取实际的类定义
|
||||||
# 根据处理器类名判断是否需要 subheartflow_id
|
# 根据处理器类名判断是否需要 subheartflow_id
|
||||||
if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor"]:
|
if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor"]:
|
||||||
self.processors.append(processor_class(subheartflow_id=self.stream_id))
|
self.processors.append(processor_actual_class(subheartflow_id=self.stream_id))
|
||||||
elif name == "ChattingInfoProcessor":
|
elif name == "ChattingInfoProcessor":
|
||||||
self.processors.append(processor_class())
|
self.processors.append(processor_actual_class())
|
||||||
else:
|
else:
|
||||||
# 对于PROCESSOR_CLASSES中定义但此处未明确处理构造的处理器
|
# 对于PROCESSOR_CLASSES中定义但此处未明确处理构造的处理器
|
||||||
|
# (例如, 新增了一个处理器到PROCESSOR_CLASSES, 它不需要id, 也不叫ChattingInfoProcessor)
|
||||||
try:
|
try:
|
||||||
self.processors.append(processor_class()) # 尝试无参构造
|
self.processors.append(processor_actual_class()) # 尝试无参构造
|
||||||
logger.debug(f"{self.log_prefix} 注册处理器 {name} (尝试无参构造).")
|
logger.debug(f"{self.log_prefix} 注册处理器 {name} (尝试无参构造).")
|
||||||
except TypeError:
|
except TypeError:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -206,7 +198,9 @@ class HeartFChatting:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 这理论上不应该发生,因为 enabled_processor_names 是从 PROCESSOR_CLASSES 的键生成的
|
# 这理论上不应该发生,因为 enabled_processor_names 是从 PROCESSOR_CLASSES 的键生成的
|
||||||
logger.warning(f"{self.log_prefix} 在 PROCESSOR_CLASSES 中未找到名为 '{name}' 的处理器,将跳过注册。")
|
logger.warning(
|
||||||
|
f"{self.log_prefix} 在 PROCESSOR_CLASSES 中未找到名为 '{name}' 的处理器定义,将跳过注册。"
|
||||||
|
)
|
||||||
|
|
||||||
if self.processors:
|
if self.processors:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -286,8 +280,9 @@ class HeartFChatting:
|
||||||
thinking_id = "tid" + str(round(time.time(), 2))
|
thinking_id = "tid" + str(round(time.time(), 2))
|
||||||
self._current_cycle.set_thinking_id(thinking_id)
|
self._current_cycle.set_thinking_id(thinking_id)
|
||||||
# 主循环:思考->决策->执行
|
# 主循环:思考->决策->执行
|
||||||
|
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||||
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id)
|
logger.debug(f"模板 {self.chat_stream.context.get_template_name()}")
|
||||||
|
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id)
|
||||||
|
|
||||||
self._current_cycle.set_loop_info(loop_info)
|
self._current_cycle.set_loop_info(loop_info)
|
||||||
|
|
||||||
|
|
@ -430,7 +425,10 @@ class HeartFChatting:
|
||||||
self.all_observations = observations
|
self.all_observations = observations
|
||||||
|
|
||||||
with Timer("回忆", cycle_timers):
|
with Timer("回忆", cycle_timers):
|
||||||
|
logger.debug(f"{self.log_prefix} 开始回忆")
|
||||||
running_memorys = await self.memory_activator.activate_memory(observations)
|
running_memorys = await self.memory_activator.activate_memory(observations)
|
||||||
|
logger.debug(f"{self.log_prefix} 回忆完成")
|
||||||
|
print(running_memorys)
|
||||||
|
|
||||||
with Timer("执行 信息处理器", cycle_timers):
|
with Timer("执行 信息处理器", cycle_timers):
|
||||||
all_plan_info = await self._process_processors(observations, running_memorys, cycle_timers)
|
all_plan_info = await self._process_processors(observations, running_memorys, cycle_timers)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from ..message_receive.chat_stream import chat_manager
|
||||||
|
|
||||||
# from ..message_receive.message_buffer import message_buffer
|
# from ..message_receive.message_buffer import message_buffer
|
||||||
from ..utils.timer_calculator import Timer
|
from ..utils.timer_calculator import Timer
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
from typing import Optional, Tuple, Dict, Any
|
from typing import Optional, Tuple, Dict, Any
|
||||||
|
|
||||||
logger = get_logger("chat")
|
logger = get_logger("chat")
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from src.common.logger_manager import get_logger
|
||||||
from src.individuality.individuality import individuality
|
from src.individuality.individuality import individuality
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from src.chat.utils.utils import get_recent_group_speaker
|
from src.chat.utils.utils import get_recent_group_speaker
|
||||||
|
|
@ -120,7 +120,6 @@ class PromptBuilder:
|
||||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
relation_prompt += await relationship_manager.build_relationship_info(person)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}")
|
logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}")
|
||||||
|
|
||||||
mood_prompt = mood_manager.get_mood_prompt()
|
mood_prompt = mood_manager.get_mood_prompt()
|
||||||
reply_styles1 = [
|
reply_styles1 = [
|
||||||
("然后给出日常且口语化的回复,平淡一些", 0.4),
|
("然后给出日常且口语化的回复,平淡一些", 0.4),
|
||||||
|
|
@ -141,9 +140,11 @@ class PromptBuilder:
|
||||||
[style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1
|
[style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1
|
||||||
)[0]
|
)[0]
|
||||||
memory_prompt = ""
|
memory_prompt = ""
|
||||||
|
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
)
|
)
|
||||||
|
|
||||||
related_memory_info = ""
|
related_memory_info = ""
|
||||||
if related_memory:
|
if related_memory:
|
||||||
for memory in related_memory:
|
for memory in related_memory:
|
||||||
|
|
|
||||||
|
|
@ -37,4 +37,4 @@ class SelfInfo(InfoBase):
|
||||||
Returns:
|
Returns:
|
||||||
str: 处理后的信息
|
str: 处理后的信息
|
||||||
"""
|
"""
|
||||||
return self.get_self_info()
|
return self.get_self_info() or ""
|
||||||
|
|
|
||||||
|
|
@ -67,3 +67,16 @@ class StructuredInfo:
|
||||||
value: 要设置的属性值
|
value: 要设置的属性值
|
||||||
"""
|
"""
|
||||||
self.data[key] = value
|
self.data[key] = value
|
||||||
|
|
||||||
|
def get_processed_info(self) -> str:
|
||||||
|
"""获取处理后的信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 处理后的信息字符串
|
||||||
|
"""
|
||||||
|
|
||||||
|
info_str = ""
|
||||||
|
for key, value in self.data.items():
|
||||||
|
info_str += f"信息类型:{key},信息内容:{value}\n"
|
||||||
|
|
||||||
|
return info_str
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from .base_processor import BaseProcessor
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservati
|
||||||
from src.chat.focus_chat.info.cycle_info import CycleInfo
|
from src.chat.focus_chat.info.cycle_info import CycleInfo
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
logger = get_logger("processor")
|
logger = get_logger("processor")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||||
from src.chat.heart_flow.observation.observation import Observation
|
from src.chat.heart_flow.observation.observation import Observation
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -9,7 +9,7 @@ from src.individuality.individuality import individuality
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.json_utils import safe_json_dumps
|
from src.chat.utils.json_utils import safe_json_dumps
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
from .base_processor import BaseProcessor
|
from .base_processor import BaseProcessor
|
||||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
from src.chat.focus_chat.info.mind_info import MindInfo
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||||
from src.chat.heart_flow.observation.observation import Observation
|
from src.chat.heart_flow.observation.observation import Observation
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -8,7 +8,7 @@ from src.common.logger_manager import get_logger
|
||||||
from src.individuality.individuality import individuality
|
from src.individuality.individuality import individuality
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
from .base_processor import BaseProcessor
|
from .base_processor import BaseProcessor
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||||
|
|
@ -33,12 +33,13 @@ def init_prompt():
|
||||||
|
|
||||||
现在请你根据现有的信息,思考自我认同
|
现在请你根据现有的信息,思考自我认同
|
||||||
1. 你是一个什么样的人,你和群里的人关系如何
|
1. 你是一个什么样的人,你和群里的人关系如何
|
||||||
2. 思考有没有人提到你,或者图片与你有关
|
2. 你的形象是什么
|
||||||
3. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十个字以内的简短自我认同
|
3. 思考有没有人提到你,或者图片与你有关
|
||||||
4. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景
|
4. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十几个字的简短自我认同
|
||||||
|
5. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景
|
||||||
|
|
||||||
请思考的平淡一些,简短一些,说中文,不要浮夸,平淡一些。
|
输出内容平淡一些,说中文,不要浮夸,平淡一些。
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出自我认同内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出自我认同内容,记得明确说明这是你的自我认同。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Prompt(indentify_prompt, "indentify_prompt")
|
Prompt(indentify_prompt, "indentify_prompt")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
import time
|
import time
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
|
@ -7,7 +7,7 @@ from src.individuality.individuality import individuality
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.tools.tool_use import ToolUser
|
from src.tools.tool_use import ToolUser
|
||||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
from .base_processor import BaseProcessor
|
from .base_processor import BaseProcessor
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
from src.chat.heart_flow.observation.observation import Observation
|
from src.chat.heart_flow.observation.observation import Observation
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||||
from src.chat.heart_flow.observation.observation import Observation
|
from src.chat.heart_flow.observation.observation import Observation
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||||
from src.chat.heart_flow.observation.structure_observation import StructureObservation
|
from src.chat.heart_flow.observation.structure_observation import StructureObservation
|
||||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt
|
from src.chat.utils.prompt_builder import Prompt
|
||||||
|
|
@ -61,6 +61,8 @@ class MemoryActivator:
|
||||||
elif isinstance(observation, HFCloopObservation):
|
elif isinstance(observation, HFCloopObservation):
|
||||||
obs_info_text += observation.get_observe_info()
|
obs_info_text += observation.get_observe_info()
|
||||||
|
|
||||||
|
logger.debug(f"回忆待检索内容:obs_info_text: {obs_info_text}")
|
||||||
|
|
||||||
# prompt = await global_prompt_manager.format_prompt(
|
# prompt = await global_prompt_manager.format_prompt(
|
||||||
# "memory_activator_prompt",
|
# "memory_activator_prompt",
|
||||||
# obs_info_text=obs_info_text,
|
# obs_info_text=obs_info_text,
|
||||||
|
|
@ -81,7 +83,7 @@ class MemoryActivator:
|
||||||
# valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
|
# valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
|
||||||
# )
|
# )
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=obs_info_text, max_memory_num=3, max_memory_length=2, max_depth=3, fast_retrieval=True
|
text=obs_info_text, max_memory_num=5, max_memory_length=2, max_depth=3, fast_retrieval=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# logger.debug(f"获取到的记忆: {related_memory}")
|
# logger.debug(f"获取到的记忆: {related_memory}")
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ class ExitFocusChatAction(BaseAction):
|
||||||
if self.sub_heartflow:
|
if self.sub_heartflow:
|
||||||
try:
|
try:
|
||||||
# 转换为normal_chat状态
|
# 转换为normal_chat状态
|
||||||
await self.sub_heartflow.change_chat_state(ChatState.NORMAL_CHAT)
|
await self.sub_heartflow.change_chat_state(ChatState.CHAT)
|
||||||
status_message = "已成功切换到普通聊天模式"
|
status_message = "已成功切换到普通聊天模式"
|
||||||
logger.info(f"{self.log_prefix} {status_message}")
|
logger.info(f"{self.log_prefix} {status_message}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple, Dict, List, Any, Optional
|
from typing import Tuple, Dict, List, Any, Optional
|
||||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction
|
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action # noqa F401
|
||||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
import toml # 导入 toml 库
|
||||||
|
|
||||||
logger = get_logger("plugin_action")
|
logger = get_logger("plugin_action")
|
||||||
|
|
||||||
|
|
@ -16,12 +19,24 @@ class PluginAction(BaseAction):
|
||||||
封装了主程序内部依赖,提供简化的API接口给插件开发者
|
封装了主程序内部依赖,提供简化的API接口给插件开发者
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs):
|
action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
action_data: dict,
|
||||||
|
reasoning: str,
|
||||||
|
cycle_timers: dict,
|
||||||
|
thinking_id: str,
|
||||||
|
global_config: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""初始化插件动作基类"""
|
"""初始化插件动作基类"""
|
||||||
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
||||||
|
|
||||||
# 存储内部服务和对象引用
|
# 存储内部服务和对象引用
|
||||||
self._services = {}
|
self._services = {}
|
||||||
|
self._global_config = global_config # 存储全局配置的只读引用
|
||||||
|
self.config: Dict[str, Any] = {} # 用于存储插件自身的配置
|
||||||
|
|
||||||
# 从kwargs提取必要的内部服务
|
# 从kwargs提取必要的内部服务
|
||||||
if "observations" in kwargs:
|
if "observations" in kwargs:
|
||||||
|
|
@ -32,6 +47,61 @@ class PluginAction(BaseAction):
|
||||||
self._services["chat_stream"] = kwargs["chat_stream"]
|
self._services["chat_stream"] = kwargs["chat_stream"]
|
||||||
|
|
||||||
self.log_prefix = kwargs.get("log_prefix", "")
|
self.log_prefix = kwargs.get("log_prefix", "")
|
||||||
|
self._load_plugin_config() # 初始化时加载插件配置
|
||||||
|
|
||||||
|
def _load_plugin_config(self):
|
||||||
|
"""
|
||||||
|
加载插件自身的配置文件。
|
||||||
|
配置文件应与插件模块在同一目录下。
|
||||||
|
插件可以通过覆盖 `action_config_file_name` 类属性来指定文件名。
|
||||||
|
如果 `action_config_file_name` 未指定,则不加载配置。
|
||||||
|
仅支持 TOML (.toml) 格式。
|
||||||
|
"""
|
||||||
|
if not self.action_config_file_name:
|
||||||
|
logger.debug(
|
||||||
|
f"{self.log_prefix} 插件 {self.__class__.__name__} 未指定 action_config_file_name,不加载插件配置。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
plugin_module_path = inspect.getfile(self.__class__)
|
||||||
|
plugin_dir = os.path.dirname(plugin_module_path)
|
||||||
|
config_file_path = os.path.join(plugin_dir, self.action_config_file_name)
|
||||||
|
|
||||||
|
if not os.path.exists(config_file_path):
|
||||||
|
logger.warning(
|
||||||
|
f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置文件 {config_file_path} 不存在。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
file_ext = os.path.splitext(self.action_config_file_name)[1].lower()
|
||||||
|
|
||||||
|
if file_ext == ".toml":
|
||||||
|
with open(config_file_path, "r", encoding="utf-8") as f:
|
||||||
|
self.config = toml.load(f) or {}
|
||||||
|
logger.info(f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置已从 {config_file_path} 加载。")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"{self.log_prefix} 不支持的插件配置文件格式: {file_ext}。仅支持 .toml。插件配置未加载。"
|
||||||
|
)
|
||||||
|
self.config = {} # 确保未加载时为空字典
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"{self.log_prefix} 加载插件 {self.__class__.__name__} 的配置文件 {self.action_config_file_name} 时出错: {e}"
|
||||||
|
)
|
||||||
|
self.config = {} # 出错时确保 config 是一个空字典
|
||||||
|
|
||||||
|
def get_global_config(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""
|
||||||
|
安全地从全局配置中获取一个值。
|
||||||
|
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||||
|
"""
|
||||||
|
if self._global_config:
|
||||||
|
return self._global_config.get(key, default)
|
||||||
|
logger.debug(f"{self.log_prefix} 尝试访问全局配置项 '{key}',但全局配置未提供。")
|
||||||
|
return default
|
||||||
|
|
||||||
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
|
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
|
||||||
"""根据用户名获取用户ID"""
|
"""根据用户名获取用户ID"""
|
||||||
|
|
@ -41,7 +111,7 @@ class PluginAction(BaseAction):
|
||||||
return platform, user_id
|
return platform, user_id
|
||||||
|
|
||||||
# 提供简化的API方法
|
# 提供简化的API方法
|
||||||
async def send_message(self, text: str, target: Optional[str] = None) -> bool:
|
async def send_message(self, type: str, data: str, target: Optional[str] = "") -> bool:
|
||||||
"""发送消息的简化方法
|
"""发送消息的简化方法
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -60,7 +130,7 @@ class PluginAction(BaseAction):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 构造简化的动作数据
|
# 构造简化的动作数据
|
||||||
reply_data = {"text": text, "target": target or "", "emojis": []}
|
# reply_data = {"text": text, "target": target or "", "emojis": []}
|
||||||
|
|
||||||
# 获取锚定消息(如果有)
|
# 获取锚定消息(如果有)
|
||||||
observations = self._services.get("observations", [])
|
observations = self._services.get("observations", [])
|
||||||
|
|
@ -68,7 +138,8 @@ class PluginAction(BaseAction):
|
||||||
chatting_observation: ChattingObservation = next(
|
chatting_observation: ChattingObservation = next(
|
||||||
obs for obs in observations if isinstance(obs, ChattingObservation)
|
obs for obs in observations if isinstance(obs, ChattingObservation)
|
||||||
)
|
)
|
||||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
|
||||||
|
anchor_message = chatting_observation.search_message_by_text(target)
|
||||||
|
|
||||||
# 如果没有找到锚点消息,创建一个占位符
|
# 如果没有找到锚点消息,创建一个占位符
|
||||||
if not anchor_message:
|
if not anchor_message:
|
||||||
|
|
@ -80,7 +151,7 @@ class PluginAction(BaseAction):
|
||||||
anchor_message.update_chat_stream(chat_stream)
|
anchor_message.update_chat_stream(chat_stream)
|
||||||
|
|
||||||
response_set = [
|
response_set = [
|
||||||
("text", text),
|
(type, data),
|
||||||
]
|
]
|
||||||
|
|
||||||
# 调用内部方法发送消息
|
# 调用内部方法发送消息
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import json # <--- 确保导入 json
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.focus_chat.info.info_base import InfoBase
|
from src.chat.focus_chat.info.info_base import InfoBase
|
||||||
from src.chat.focus_chat.info.obs_info import ObsInfo
|
from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||||
|
|
@ -10,6 +10,7 @@ from src.chat.focus_chat.info.cycle_info import CycleInfo
|
||||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
from src.chat.focus_chat.info.mind_info import MindInfo
|
||||||
from src.chat.focus_chat.info.action_info import ActionInfo
|
from src.chat.focus_chat.info.action_info import ActionInfo
|
||||||
from src.chat.focus_chat.info.structured_info import StructuredInfo
|
from src.chat.focus_chat.info.structured_info import StructuredInfo
|
||||||
|
from src.chat.focus_chat.info.self_info import SelfInfo
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.individuality.individuality import individuality
|
from src.individuality.individuality import individuality
|
||||||
|
|
@ -22,7 +23,11 @@ install(extra_lines=3)
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{extra_info_block}
|
"""
|
||||||
|
你的自我认知是:
|
||||||
|
{self_info_block}
|
||||||
|
|
||||||
|
{extra_info_block}
|
||||||
|
|
||||||
你需要基于以下信息决定如何参与对话
|
你需要基于以下信息决定如何参与对话
|
||||||
这些信息可能会有冲突,请你整合这些信息,并选择一个最合适的action:
|
这些信息可能会有冲突,请你整合这些信息,并选择一个最合适的action:
|
||||||
|
|
@ -127,6 +132,8 @@ class ActionPlanner:
|
||||||
current_mind = info.get_current_mind()
|
current_mind = info.get_current_mind()
|
||||||
elif isinstance(info, CycleInfo):
|
elif isinstance(info, CycleInfo):
|
||||||
cycle_info = info.get_observe_info()
|
cycle_info = info.get_observe_info()
|
||||||
|
elif isinstance(info, SelfInfo):
|
||||||
|
self_info = info.get_processed_info()
|
||||||
elif isinstance(info, StructuredInfo):
|
elif isinstance(info, StructuredInfo):
|
||||||
_structured_info = info.get_data()
|
_structured_info = info.get_data()
|
||||||
elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo
|
elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo
|
||||||
|
|
@ -148,6 +155,7 @@ class ActionPlanner:
|
||||||
|
|
||||||
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
||||||
prompt = await self.build_planner_prompt(
|
prompt = await self.build_planner_prompt(
|
||||||
|
self_info_block=self_info,
|
||||||
is_group_chat=is_group_chat, # <-- Pass HFC state
|
is_group_chat=is_group_chat, # <-- Pass HFC state
|
||||||
chat_target_info=None,
|
chat_target_info=None,
|
||||||
observed_messages_str=observed_messages_str, # <-- Pass local variable
|
observed_messages_str=observed_messages_str, # <-- Pass local variable
|
||||||
|
|
@ -236,6 +244,7 @@ class ActionPlanner:
|
||||||
|
|
||||||
async def build_planner_prompt(
|
async def build_planner_prompt(
|
||||||
self,
|
self,
|
||||||
|
self_info_block: str,
|
||||||
is_group_chat: bool, # Now passed as argument
|
is_group_chat: bool, # Now passed as argument
|
||||||
chat_target_info: Optional[dict], # Now passed as argument
|
chat_target_info: Optional[dict], # Now passed as argument
|
||||||
observed_messages_str: str,
|
observed_messages_str: str,
|
||||||
|
|
@ -301,7 +310,8 @@ class ActionPlanner:
|
||||||
|
|
||||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||||
prompt = planner_prompt_template.format(
|
prompt = planner_prompt_template.format(
|
||||||
bot_name=global_config.bot.nickname,
|
self_info_block=self_info_block,
|
||||||
|
# bot_name=global_config.bot.nickname,
|
||||||
prompt_personality=personality_block,
|
prompt_personality=personality_block,
|
||||||
chat_context_description=chat_context_description,
|
chat_context_description=chat_context_description,
|
||||||
chat_content_block=chat_content_block,
|
chat_content_block=chat_content_block,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import traceback
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||||
import json # 添加json模块导入
|
import json # 添加json模块导入
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
import traceback
|
import traceback
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
|
|
|
||||||
|
|
@ -88,5 +88,6 @@ class HFCloopObservation:
|
||||||
for action_name, action_info in using_actions.items():
|
for action_name, action_info in using_actions.items():
|
||||||
action_description = action_info["description"]
|
action_description = action_info["description"]
|
||||||
cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
|
cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
|
||||||
|
cycle_info_block += "注意,除了上述动作选项之外,你在群聊里不能做其他任何事情,这是你能力的边界\n"
|
||||||
|
|
||||||
self.observe_info = cycle_info_block
|
self.observe_info = cycle_info_block
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
from typing import Optional, Tuple, Dict
|
from typing import Optional, Tuple, Dict
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from src.chat.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager
|
||||||
|
|
||||||
logger = get_logger("heartflow_utils")
|
logger = get_logger("heartflow_utils")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import jieba
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from ...chat.models.utils_model import LLMRequest
|
from ...llm_models.utils_model import LLMRequest
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||||
from ..utils.chat_message_builder import (
|
from ..utils.chat_message_builder import (
|
||||||
|
|
@ -338,7 +338,8 @@ class Hippocampus:
|
||||||
# 去重
|
# 去重
|
||||||
keywords = list(set(keywords))
|
keywords = list(set(keywords))
|
||||||
# 限制关键词数量
|
# 限制关键词数量
|
||||||
keywords = keywords[:5]
|
logger.debug(f"提取关键词: {keywords}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 使用LLM提取关键词
|
# 使用LLM提取关键词
|
||||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||||
|
|
@ -361,7 +362,7 @@ class Hippocampus:
|
||||||
# 过滤掉不存在于记忆图中的关键词
|
# 过滤掉不存在于记忆图中的关键词
|
||||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||||
if not valid_keywords:
|
if not valid_keywords:
|
||||||
# logger.info("没有找到有效的关键词节点")
|
logger.info("没有找到有效的关键词节点")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from ..emoji_system.emoji_manager import emoji_manager
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
from ..person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
from .chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from .message_sender import message_manager
|
from src.chat.message_receive.message_sender import message_manager
|
||||||
from .storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ class ChatBot:
|
||||||
message = MessageRecv(message_data)
|
message = MessageRecv(message_data)
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
user_info = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
|
chat_manager.register_message(message)
|
||||||
|
|
||||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||||
|
|
@ -86,7 +87,7 @@ class ChatBot:
|
||||||
if isinstance(template_items, dict):
|
if isinstance(template_items, dict):
|
||||||
for k in template_items.keys():
|
for k in template_items.keys():
|
||||||
await Prompt.create_async(template_items[k], k)
|
await Prompt.create_async(template_items[k], k)
|
||||||
print(f"注册{template_items[k]},{k}")
|
logger.debug(f"注册{template_items[k]},{k}")
|
||||||
else:
|
else:
|
||||||
template_group_name = None
|
template_group_name = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,17 @@ import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
from ...common.database.database import db
|
from ...common.database.database import db
|
||||||
from ...common.database.database_model import ChatStreams # 新增导入
|
from ...common.database.database_model import ChatStreams # 新增导入
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
|
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .message import MessageRecv
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
|
@ -18,6 +22,23 @@ install(extra_lines=3)
|
||||||
logger = get_logger("chat_stream")
|
logger = get_logger("chat_stream")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageContext:
|
||||||
|
"""聊天消息上下文,存储消息的上下文信息"""
|
||||||
|
|
||||||
|
def __init__(self, message: "MessageRecv"):
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def get_template_name(self) -> str:
|
||||||
|
"""获取模板名称"""
|
||||||
|
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||||||
|
return self.message.message_info.template_info.template_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_last_message(self) -> "MessageRecv":
|
||||||
|
"""获取最后一条消息"""
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
|
||||||
class ChatStream:
|
class ChatStream:
|
||||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||||
|
|
||||||
|
|
@ -36,6 +57,7 @@ class ChatStream:
|
||||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
|
|
@ -67,6 +89,10 @@ class ChatStream:
|
||||||
self.last_active_time = time.time()
|
self.last_active_time = time.time()
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
|
||||||
|
def set_context(self, message: "MessageRecv"):
|
||||||
|
"""设置聊天消息上下文"""
|
||||||
|
self.context = ChatMessageContext(message)
|
||||||
|
|
||||||
|
|
||||||
class ChatManager:
|
class ChatManager:
|
||||||
"""聊天管理器,管理所有聊天流"""
|
"""聊天管理器,管理所有聊天流"""
|
||||||
|
|
@ -82,6 +108,7 @@ class ChatManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||||
|
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||||
try:
|
try:
|
||||||
db.connect(reuse_if_open=True)
|
db.connect(reuse_if_open=True)
|
||||||
# 确保 ChatStreams 表存在
|
# 确保 ChatStreams 表存在
|
||||||
|
|
@ -113,6 +140,16 @@ class ChatManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||||
|
|
||||||
|
def register_message(self, message: "MessageRecv"):
|
||||||
|
"""注册消息到聊天流"""
|
||||||
|
stream_id = self._generate_stream_id(
|
||||||
|
message.message_info.platform,
|
||||||
|
message.message_info.user_info,
|
||||||
|
message.message_info.group_info,
|
||||||
|
)
|
||||||
|
self.last_messages[stream_id] = message
|
||||||
|
logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||||
"""生成聊天流唯一ID"""
|
"""生成聊天流唯一ID"""
|
||||||
|
|
@ -146,12 +183,19 @@ class ChatManager:
|
||||||
# 检查内存中是否存在
|
# 检查内存中是否存在
|
||||||
if stream_id in self.streams:
|
if stream_id in self.streams:
|
||||||
stream = self.streams[stream_id]
|
stream = self.streams[stream_id]
|
||||||
|
|
||||||
# 更新用户信息和群组信息
|
# 更新用户信息和群组信息
|
||||||
stream.update_active_time()
|
stream.update_active_time()
|
||||||
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||||||
stream.user_info = user_info
|
stream.user_info = user_info
|
||||||
if group_info:
|
if group_info:
|
||||||
stream.group_info = group_info
|
stream.group_info = group_info
|
||||||
|
from .message import MessageRecv # 延迟导入,避免循环引用
|
||||||
|
|
||||||
|
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||||
|
stream.set_context(self.last_messages[stream_id])
|
||||||
|
else:
|
||||||
|
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
# 检查数据库中是否存在
|
# 检查数据库中是否存在
|
||||||
|
|
@ -202,14 +246,26 @@ class ChatManager:
|
||||||
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
stream = copy.deepcopy(stream)
|
||||||
|
from .message import MessageRecv # 延迟导入,避免循环引用
|
||||||
|
|
||||||
|
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||||
|
stream.set_context(self.last_messages[stream_id])
|
||||||
|
else:
|
||||||
|
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||||
# 保存到内存和数据库
|
# 保存到内存和数据库
|
||||||
self.streams[stream_id] = stream
|
self.streams[stream_id] = stream
|
||||||
await self._save_stream(stream)
|
await self._save_stream(stream)
|
||||||
return copy.deepcopy(stream)
|
return stream
|
||||||
|
|
||||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||||
"""通过stream_id获取聊天流"""
|
"""通过stream_id获取聊天流"""
|
||||||
return self.streams.get(stream_id)
|
stream = self.streams.get(stream_id)
|
||||||
|
if not stream:
|
||||||
|
return None
|
||||||
|
if stream_id in self.last_messages:
|
||||||
|
stream.set_context(self.last_messages[stream_id])
|
||||||
|
return stream
|
||||||
|
|
||||||
def get_stream_by_info(
|
def get_stream_by_info(
|
||||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||||
|
|
@ -306,6 +362,8 @@ class ChatManager:
|
||||||
stream = ChatStream.from_dict(data)
|
stream = ChatStream.from_dict(data)
|
||||||
stream.saved = True
|
stream.saved = True
|
||||||
self.streams[stream.stream_id] = stream
|
self.streams[stream.stream_id] = stream
|
||||||
|
if stream.stream_id in self.last_messages:
|
||||||
|
stream.set_context(self.last_messages[stream.stream_id])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, TYPE_CHECKING
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from .chat_stream import ChatStream
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .chat_stream import ChatStream
|
||||||
from ..utils.utils_image import image_manager
|
from ..utils.utils_image import image_manager
|
||||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
@ -25,7 +27,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Message(MessageBase):
|
class Message(MessageBase):
|
||||||
chat_stream: ChatStream = None
|
chat_stream: "ChatStream" = None
|
||||||
reply: Optional["Message"] = None
|
reply: Optional["Message"] = None
|
||||||
detailed_plain_text: str = ""
|
detailed_plain_text: str = ""
|
||||||
processed_plain_text: str = ""
|
processed_plain_text: str = ""
|
||||||
|
|
@ -34,7 +36,7 @@ class Message(MessageBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
user_info: UserInfo,
|
user_info: UserInfo,
|
||||||
message_segment: Optional[Seg] = None,
|
message_segment: Optional[Seg] = None,
|
||||||
timestamp: Optional[float] = None,
|
timestamp: Optional[float] = None,
|
||||||
|
|
@ -111,7 +113,7 @@ class MessageRecv(Message):
|
||||||
self.detailed_plain_text = "" # 初始化为空字符串
|
self.detailed_plain_text = "" # 初始化为空字符串
|
||||||
self.is_emoji = False
|
self.is_emoji = False
|
||||||
|
|
||||||
def update_chat_stream(self, chat_stream: ChatStream):
|
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
|
|
||||||
async def process(self) -> None:
|
async def process(self) -> None:
|
||||||
|
|
@ -165,7 +167,7 @@ class MessageProcessBase(Message):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
bot_user_info: UserInfo,
|
bot_user_info: UserInfo,
|
||||||
message_segment: Optional[Seg] = None,
|
message_segment: Optional[Seg] = None,
|
||||||
reply: Optional["MessageRecv"] = None,
|
reply: Optional["MessageRecv"] = None,
|
||||||
|
|
@ -241,7 +243,7 @@ class MessageThinking(MessageProcessBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
bot_user_info: UserInfo,
|
bot_user_info: UserInfo,
|
||||||
reply: Optional["MessageRecv"] = None,
|
reply: Optional["MessageRecv"] = None,
|
||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
|
|
@ -269,7 +271,7 @@ class MessageSending(MessageProcessBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
bot_user_info: UserInfo,
|
bot_user_info: UserInfo,
|
||||||
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
|
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
|
||||||
message_segment: Seg,
|
message_segment: Seg,
|
||||||
|
|
@ -353,7 +355,7 @@ class MessageSending(MessageProcessBase):
|
||||||
class MessageSet:
|
class MessageSet:
|
||||||
"""消息集合类,可以存储多个发送消息"""
|
"""消息集合类,可以存储多个发送消息"""
|
||||||
|
|
||||||
def __init__(self, chat_stream: ChatStream, message_id: str):
|
def __init__(self, chat_stream: "ChatStream", message_id: str):
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
self.messages: list[MessageSending] = []
|
self.messages: list[MessageSending] = []
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,10 @@ from src.common.logger_manager import get_logger
|
||||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||||
from src.manager.mood_manager import mood_manager
|
from src.manager.mood_manager import mood_manager
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
from src.chat.utils.info_catcher import info_catcher_manager
|
from src.chat.utils.info_catcher import info_catcher_manager
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
|
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||||
from .normal_chat_generator import NormalChatGenerator
|
from .normal_chat_generator import NormalChatGenerator
|
||||||
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||||
from src.chat.message_receive.message_sender import message_manager
|
from src.chat.message_receive.message_sender import message_manager
|
||||||
|
|
@ -189,30 +190,31 @@ class NormalChat:
|
||||||
通常由start_monitoring_interest()启动
|
通常由start_monitoring_interest()启动
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(0.5) # 每秒检查一次
|
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||||
# 检查任务是否已被取消
|
await asyncio.sleep(0.5) # 每秒检查一次
|
||||||
if self._chat_task is None or self._chat_task.cancelled():
|
# 检查任务是否已被取消
|
||||||
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
|
if self._chat_task is None or self._chat_task.cancelled():
|
||||||
break
|
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
|
||||||
|
break
|
||||||
|
|
||||||
items_to_process = list(self.interest_dict.items())
|
items_to_process = list(self.interest_dict.items())
|
||||||
if not items_to_process:
|
if not items_to_process:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 处理每条兴趣消息
|
# 处理每条兴趣消息
|
||||||
for msg_id, (message, interest_value, is_mentioned) in items_to_process:
|
for msg_id, (message, interest_value, is_mentioned) in items_to_process:
|
||||||
try:
|
try:
|
||||||
# 处理消息
|
# 处理消息
|
||||||
await self.normal_response(
|
await self.normal_response(
|
||||||
message=message,
|
message=message,
|
||||||
is_mentioned=is_mentioned,
|
is_mentioned=is_mentioned,
|
||||||
interested_rate=interest_value,
|
interested_rate=interest_value,
|
||||||
rewind_response=False,
|
rewind_response=False,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.stream_name}] 处理兴趣消息{msg_id}时出错: {e}\n{traceback.format_exc()}")
|
logger.error(f"[{self.stream_name}] 处理兴趣消息{msg_id}时出错: {e}\n{traceback.format_exc()}")
|
||||||
finally:
|
finally:
|
||||||
self.interest_dict.pop(msg_id, None)
|
self.interest_dict.pop(msg_id, None)
|
||||||
|
|
||||||
# 改为实例方法, 移除 chat 参数
|
# 改为实例方法, 移除 chat 参数
|
||||||
async def normal_response(
|
async def normal_response(
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
import random
|
import random
|
||||||
from ..models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from ...config.config import global_config
|
from src.config.config import global_config
|
||||||
from ..message_receive.message import MessageThinking
|
from src.chat.message_receive.message import MessageThinking
|
||||||
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
|
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
|
||||||
from src.chat.utils.utils import process_llm_response
|
from src.chat.utils.utils import process_llm_response
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager
|
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import time # 导入 time 模块以获取当前时间
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.chat.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager
|
||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from typing import Dict, Any, Optional, List, Union
|
||||||
import re
|
import re
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextvars
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
# import traceback
|
# import traceback
|
||||||
|
|
@ -15,29 +16,59 @@ logger = get_module_logger("prompt_build")
|
||||||
class PromptContext:
|
class PromptContext:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||||
self._current_context: Optional[str] = None
|
# 使用contextvars创建协程上下文变量
|
||||||
self._context_lock = asyncio.Lock() # 添加异步锁
|
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||||
|
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _current_context(self) -> Optional[str]:
|
||||||
|
"""获取当前协程的上下文ID"""
|
||||||
|
return self._current_context_var.get()
|
||||||
|
|
||||||
|
@_current_context.setter
|
||||||
|
def _current_context(self, value: Optional[str]):
|
||||||
|
"""设置当前协程的上下文ID"""
|
||||||
|
self._current_context_var.set(value)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def async_scope(self, context_id: str):
|
async def async_scope(self, context_id: Optional[str] = None):
|
||||||
"""创建一个异步的临时提示模板作用域"""
|
"""创建一个异步的临时提示模板作用域"""
|
||||||
async with self._context_lock:
|
# 保存当前上下文并设置新上下文
|
||||||
if context_id not in self._context_prompts:
|
if context_id is not None:
|
||||||
self._context_prompts[context_id] = {}
|
async with self._context_lock:
|
||||||
|
if context_id not in self._context_prompts:
|
||||||
|
self._context_prompts[context_id] = {}
|
||||||
|
|
||||||
|
# 保存当前协程的上下文值,不影响其他协程
|
||||||
previous_context = self._current_context
|
previous_context = self._current_context
|
||||||
self._current_context = context_id
|
# 设置当前协程的新上下文
|
||||||
|
token = self._current_context_var.set(context_id)
|
||||||
|
else:
|
||||||
|
# 如果没有提供新上下文,保持当前上下文不变
|
||||||
|
previous_context = self._current_context
|
||||||
|
token = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self
|
yield self
|
||||||
finally:
|
finally:
|
||||||
async with self._context_lock:
|
# 恢复之前的上下文
|
||||||
self._current_context = previous_context
|
if context_id is not None:
|
||||||
|
if token:
|
||||||
|
self._current_context_var.reset(token)
|
||||||
|
else:
|
||||||
|
self._current_context = previous_context
|
||||||
|
|
||||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||||
"""异步获取当前作用域中的提示模板"""
|
"""异步获取当前作用域中的提示模板"""
|
||||||
async with self._context_lock:
|
async with self._context_lock:
|
||||||
if self._current_context and name in self._context_prompts[self._current_context]:
|
current_context = self._current_context
|
||||||
return self._context_prompts[self._current_context][name]
|
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
|
||||||
|
if (
|
||||||
|
current_context
|
||||||
|
and current_context in self._context_prompts
|
||||||
|
and name in self._context_prompts[current_context]
|
||||||
|
):
|
||||||
|
return self._context_prompts[current_context][name]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||||
|
|
@ -56,8 +87,8 @@ class PromptManager:
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def async_message_scope(self, message_id: str):
|
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||||
"""为消息处理创建异步临时作用域"""
|
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
|
||||||
async with self._context.async_scope(message_id):
|
async with self._context.async_scope(message_id):
|
||||||
yield self
|
yield self
|
||||||
|
|
||||||
|
|
@ -65,9 +96,11 @@ class PromptManager:
|
||||||
# 首先尝试从当前上下文获取
|
# 首先尝试从当前上下文获取
|
||||||
context_prompt = await self._context.get_prompt_async(name)
|
context_prompt = await self._context.get_prompt_async(name)
|
||||||
if context_prompt is not None:
|
if context_prompt is not None:
|
||||||
|
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||||
return context_prompt
|
return context_prompt
|
||||||
# 如果上下文中不存在,则使用全局提示模板
|
# 如果上下文中不存在,则使用全局提示模板
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
|
logger.debug(f"从全局获取提示词: {name}")
|
||||||
if name not in self._prompts:
|
if name not in self._prompts:
|
||||||
raise KeyError(f"Prompt '{name}' not found")
|
raise KeyError(f"Prompt '{name}' not found")
|
||||||
return self._prompts[name]
|
return self._prompts[name]
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from maim_message import UserInfo
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.manager.mood_manager import mood_manager
|
from src.manager.mood_manager import mood_manager
|
||||||
from ..message_receive.message import MessageRecv
|
from ..message_receive.message import MessageRecv
|
||||||
from ..models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from .typo_generator import ChineseTypoGenerator
|
from .typo_generator import ChineseTypoGenerator
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from ...common.message_repository import find_messages, count_messages
|
from ...common.message_repository import find_messages, count_messages
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ import io
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from ...common.database.database import db
|
from src.common.database.database import db
|
||||||
from ...common.database.database_model import Images, ImageDescriptions
|
from src.common.database.database_model import Images, ImageDescriptions
|
||||||
from ...config.config import global_config
|
from src.config.config import global_config
|
||||||
from ..models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
|
||||||
|
|
@ -869,6 +869,23 @@ API_SERVER_STYLE_CONFIG = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# maim_message 消息服务样式配置
|
||||||
|
MAIM_MESSAGE_STYLE_CONFIG = {
|
||||||
|
"advanced": {
|
||||||
|
"console_format": (
|
||||||
|
"<white>{time:YYYY-MM-DD HH:mm:ss}</white> | "
|
||||||
|
"<level>{level: <8}</level> | "
|
||||||
|
"<fg #00B2FF>消息服务</fg #00B2FF> | "
|
||||||
|
"<level>{message}</level>"
|
||||||
|
),
|
||||||
|
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息服务 | {message}",
|
||||||
|
},
|
||||||
|
"simple": {
|
||||||
|
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #00B2FF>消息服务</fg #00B2FF> | {message}",
|
||||||
|
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息服务 | {message}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# 根据SIMPLE_OUTPUT选择配置
|
# 根据SIMPLE_OUTPUT选择配置
|
||||||
MAIN_STYLE_CONFIG = MAIN_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIN_STYLE_CONFIG["advanced"]
|
MAIN_STYLE_CONFIG = MAIN_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIN_STYLE_CONFIG["advanced"]
|
||||||
|
|
@ -946,6 +963,9 @@ CHAT_MESSAGE_STYLE_CONFIG = (
|
||||||
CHAT_IMAGE_STYLE_CONFIG = CHAT_IMAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_IMAGE_STYLE_CONFIG["advanced"]
|
CHAT_IMAGE_STYLE_CONFIG = CHAT_IMAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_IMAGE_STYLE_CONFIG["advanced"]
|
||||||
INIT_STYLE_CONFIG = INIT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INIT_STYLE_CONFIG["advanced"]
|
INIT_STYLE_CONFIG = INIT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INIT_STYLE_CONFIG["advanced"]
|
||||||
API_SERVER_STYLE_CONFIG = API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"]
|
API_SERVER_STYLE_CONFIG = API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"]
|
||||||
|
MAIM_MESSAGE_STYLE_CONFIG = (
|
||||||
|
MAIM_MESSAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIM_MESSAGE_STYLE_CONFIG["advanced"]
|
||||||
|
)
|
||||||
INTEREST_CHAT_STYLE_CONFIG = (
|
INTEREST_CHAT_STYLE_CONFIG = (
|
||||||
INTEREST_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INTEREST_CHAT_STYLE_CONFIG["advanced"]
|
INTEREST_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INTEREST_CHAT_STYLE_CONFIG["advanced"]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,27 @@
|
||||||
from src.common.server import global_server
|
from src.common.server import global_server
|
||||||
import os
|
import os
|
||||||
|
import importlib.metadata
|
||||||
from maim_message import MessageServer
|
from maim_message import MessageServer
|
||||||
|
from src.common.logger_manager import get_logger
|
||||||
|
|
||||||
|
|
||||||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|
# 检查maim_message版本
|
||||||
|
try:
|
||||||
|
maim_message_version = importlib.metadata.version("maim_message")
|
||||||
|
version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 0]
|
||||||
|
except (importlib.metadata.PackageNotFoundError, ValueError):
|
||||||
|
version_compatible = False
|
||||||
|
|
||||||
|
# 根据版本决定是否使用自定义logger
|
||||||
|
kwargs = {
|
||||||
|
"host": os.environ["HOST"],
|
||||||
|
"port": int(os.environ["PORT"]),
|
||||||
|
"app": global_server.get_app(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 只有在版本 >= 0.3.0 时才使用自定义logger
|
||||||
|
if version_compatible:
|
||||||
|
maim_message_logger = get_logger("maim_message")
|
||||||
|
kwargs["custom_logger"] = maim_message_logger
|
||||||
|
|
||||||
|
global_api = MessageServer(**kwargs)
|
||||||
|
|
|
||||||
|
|
@ -68,30 +68,30 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{TELEMETRY_SERVER_URL}/stat/reg_client",
|
f"{TELEMETRY_SERVER_URL}/stat/reg_client",
|
||||||
json={"deploy_time": local_storage["deploy_time"]},
|
json={"deploy_time": local_storage["deploy_time"]},
|
||||||
|
timeout=5, # 设置超时时间为5秒
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
|
|
||||||
|
|
||||||
logger.debug(local_storage["deploy_time"])
|
|
||||||
|
|
||||||
logger.debug(response)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
data = response.json()
|
|
||||||
client_id = data.get("mmc_uuid")
|
|
||||||
if client_id:
|
|
||||||
# 将UUID存储到本地
|
|
||||||
local_storage["mmc_uuid"] = client_id
|
|
||||||
self.client_uuid = client_id
|
|
||||||
logger.info(f"成功获取UUID: {self.client_uuid}")
|
|
||||||
return True # 成功获取UUID,返回True
|
|
||||||
else:
|
|
||||||
logger.error("无效的服务端响应")
|
|
||||||
else:
|
|
||||||
logger.error(f"请求UUID失败,状态码: {response.status_code}, 响应内容: {response.text}")
|
|
||||||
except requests.RequestException as e:
|
|
||||||
logger.error(f"请求UUID时出错: {e}") # 可能是网络问题
|
logger.error(f"请求UUID时出错: {e}") # 可能是网络问题
|
||||||
|
|
||||||
|
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
|
||||||
|
|
||||||
|
logger.debug(local_storage["deploy_time"])
|
||||||
|
|
||||||
|
logger.debug(response)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if client_id := data.get("mmc_uuid"):
|
||||||
|
# 将UUID存储到本地
|
||||||
|
local_storage["mmc_uuid"] = client_id
|
||||||
|
self.client_uuid = client_id
|
||||||
|
logger.info(f"成功获取UUID: {self.client_uuid}")
|
||||||
|
return True # 成功获取UUID,返回True
|
||||||
|
else:
|
||||||
|
logger.error("无效的服务端响应")
|
||||||
|
else:
|
||||||
|
logger.error(f"请求UUID失败,状态码: {response.status_code}, 响应内容: {response.text}")
|
||||||
|
|
||||||
# 请求失败,重试次数+1
|
# 请求失败,重试次数+1
|
||||||
try_count += 1
|
try_count += 1
|
||||||
if try_count > 3:
|
if try_count > 3:
|
||||||
|
|
@ -100,47 +100,48 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
# 如果可以重试,等待后继续(指数退避)
|
# 如果可以重试,等待后继续(指数退避)
|
||||||
|
logger.info(f"获取UUID失败,将于 {4**try_count} 秒后重试...")
|
||||||
await asyncio.sleep(4**try_count)
|
await asyncio.sleep(4**try_count)
|
||||||
|
|
||||||
async def _send_heartbeat(self):
|
async def _send_heartbeat(self):
|
||||||
"""向服务器发送心跳"""
|
"""向服务器发送心跳"""
|
||||||
|
headers = {
|
||||||
|
"Client-UUID": self.client_uuid,
|
||||||
|
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
||||||
|
|
||||||
|
logger.debug(headers)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
headers = {
|
|
||||||
"Client-UUID": self.client_uuid,
|
|
||||||
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
|
||||||
|
|
||||||
logger.debug(headers)
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self.server_url}/stat/client_heartbeat",
|
f"{self.server_url}/stat/client_heartbeat",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=self.info_dict,
|
json=self.info_dict,
|
||||||
|
timeout=5, # 设置超时时间为5秒
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
logger.debug(response)
|
|
||||||
|
|
||||||
# 处理响应
|
|
||||||
if 200 <= response.status_code < 300:
|
|
||||||
# 成功
|
|
||||||
logger.debug(f"心跳发送成功,状态码: {response.status_code}")
|
|
||||||
elif response.status_code == 403:
|
|
||||||
# 403 Forbidden
|
|
||||||
logger.error(
|
|
||||||
"心跳发送失败,403 Forbidden: 可能是UUID无效或未注册。"
|
|
||||||
"处理措施:重置UUID,下次发送心跳时将尝试重新注册。"
|
|
||||||
)
|
|
||||||
self.client_uuid = None
|
|
||||||
del local_storage["mmc_uuid"] # 删除本地存储的UUID
|
|
||||||
else:
|
|
||||||
# 其他错误
|
|
||||||
logger.error(f"心跳发送失败,状态码: {response.status_code}, 响应内容: {response.text}")
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
logger.error(f"心跳发送失败: {e}")
|
logger.error(f"心跳发送失败: {e}")
|
||||||
|
|
||||||
|
logger.debug(response)
|
||||||
|
|
||||||
|
# 处理响应
|
||||||
|
if 200 <= response.status_code < 300:
|
||||||
|
# 成功
|
||||||
|
logger.debug(f"心跳发送成功,状态码: {response.status_code}")
|
||||||
|
elif response.status_code == 403:
|
||||||
|
# 403 Forbidden
|
||||||
|
logger.error(
|
||||||
|
"心跳发送失败,403 Forbidden: 可能是UUID无效或未注册。"
|
||||||
|
"处理措施:重置UUID,下次发送心跳时将尝试重新注册。"
|
||||||
|
)
|
||||||
|
self.client_uuid = None
|
||||||
|
del local_storage["mmc_uuid"] # 删除本地存储的UUID
|
||||||
|
else:
|
||||||
|
# 其他错误
|
||||||
|
logger.error(f"心跳发送失败,状态码: {response.status_code}, 响应内容: {response.text}")
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
# 发送心跳
|
# 发送心跳
|
||||||
if global_config.telemetry.enable:
|
if global_config.telemetry.enable:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import time
|
import time
|
||||||
from typing import Tuple, Optional # 增加了 Optional
|
from typing import Tuple, Optional # 增加了 Optional
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.experimental.PFC.chat_observer import ChatObserver
|
from src.experimental.PFC.chat_observer import ChatObserver
|
||||||
from src.experimental.PFC.pfc_utils import get_items_from_json
|
from src.experimental.PFC.pfc_utils import get_items_from_json
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import List, Tuple, TYPE_CHECKING
|
from typing import List, Tuple, TYPE_CHECKING
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.experimental.PFC.chat_observer import ChatObserver
|
from src.experimental.PFC.chat_observer import ChatObserver
|
||||||
from src.experimental.PFC.pfc_utils import get_items_from_json
|
from src.experimental.PFC.pfc_utils import get_items_from_json
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.message import Message
|
from src.chat.message_receive.message import Message
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
from typing import Tuple, List, Dict, Any
|
from typing import Tuple, List, Dict, Any
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.experimental.PFC.chat_observer import ChatObserver
|
from src.experimental.PFC.chat_observer import ChatObserver
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Tuple, List, Dict, Any
|
from typing import Tuple, List, Dict, Any
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.experimental.PFC.chat_observer import ChatObserver
|
from src.experimental.PFC.chat_observer import ChatObserver
|
||||||
from src.experimental.PFC.reply_checker import ReplyChecker
|
from src.experimental.PFC.reply_checker import ReplyChecker
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import random
|
import random
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,8 @@ import json
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Tuple, Union, Dict, Any
|
from typing import Tuple, Union, Dict, Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp.client import ClientResponse
|
from aiohttp.client import ClientResponse
|
||||||
|
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
import base64
|
import base64
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
@ -14,7 +12,7 @@ import io
|
||||||
import os
|
import os
|
||||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||||
from ...config.config import global_config
|
from src.config.config import global_config
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
@ -6,7 +6,7 @@ from .manager.async_task_manager import async_task_manager
|
||||||
from .chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
from .chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
from .manager.mood_manager import MoodPrintTask, MoodUpdateTask
|
from .manager.mood_manager import MoodPrintTask, MoodUpdateTask
|
||||||
from .chat.emoji_system.emoji_manager import emoji_manager
|
from .chat.emoji_system.emoji_manager import emoji_manager
|
||||||
from .chat.person_info.person_info import person_info_manager
|
from .person_info.person_info import person_info_manager
|
||||||
from .chat.normal_chat.willing.willing_manager import willing_manager
|
from .chat.normal_chat.willing.willing_manager import willing_manager
|
||||||
from .chat.message_receive.chat_stream import chat_manager
|
from .chat.message_receive.chat_stream import chat_manager
|
||||||
from src.chat.heart_flow.heartflow import heartflow
|
from src.chat.heart_flow.heartflow import heartflow
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,21 @@ class LocalStoreManager:
|
||||||
|
|
||||||
def __getitem__(self, item: str) -> str | list | dict | int | float | bool | None:
|
def __getitem__(self, item: str) -> str | list | dict | int | float | bool | None:
|
||||||
"""获取本地存储数据"""
|
"""获取本地存储数据"""
|
||||||
return self.store.get(item, None)
|
return self.store.get(item)
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: str | list | dict | int | float | bool):
|
def __setitem__(self, key: str, value: str | list | dict | int | float | bool):
|
||||||
"""设置本地存储数据"""
|
"""设置本地存储数据"""
|
||||||
self.store[key] = value
|
self.store[key] = value
|
||||||
self.save_local_store()
|
self.save_local_store()
|
||||||
|
|
||||||
|
def __delitem__(self, key: str):
|
||||||
|
"""删除本地存储数据"""
|
||||||
|
if key in self.store:
|
||||||
|
del self.store[key]
|
||||||
|
self.save_local_store()
|
||||||
|
else:
|
||||||
|
logger.warning(f"尝试删除不存在的键: {key}")
|
||||||
|
|
||||||
def __contains__(self, item: str) -> bool:
|
def __contains__(self, item: str) -> bool:
|
||||||
"""检查本地存储数据是否存在"""
|
"""检查本地存储数据是否存在"""
|
||||||
return item in self.store
|
return item in self.store
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from ...common.database.database import db
|
from src.common.database.database import db
|
||||||
from ...common.database.database_model import PersonInfo # 新增导入
|
from src.common.database.database_model import PersonInfo # 新增导入
|
||||||
import copy
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict
|
||||||
import datetime
|
import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.individuality.individuality import individuality
|
from src.individuality.individuality import individuality
|
||||||
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from ..message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
import math
|
import math
|
||||||
from bson.decimal128 import Decimal128
|
from bson.decimal128 import Decimal128
|
||||||
from .person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
|
|
||||||
from ...manager.mood_manager import mood_manager
|
from src.manager.mood_manager import mood_manager
|
||||||
|
|
||||||
# import re
|
# import re
|
||||||
# import traceback
|
# import traceback
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""测试插件动作模块"""
|
"""测试插件动作模块"""
|
||||||
|
|
||||||
# 导入所有动作模块以确保装饰器被执行
|
# 导入所有动作模块以确保装饰器被执行
|
||||||
# from . import test_action # noqa
|
from . import test_action # noqa
|
||||||
|
|
||||||
# from . import online_action # noqa
|
from . import online_action # noqa
|
||||||
# from . import mute_action # noqa
|
from . import mute_action # noqa
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class MuteAction(PluginAction):
|
||||||
"当千石可乐或可乐酱要求你禁言时使用",
|
"当千石可乐或可乐酱要求你禁言时使用",
|
||||||
"当你想回避某个话题时使用",
|
"当你想回避某个话题时使用",
|
||||||
]
|
]
|
||||||
default = True # 不是默认动作,需要手动添加到使用集
|
default = False # 不是默认动作,需要手动添加到使用集
|
||||||
|
|
||||||
async def process(self) -> Tuple[bool, str]:
|
async def process(self) -> Tuple[bool, str]:
|
||||||
"""处理测试动作"""
|
"""处理测试动作"""
|
||||||
|
|
@ -40,7 +40,11 @@ class MuteAction(PluginAction):
|
||||||
await self.send_message_by_expressor(f"我要禁言{target},{platform},时长{duration}秒,理由{reason},表达情绪")
|
await self.send_message_by_expressor(f"我要禁言{target},{platform},时长{duration}秒,理由{reason},表达情绪")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.send_message(f"[command]mute,{user_id},{duration}")
|
await self.send_message(
|
||||||
|
type="text",
|
||||||
|
data=f"[command]mute,{user_id},{duration}",
|
||||||
|
# target = target
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}")
|
logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}")
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ class CheckOnlineAction(PluginAction):
|
||||||
"mode参数为version时查看在线版本状态,默认用这种",
|
"mode参数为version时查看在线版本状态,默认用这种",
|
||||||
"mode参数为type时查看在线系统类型分布",
|
"mode参数为type时查看在线系统类型分布",
|
||||||
]
|
]
|
||||||
default = True # 不是默认动作,需要手动添加到使用集
|
default = False # 不是默认动作,需要手动添加到使用集
|
||||||
|
|
||||||
async def process(self) -> Tuple[bool, str]:
|
async def process(self) -> Tuple[bool, str]:
|
||||||
"""处理测试动作"""
|
"""处理测试动作"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""测试插件包:图片发送"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
这是一个测试插件,用于测试图片发送功能
|
||||||
|
"""
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
"""测试插件动作模块"""
|
||||||
|
|
||||||
|
# 导入所有动作模块以确保装饰器被执行
|
||||||
|
from . import pic_action # noqa
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
CONFIG_CONTENT = """\
|
||||||
|
# 请替换为您的火山引擎 Access Key ID
|
||||||
|
volcano_ak = "YOUR_VOLCANO_ENGINE_ACCESS_KEY_ID_HERE"
|
||||||
|
# 请替换为您的火山引擎 Secret Access Key
|
||||||
|
volcano_sk = "YOUR_VOLCANO_ENGINE_SECRET_ACCESS_KEY_HERE"
|
||||||
|
# 火山方舟 API 的基础 URL
|
||||||
|
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||||
|
# 默认图片生成模型
|
||||||
|
default_model = "doubao-seedream-3-0-t2i-250415"
|
||||||
|
# 默认图片尺寸
|
||||||
|
default_size = "1024x1024"
|
||||||
|
# 用于图片生成的API密钥
|
||||||
|
# PicAction 当前配置为在HTTP请求体和Authorization头中使用此密钥。
|
||||||
|
# 如果您的API认证方式不同,请相应调整或移除。
|
||||||
|
volcano_generate_api_key = "YOUR_VOLCANO_GENERATE_API_KEY_HERE"
|
||||||
|
|
||||||
|
# 是否默认开启水印
|
||||||
|
default_watermark = true
|
||||||
|
# 默认引导强度
|
||||||
|
default_guidance_scale = 2.5
|
||||||
|
# 默认随机种子
|
||||||
|
default_seed = 42
|
||||||
|
|
||||||
|
# 更多插件特定配置可以在此添加...
|
||||||
|
# custom_parameter = "some_value"
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_config():
|
||||||
|
# 获取当前脚本所在的目录
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
config_file_path = os.path.join(current_dir, "pic_action_config.toml")
|
||||||
|
|
||||||
|
if not os.path.exists(config_file_path):
|
||||||
|
try:
|
||||||
|
with open(config_file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(CONFIG_CONTENT)
|
||||||
|
print(f"配置文件已生成: {config_file_path}")
|
||||||
|
print("请记得编辑该文件,填入您的火山引擎 AK/SK 和 API 密钥。")
|
||||||
|
except IOError as e:
|
||||||
|
print(f"错误:无法写入配置文件 {config_file_path}。原因: {e}")
|
||||||
|
else:
|
||||||
|
print(f"配置文件已存在: {config_file_path}")
|
||||||
|
print("未进行任何更改。如果您想重新生成,请先删除或重命名现有文件。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
generate_config()
|
||||||
|
|
@ -0,0 +1,264 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
|
import base64 # 新增:用于Base64编码
|
||||||
|
import traceback # 新增:用于打印堆栈跟踪
|
||||||
|
from typing import Tuple
|
||||||
|
from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
|
||||||
|
from src.common.logger_manager import get_logger
|
||||||
|
from .generate_pic_config import generate_config
|
||||||
|
|
||||||
|
logger = get_logger("pic_action")
|
||||||
|
|
||||||
|
# 当此模块被加载时,尝试生成配置文件(如果它不存在)
|
||||||
|
# 注意:在某些插件加载机制下,这可能会在每次机器人启动或插件重载时执行
|
||||||
|
# 考虑是否需要更复杂的逻辑来决定何时运行 (例如,仅在首次安装时)
|
||||||
|
generate_config()
|
||||||
|
|
||||||
|
|
||||||
|
@register_action
|
||||||
|
class PicAction(PluginAction):
|
||||||
|
"""根据描述使用火山引擎HTTP API生成图片的动作处理类"""
|
||||||
|
|
||||||
|
action_name = "pic_action"
|
||||||
|
action_description = "可以根据特定的描述,使用火山引擎模型生成并发送一张图片 (通过HTTP API)"
|
||||||
|
action_parameters = {
|
||||||
|
"description": "图片描述,输入你想要生成并发送的图片的描述,必填",
|
||||||
|
"size": "图片尺寸,例如 '1024x1024' (可选, 默认从配置或 '1024x1024')",
|
||||||
|
}
|
||||||
|
action_require = [
|
||||||
|
"当有人要求你生成并发送一张图片时使用",
|
||||||
|
"当有人让你画一张图时使用",
|
||||||
|
]
|
||||||
|
default = False
|
||||||
|
action_config_file_name = "pic_action_config.toml"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
action_data: dict,
|
||||||
|
reasoning: str,
|
||||||
|
cycle_timers: dict,
|
||||||
|
thinking_id: str,
|
||||||
|
global_config: dict = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(action_data, reasoning, cycle_timers, thinking_id, global_config, **kwargs)
|
||||||
|
|
||||||
|
http_base_url = self.config.get("base_url")
|
||||||
|
http_api_key = self.config.get("volcano_generate_api_key")
|
||||||
|
|
||||||
|
if not (http_base_url and http_api_key):
|
||||||
|
logger.error(
|
||||||
|
f"{self.log_prefix} PicAction初始化, 但HTTP配置 (base_url 或 volcano_generate_api_key) 缺失. HTTP图片生成将失败."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"{self.log_prefix} HTTP方式初始化完成. Base URL: {http_base_url}, API Key已配置.")
|
||||||
|
|
||||||
|
# _restore_env_vars 方法不再需要,已移除
|
||||||
|
|
||||||
|
async def process(self) -> Tuple[bool, str]:
|
||||||
|
"""处理图片生成动作(通过HTTP API)"""
|
||||||
|
logger.info(f"{self.log_prefix} 执行 pic_action (HTTP): {self.reasoning}")
|
||||||
|
|
||||||
|
http_base_url = self.config.get("base_url")
|
||||||
|
http_api_key = self.config.get("volcano_generate_api_key")
|
||||||
|
|
||||||
|
if not (http_base_url and http_api_key):
|
||||||
|
error_msg = "抱歉,图片生成功能所需的HTTP配置(如API地址或密钥)不完整,无法提供服务。"
|
||||||
|
await self.send_message_by_expressor(error_msg)
|
||||||
|
logger.error(f"{self.log_prefix} HTTP调用配置缺失: base_url 或 volcano_generate_api_key.")
|
||||||
|
return False, "HTTP配置不完整"
|
||||||
|
|
||||||
|
description = self.action_data.get("description")
|
||||||
|
if not description:
|
||||||
|
logger.warning(f"{self.log_prefix} 图片描述为空,无法生成图片。")
|
||||||
|
await self.send_message_by_expressor("你需要告诉我想要画什么样的图片哦~")
|
||||||
|
return False, "图片描述为空"
|
||||||
|
|
||||||
|
default_model = self.config.get("default_model", "doubao-seedream-3-0-t2i-250415")
|
||||||
|
image_size = self.action_data.get("size", self.config.get("default_size", "1024x1024"))
|
||||||
|
|
||||||
|
# guidance_scale 现在完全由配置文件控制
|
||||||
|
guidance_scale_input = self.config.get("default_guidance_scale", 2.5) # 默认2.5
|
||||||
|
guidance_scale_val = 2.5 # Fallback default
|
||||||
|
try:
|
||||||
|
guidance_scale_val = float(guidance_scale_input)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"{self.log_prefix} 配置文件中的 default_guidance_scale 值 '{guidance_scale_input}' 无效 (应为浮点数),使用默认值 2.5。"
|
||||||
|
)
|
||||||
|
guidance_scale_val = 2.5
|
||||||
|
|
||||||
|
# Seed parameter - ensure it's always an integer
|
||||||
|
seed_config_value = self.config.get("default_seed")
|
||||||
|
seed_val = 42 # Default seed if not configured or invalid
|
||||||
|
if seed_config_value is not None:
|
||||||
|
try:
|
||||||
|
seed_val = int(seed_config_value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"{self.log_prefix} 配置文件中的 default_seed ('{seed_config_value}') 无效,将使用默认种子 42。"
|
||||||
|
)
|
||||||
|
# seed_val is already 42
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"{self.log_prefix} 未在配置中找到 default_seed,将使用默认种子 42。建议在配置文件中添加 default_seed。"
|
||||||
|
)
|
||||||
|
# seed_val is already 42
|
||||||
|
|
||||||
|
# Watermark 现在完全由配置文件控制
|
||||||
|
effective_watermark_source = self.config.get("default_watermark", True) # 默认True
|
||||||
|
if isinstance(effective_watermark_source, bool):
|
||||||
|
watermark_val = effective_watermark_source
|
||||||
|
elif isinstance(effective_watermark_source, str):
|
||||||
|
watermark_val = effective_watermark_source.lower() == "true"
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"{self.log_prefix} 配置文件中的 default_watermark 值 '{effective_watermark_source}' 无效 (应为布尔值或 'true'/'false'),使用默认值 True。"
|
||||||
|
)
|
||||||
|
watermark_val = True
|
||||||
|
|
||||||
|
await self.send_message_by_expressor(
|
||||||
|
f"收到!正在为您生成关于 '{description}' 的图片,请稍候...(模型: {default_model}, 尺寸: {image_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
success, result = await asyncio.to_thread(
|
||||||
|
self._make_http_image_request,
|
||||||
|
prompt=description,
|
||||||
|
model=default_model,
|
||||||
|
size=image_size,
|
||||||
|
seed=seed_val,
|
||||||
|
guidance_scale=guidance_scale_val,
|
||||||
|
watermark=watermark_val,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} (HTTP) 异步请求执行失败: {e!r}", exc_info=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
success = False
|
||||||
|
result = f"图片生成服务遇到意外问题: {str(e)[:100]}"
|
||||||
|
|
||||||
|
if success:
|
||||||
|
image_url = result
|
||||||
|
logger.info(f"{self.log_prefix} 图片URL获取成功: {image_url[:70]}... 下载并编码.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
encode_success, encode_result = await asyncio.to_thread(self._download_and_encode_base64, image_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} (B64) 异步下载/编码失败: {e!r}", exc_info=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
encode_success = False
|
||||||
|
encode_result = f"图片下载或编码时发生内部错误: {str(e)[:100]}"
|
||||||
|
|
||||||
|
if encode_success:
|
||||||
|
base64_image_string = encode_result
|
||||||
|
send_success = await self.send_message(type="emoji", data=base64_image_string)
|
||||||
|
if send_success:
|
||||||
|
await self.send_message_by_expressor("图片表情已发送!")
|
||||||
|
return True, "图片表情已发送"
|
||||||
|
else:
|
||||||
|
await self.send_message_by_expressor("图片已处理为Base64,但作为表情发送失败了。")
|
||||||
|
return False, "图片表情发送失败 (Base64)"
|
||||||
|
else:
|
||||||
|
await self.send_message_by_expressor(f"获取到图片URL,但在处理图片时失败了:{encode_result}")
|
||||||
|
return False, f"图片处理失败(Base64): {encode_result}"
|
||||||
|
else:
|
||||||
|
error_message = result
|
||||||
|
await self.send_message_by_expressor(f"哎呀,生成图片时遇到问题:{error_message}")
|
||||||
|
return False, f"图片生成失败: {error_message}"
|
||||||
|
|
||||||
|
def _download_and_encode_base64(self, image_url: str) -> Tuple[bool, str]:
|
||||||
|
"""下载图片并将其编码为Base64字符串"""
|
||||||
|
logger.info(f"{self.log_prefix} (B64) 下载并编码图片: {image_url[:70]}...")
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(image_url, timeout=30) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
image_bytes = response.read()
|
||||||
|
base64_encoded_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
logger.info(f"{self.log_prefix} (B64) 图片下载编码完成. Base64长度: {len(base64_encoded_image)}")
|
||||||
|
return True, base64_encoded_image
|
||||||
|
else:
|
||||||
|
error_msg = f"下载图片失败 (状态: {response.status})"
|
||||||
|
logger.error(f"{self.log_prefix} (B64) {error_msg} URL: {image_url}")
|
||||||
|
return False, error_msg
|
||||||
|
except Exception as e: # Catches all exceptions from urlopen, b64encode, etc.
|
||||||
|
logger.error(f"{self.log_prefix} (B64) 下载或编码时错误: {e!r}", exc_info=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
return False, f"下载或编码图片时发生错误: {str(e)[:100]}"
|
||||||
|
|
||||||
|
def _make_http_image_request(
|
||||||
|
self, prompt: str, model: str, size: str, seed: int | None, guidance_scale: float, watermark: bool
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
base_url = self.config.get("base_url")
|
||||||
|
generate_api_key = self.config.get("volcano_generate_api_key")
|
||||||
|
|
||||||
|
endpoint = f"{base_url.rstrip('/')}/images/generations"
|
||||||
|
|
||||||
|
payload_dict = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"response_format": "url",
|
||||||
|
"size": size,
|
||||||
|
"guidance_scale": guidance_scale,
|
||||||
|
"watermark": watermark,
|
||||||
|
"seed": seed, # seed is now always an int from process()
|
||||||
|
"api-key": generate_api_key,
|
||||||
|
}
|
||||||
|
# if seed is not None: # No longer needed, seed is always an int
|
||||||
|
# payload_dict["seed"] = seed
|
||||||
|
|
||||||
|
data = json.dumps(payload_dict).encode("utf-8")
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Authorization": f"Bearer {generate_api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"{self.log_prefix} (HTTP) 发起图片请求: {model}, Prompt: {prompt[:30]}... To: {endpoint}")
|
||||||
|
logger.debug(
|
||||||
|
f"{self.log_prefix} (HTTP) Request Headers: {{...Authorization: Bearer {generate_api_key[:10]}...}}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"{self.log_prefix} (HTTP) Request Body (api-key omitted): {json.dumps({k: v for k, v in payload_dict.items() if k != 'api-key'})}"
|
||||||
|
)
|
||||||
|
|
||||||
|
req = urllib.request.Request(endpoint, data=data, headers=headers, method="POST")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req, timeout=60) as response:
|
||||||
|
response_status = response.status
|
||||||
|
response_body_bytes = response.read()
|
||||||
|
response_body_str = response_body_bytes.decode("utf-8")
|
||||||
|
|
||||||
|
logger.info(f"{self.log_prefix} (HTTP) 响应: {response_status}. Preview: {response_body_str[:150]}...")
|
||||||
|
|
||||||
|
if 200 <= response_status < 300:
|
||||||
|
response_data = json.loads(response_body_str)
|
||||||
|
image_url = None
|
||||||
|
if (
|
||||||
|
isinstance(response_data.get("data"), list)
|
||||||
|
and response_data["data"]
|
||||||
|
and isinstance(response_data["data"][0], dict)
|
||||||
|
):
|
||||||
|
image_url = response_data["data"][0].get("url")
|
||||||
|
elif response_data.get("url"):
|
||||||
|
image_url = response_data.get("url")
|
||||||
|
|
||||||
|
if image_url:
|
||||||
|
logger.info(f"{self.log_prefix} (HTTP) 图片生成成功,URL: {image_url[:70]}...")
|
||||||
|
return True, image_url
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"{self.log_prefix} (HTTP) API成功但无图片URL. 响应预览: {response_body_str[:300]}..."
|
||||||
|
)
|
||||||
|
return False, "图片生成API响应成功但未找到图片URL"
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"{self.log_prefix} (HTTP) API请求失败. 状态: {response.status}. 正文: {response_body_str[:300]}..."
|
||||||
|
)
|
||||||
|
return False, f"图片API请求失败(状态码 {response.status})"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} (HTTP) 图片生成时意外错误: {e!r}", exc_info=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
return False, f"图片生成HTTP请求时发生意外错误: {str(e)[:100]}"
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
# 请替换为您的火山引擎 Access Key ID
|
||||||
|
volcano_ak = "YOUR_VOLCANO_ENGINE_ACCESS_KEY_ID_HERE"
|
||||||
|
# 请替换为您的火山引擎 Secret Access Key
|
||||||
|
volcano_sk = "YOUR_VOLCANO_ENGINE_SECRET_ACCESS_KEY_HERE"
|
||||||
|
# 火山方舟 API 的基础 URL
|
||||||
|
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||||
|
# 默认图片生成模型
|
||||||
|
default_model = "doubao-seedream-3-0-t2i-250415"
|
||||||
|
# 默认图片尺寸
|
||||||
|
default_size = "1024x1024"
|
||||||
|
# 用于图片生成的API密钥
|
||||||
|
# PicAction 当前配置为在HTTP请求体和Authorization头中使用此密钥。
|
||||||
|
# 如果您的API认证方式不同,请相应调整或移除。
|
||||||
|
volcano_generate_api_key = "YOUR_VOLCANO_GENERATE_API_KEY_HERE"
|
||||||
|
|
||||||
|
# 是否默认开启水印
|
||||||
|
default_watermark = true
|
||||||
|
# 默认引导强度
|
||||||
|
default_guidance_scale = 2.5
|
||||||
|
# 默认随机种子
|
||||||
|
default_seed = 42
|
||||||
|
|
||||||
|
# 更多插件特定配置可以在此添加...
|
||||||
|
# custom_parameter = "some_value"
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from src.tools.tool_can_use.base_tool import BaseTool, register_tool
|
from src.tools.tool_can_use.base_tool import BaseTool, register_tool
|
||||||
from src.chat.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from src.chat.models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
import json
|
import json
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue