pull/1001/head
SnowindMe 2025-05-23 21:01:49 +08:00
commit 774e72340d
63 changed files with 850 additions and 301 deletions

View File

@ -1,5 +1,12 @@
name: Ruff name: Ruff
on: [ push ]
on:
push:
branches:
- main
- dev
- dev-refactor # 例如:匹配所有以 feature/ 开头的分支
# 添加你希望触发此 workflow 的其他分支
permissions: permissions:
contents: write contents: write
@ -7,6 +14,10 @@ permissions:
jobs: jobs:
ruff: ruff:
runs-on: ubuntu-latest runs-on: ubuntu-latest
# 关键修改:添加条件判断
# 确保只有在 event_name 是 'push' 且不是由 Pull Request 引起的 push 时才运行
if: github.event_name == 'push' && !startsWith(github.ref, 'refs/pull/')
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
@ -25,5 +36,4 @@ jobs:
git config --local user.name "github-actions[bot]" git config --local user.name "github-actions[bot]"
git add -A git add -A
git diff --quiet && git diff --staged --quiet || git commit -m "🤖 自动格式化代码 [skip ci]" git diff --quiet && git diff --staged --quiet || git commit -m "🤖 自动格式化代码 [skip ci]"
git push git push

View File

@ -65,11 +65,11 @@
## 💬 讨论 ## 💬 讨论
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) | - [一群](https://qm.qq.com/q/VQ3XZrWgMs) |
[四群](https://qm.qq.com/q/wGePTl1UyY) |
[二群](https://qm.qq.com/q/RzmCiRtHEW) | [二群](https://qm.qq.com/q/RzmCiRtHEW) |
[五群](https://qm.qq.com/q/JxvHZnxyec) | [五群](https://qm.qq.com/q/JxvHZnxyec)(已满) |
[三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)| [三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)
[四群](https://qm.qq.com/q/wGePTl1UyY)(已满)
## 📚 文档 ## 📚 文档

View File

@ -27,6 +27,11 @@
- 表达器:装饰语言风格 - 表达器:装饰语言风格
- 可通过插件添加和自定义HFC部件目前只支持action定义 - 可通过插件添加和自定义HFC部件目前只支持action定义
**插件系统**
- 添加示例插件
- 示例插件:禁言插件
- 示例插件:豆包绘图插件
**新增表达方式学习** **新增表达方式学习**
- 自主学习群聊中的表达方式,更贴近群友 - 自主学习群聊中的表达方式,更贴近群友
- 可自定义的学习频率和开关 - 可自定义的学习频率和开关
@ -45,7 +50,6 @@
**优化** **优化**
- 移除日程系统,减少幻觉(将会在未来版本回归) - 移除日程系统,减少幻觉(将会在未来版本回归)
- 移除主心流思考和LLM进入聊天判定 - 移除主心流思考和LLM进入聊天判定
-
## [0.6.3-fix-4] - 2025-5-18 ## [0.6.3-fix-4] - 2025-5-18

View File

@ -4,7 +4,7 @@
# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9 # 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9
# 请小心使用任何一键脚本! # 请小心使用任何一键脚本!
INSTALLER_VERSION="0.0.4-refactor" INSTALLER_VERSION="0.0.5-refactor"
LANG=C.UTF-8 LANG=C.UTF-8
# 如无法访问GitHub请修改此处镜像地址 # 如无法访问GitHub请修改此处镜像地址
@ -33,7 +33,6 @@ SERVICE_NAME="maicore"
SERVICE_NAME_WEB="maicore-web" SERVICE_NAME_WEB="maicore-web"
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter" SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
IS_INSTALL_MONGODB=false
IS_INSTALL_NAPCAT=false IS_INSTALL_NAPCAT=false
IS_INSTALL_DEPENDENCIES=false IS_INSTALL_DEPENDENCIES=false
@ -255,7 +254,6 @@ run_installation() {
return return
elif [[ "$ID" == "arch" ]]; then elif [[ "$ID" == "arch" ]]; then
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60 whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
whiptail --title "⚠️ 兼容性警告" --msgbox "MongoDB无可用的 Arch Linux 官方安装方法将无法自动安装MongoDB。\n\n您可尝试在AUR中搜索相关包。" 10 60
return return
else else
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60 whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
@ -282,16 +280,6 @@ run_installation() {
;; ;;
esac esac
# 检查MongoDB
check_mongodb() {
if command -v mongod &>/dev/null; then
MONGO_INSTALLED=true
else
MONGO_INSTALLED=false
fi
}
check_mongodb
# 检查NapCat # 检查NapCat
check_napcat() { check_napcat() {
if command -v napcat &>/dev/null; then if command -v napcat &>/dev/null; then
@ -330,19 +318,7 @@ run_installation() {
fi fi
} }
install_packages install_packages
# 安装MongoDB
install_mongodb() {
[[ $MONGO_INSTALLED == true ]] && return
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装MongoDB是否安装\n如果您想使用远程数据库请跳过此步。" 10 60 && {
IS_INSTALL_MONGODB=true
}
}
# 仅在非Arch系统上安装MongoDB
[[ "$ID" != "arch" ]] && install_mongodb
# 安装NapCat # 安装NapCat
install_napcat() { install_napcat() {
[[ $NAPCAT_INSTALLED == true ]] && return [[ $NAPCAT_INSTALLED == true ]] && return
@ -413,9 +389,8 @@ run_installation() {
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n" confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
confirm_msg+="🔀 分支: $BRANCH\n" confirm_msg+="🔀 分支: $BRANCH\n"
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n" [[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
[[ $IS_INSTALL_MONGODB == true || $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n" [[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
[[ $IS_INSTALL_MONGODB == true ]] && confirm_msg+=" - MongoDB\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n" [[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
confirm_msg+="\n注意本脚本默认使用ghfast.top为GitHub进行加速如不想使用请手动修改脚本开头的GITHUB_REPO变量。" confirm_msg+="\n注意本脚本默认使用ghfast.top为GitHub进行加速如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
@ -440,39 +415,6 @@ run_installation() {
esac esac
fi fi
if [[ $IS_INSTALL_MONGODB == true ]]; then
echo -e "${GREEN}安装 MongoDB...${RESET}"
case "$ID" in
debian)
curl -fsSL https://www.mongodb.org/static/pgp/server-8.0.asc | gpg -o /usr/share/keyrings/mongodb-server-8.0.gpg --dearmor
echo "deb [ signed-by=/usr/share/keyrings/mongodb-server-8.0.gpg ] http://repo.mongodb.org/apt/debian bookworm/mongodb-org/8.0 main" | tee /etc/apt/sources.list.d/mongodb-org-8.0.list
apt update
apt install -y mongodb-org
systemctl enable --now mongod
;;
ubuntu)
curl -fsSL https://www.mongodb.org/static/pgp/server-8.0.asc | gpg -o /usr/share/keyrings/mongodb-server-8.0.gpg --dearmor
echo "deb [ signed-by=/usr/share/keyrings/mongodb-server-8.0.gpg ] http://repo.mongodb.org/apt/debian bookworm/mongodb-org/8.0 main" | tee /etc/apt/sources.list.d/mongodb-org-8.0.list
apt update
apt install -y mongodb-org
systemctl enable --now mongod
;;
centos)
cat > /etc/yum.repos.d/mongodb-org-8.0.repo <<EOF
[mongodb-org-8.0]
name=MongoDB Repository
baseurl=https://repo.mongodb.org/yum/redhat/9/mongodb-org/8.0/x86_64/
gpgcheck=1
enabled=1
gpgkey=https://pgp.mongodb.com/server-8.0.asc
EOF
yum install -y mongodb-org
systemctl enable --now mongod
;;
esac
fi
if [[ $IS_INSTALL_NAPCAT == true ]]; then if [[ $IS_INSTALL_NAPCAT == true ]]; then
echo -e "${GREEN}安装 NapCat...${RESET}" echo -e "${GREEN}安装 NapCat...${RESET}"
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
@ -537,7 +479,7 @@ EOF
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
[Unit] [Unit]
Description=MaiCore Description=MaiCore
After=network.target mongod.service ${SERVICE_NAME_NBADAPTER}.service After=network.target ${SERVICE_NAME_NBADAPTER}.service
[Service] [Service]
Type=simple Type=simple
@ -550,21 +492,21 @@ RestartSec=10s
WantedBy=multi-user.target WantedBy=multi-user.target
EOF EOF
cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF # cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
[Unit] # [Unit]
Description=MaiCore WebUI # Description=MaiCore WebUI
After=network.target mongod.service ${SERVICE_NAME}.service # After=network.target ${SERVICE_NAME}.service
[Service] # [Service]
Type=simple # Type=simple
WorkingDirectory=${INSTALL_DIR}/MaiBot # WorkingDirectory=${INSTALL_DIR}/MaiBot
ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py # ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
Restart=always # Restart=always
RestartSec=10s # RestartSec=10s
[Install] # [Install]
WantedBy=multi-user.target # WantedBy=multi-user.target
EOF # EOF
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
[Unit] [Unit]

View File

@ -5,7 +5,7 @@ MaiBot模块系统
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.person_info.relationship_manager import relationship_manager from src.person_info.relationship_manager import relationship_manager
from src.chat.normal_chat.willing.willing_manager import willing_manager from src.chat.normal_chat.willing.willing_manager import willing_manager
# 导出主要组件供外部使用 # 导出主要组件供外部使用

View File

@ -12,11 +12,11 @@ import re
# from gradio_client import file # from gradio_client import file
from ...common.database.database_model import Emoji from src.common.database.database_model import Emoji
from ...common.database.database import db as peewee_db from src.common.database.database import db as peewee_db
from ...config.config import global_config from src.config.config import global_config
from ..utils.utils_image import image_path_to_base64, image_manager from src.chat.utils.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install from rich.traceback import install

View File

@ -5,7 +5,7 @@ from src.chat.message_receive.message import Seg # Local import needed after mo
from src.chat.message_receive.message import UserInfo from src.chat.message_receive.message import UserInfo
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.timer_calculator import Timer # <--- Import Timer

View File

@ -2,7 +2,7 @@ import time
import random import random
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
from src.chat.focus_chat.heartflow_prompt_builder import Prompt, global_prompt_manager from src.chat.focus_chat.heartflow_prompt_builder import Prompt, global_prompt_manager

View File

@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any, Deque
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
from rich.traceback import install from rich.traceback import install
from src.chat.utils.prompt_builder import global_prompt_manager
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.observation import Observation
@ -30,13 +31,14 @@ from src.config.config import global_config
install(extra_lines=3) install(extra_lines=3)
# 定义处理器映射 # 定义处理器映射:键是处理器名称,值是 (处理器类, 可选的配置键名)
# 如果配置键名为 None则该处理器默认启用且不能通过 focus_chat_processor 配置禁用
PROCESSOR_CLASSES = { PROCESSOR_CLASSES = {
"ChattingInfoProcessor": ChattingInfoProcessor, "ChattingInfoProcessor": (ChattingInfoProcessor, None),
"MindProcessor": MindProcessor, "MindProcessor": (MindProcessor, None),
"ToolProcessor": ToolProcessor, "ToolProcessor": (ToolProcessor, "tool_use_processor"),
"WorkingMemoryProcessor": WorkingMemoryProcessor, "WorkingMemoryProcessor": (WorkingMemoryProcessor, "working_memory_processor"),
"SelfProcessor": SelfProcessor, "SelfProcessor": (SelfProcessor, "self_identify_processor"),
} }
@ -101,27 +103,19 @@ class HeartFChatting:
) )
# 根据配置文件和默认规则确定启用的处理器 # 根据配置文件和默认规则确定启用的处理器
PROCESSOR_NAME_TO_CONFIG_KEY_MAP = {
"SelfProcessor": "self_identify_processor",
"ToolProcessor": "tool_use_processor",
"WorkingMemoryProcessor": "working_memory_processor",
}
self.enabled_processor_names: List[str] = [] self.enabled_processor_names: List[str] = []
config_processor_settings = global_config.focus_chat_processor # 获取处理器配置,若不存在则为空字典 config_processor_settings = global_config.focus_chat_processor
for proc_name in PROCESSOR_CLASSES.keys(): for proc_name, (_proc_class, config_key) in PROCESSOR_CLASSES.items():
config_key = PROCESSOR_NAME_TO_CONFIG_KEY_MAP.get(proc_name) if config_key: # 此处理器可通过配置控制
if config_key: if getattr(config_processor_settings, config_key, True): # 默认启用 (如果配置中未指定该键)
# 此处理器可通过配置控制
# getattr(config_processor_settings, config_key, True)
# 如果config_processor_settings是字典则用 config_processor_settings.get(config_key, True)
if getattr(config_processor_settings, config_key, True): # 默认启用,如果配置中未指定
self.enabled_processor_names.append(proc_name) self.enabled_processor_names.append(proc_name)
else: else: # 此处理器不在配置映射中 (config_key is None),默认启用
# 此处理器不在配置映射中,默认启用
self.enabled_processor_names.append(proc_name) self.enabled_processor_names.append(proc_name)
logger.info(f"{self.log_prefix} 将启用的处理器: {self.enabled_processor_names}") logger.info(f"{self.log_prefix} 将启用的处理器: {self.enabled_processor_names}")
self.processors: List[BaseProcessor] = []
self._register_default_processors()
self.expressor = DefaultExpressor(chat_id=self.stream_id) self.expressor = DefaultExpressor(chat_id=self.stream_id)
self.action_manager = ActionManager() self.action_manager = ActionManager()
@ -130,9 +124,6 @@ class HeartFChatting:
self.hfcloop_observation.set_action_manager(self.action_manager) self.hfcloop_observation.set_action_manager(self.action_manager)
self.all_observations = observations self.all_observations = observations
# --- 处理器列表 ---
self.processors: List[BaseProcessor] = []
self._register_default_processors()
# 初始化状态控制 # 初始化状态控制
self._initialized = False self._initialized = False
@ -186,19 +177,20 @@ class HeartFChatting:
"""根据 self.enabled_processor_names 注册信息处理器""" """根据 self.enabled_processor_names 注册信息处理器"""
self.processors = [] # 清空已有的 self.processors = [] # 清空已有的
# self.enabled_processor_names 由 __init__ 保证是一个列表
for name in self.enabled_processor_names: # 'name' is "ChattingInfoProcessor", etc. for name in self.enabled_processor_names: # 'name' is "ChattingInfoProcessor", etc.
processor_class = PROCESSOR_CLASSES.get(name) processor_info = PROCESSOR_CLASSES.get(name) # processor_info is (ProcessorClass, config_key)
if processor_class: if processor_info:
processor_actual_class = processor_info[0] # 获取实际的类定义
# 根据处理器类名判断是否需要 subheartflow_id # 根据处理器类名判断是否需要 subheartflow_id
if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor"]: if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor"]:
self.processors.append(processor_class(subheartflow_id=self.stream_id)) self.processors.append(processor_actual_class(subheartflow_id=self.stream_id))
elif name == "ChattingInfoProcessor": elif name == "ChattingInfoProcessor":
self.processors.append(processor_class()) self.processors.append(processor_actual_class())
else: else:
# 对于PROCESSOR_CLASSES中定义但此处未明确处理构造的处理器 # 对于PROCESSOR_CLASSES中定义但此处未明确处理构造的处理器
# (例如, 新增了一个处理器到PROCESSOR_CLASSES, 它不需要id, 也不叫ChattingInfoProcessor)
try: try:
self.processors.append(processor_class()) # 尝试无参构造 self.processors.append(processor_actual_class()) # 尝试无参构造
logger.debug(f"{self.log_prefix} 注册处理器 {name} (尝试无参构造).") logger.debug(f"{self.log_prefix} 注册处理器 {name} (尝试无参构造).")
except TypeError: except TypeError:
logger.error( logger.error(
@ -206,7 +198,9 @@ class HeartFChatting:
) )
else: else:
# 这理论上不应该发生,因为 enabled_processor_names 是从 PROCESSOR_CLASSES 的键生成的 # 这理论上不应该发生,因为 enabled_processor_names 是从 PROCESSOR_CLASSES 的键生成的
logger.warning(f"{self.log_prefix} 在 PROCESSOR_CLASSES 中未找到名为 '{name}' 的处理器,将跳过注册。") logger.warning(
f"{self.log_prefix} 在 PROCESSOR_CLASSES 中未找到名为 '{name}' 的处理器定义,将跳过注册。"
)
if self.processors: if self.processors:
logger.info( logger.info(
@ -286,8 +280,9 @@ class HeartFChatting:
thinking_id = "tid" + str(round(time.time(), 2)) thinking_id = "tid" + str(round(time.time(), 2))
self._current_cycle.set_thinking_id(thinking_id) self._current_cycle.set_thinking_id(thinking_id)
# 主循环:思考->决策->执行 # 主循环:思考->决策->执行
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id) logger.debug(f"模板 {self.chat_stream.context.get_template_name()}")
loop_info = await self._observe_process_plan_action_loop(cycle_timers, thinking_id)
self._current_cycle.set_loop_info(loop_info) self._current_cycle.set_loop_info(loop_info)
@ -430,7 +425,10 @@ class HeartFChatting:
self.all_observations = observations self.all_observations = observations
with Timer("回忆", cycle_timers): with Timer("回忆", cycle_timers):
logger.debug(f"{self.log_prefix} 开始回忆")
running_memorys = await self.memory_activator.activate_memory(observations) running_memorys = await self.memory_activator.activate_memory(observations)
logger.debug(f"{self.log_prefix} 回忆完成")
print(running_memorys)
with Timer("执行 信息处理器", cycle_timers): with Timer("执行 信息处理器", cycle_timers):
all_plan_info = await self._process_processors(observations, running_memorys, cycle_timers) all_plan_info = await self._process_processors(observations, running_memorys, cycle_timers)

View File

@ -11,7 +11,7 @@ from ..message_receive.chat_stream import chat_manager
# from ..message_receive.message_buffer import message_buffer # from ..message_receive.message_buffer import message_buffer
from ..utils.timer_calculator import Timer from ..utils.timer_calculator import Timer
from src.chat.person_info.relationship_manager import relationship_manager from src.person_info.relationship_manager import relationship_manager
from typing import Optional, Tuple, Dict, Any from typing import Optional, Tuple, Dict, Any
logger = get_logger("chat") logger = get_logger("chat")

View File

@ -3,7 +3,7 @@ from src.common.logger_manager import get_logger
from src.individuality.individuality import individuality from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.person_info.relationship_manager import relationship_manager from src.person_info.relationship_manager import relationship_manager
import time import time
from typing import Optional from typing import Optional
from src.chat.utils.utils import get_recent_group_speaker from src.chat.utils.utils import get_recent_group_speaker
@ -120,7 +120,6 @@ class PromptBuilder:
relation_prompt += await relationship_manager.build_relationship_info(person) relation_prompt += await relationship_manager.build_relationship_info(person)
else: else:
logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}") logger.warning(f"Invalid person tuple encountered for relationship prompt: {person}")
mood_prompt = mood_manager.get_mood_prompt() mood_prompt = mood_manager.get_mood_prompt()
reply_styles1 = [ reply_styles1 = [
("然后给出日常且口语化的回复,平淡一些", 0.4), ("然后给出日常且口语化的回复,平淡一些", 0.4),
@ -141,9 +140,11 @@ class PromptBuilder:
[style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1 [style[0] for style in reply_styles2], weights=[style[1] for style in reply_styles2], k=1
)[0] )[0]
memory_prompt = "" memory_prompt = ""
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
) )
related_memory_info = "" related_memory_info = ""
if related_memory: if related_memory:
for memory in related_memory: for memory in related_memory:

View File

@ -37,4 +37,4 @@ class SelfInfo(InfoBase):
Returns: Returns:
str: 处理后的信息 str: 处理后的信息
""" """
return self.get_self_info() return self.get_self_info() or ""

View File

@ -67,3 +67,16 @@ class StructuredInfo:
value: 要设置的属性值 value: 要设置的属性值
""" """
self.data[key] = value self.data[key] = value
def get_processed_info(self) -> str:
"""获取处理后的信息
Returns:
str: 处理后的信息字符串
"""
info_str = ""
for key, value in self.data.items():
info_str += f"信息类型:{key},信息内容:{value}\n"
return info_str

View File

@ -6,7 +6,7 @@ from .base_processor import BaseProcessor
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from typing import Dict from typing import Dict
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
import random import random

View File

@ -9,7 +9,7 @@ from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservati
from src.chat.focus_chat.info.cycle_info import CycleInfo from src.chat.focus_chat.info.cycle_info import CycleInfo
from datetime import datetime from datetime import datetime
from typing import Dict from typing import Dict
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
logger = get_logger("processor") logger = get_logger("processor")

View File

@ -1,6 +1,6 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.observation import Observation
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
import time import time
import traceback import traceback
@ -9,7 +9,7 @@ from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import safe_json_dumps from src.chat.utils.json_utils import safe_json_dumps
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
from src.chat.person_info.relationship_manager import relationship_manager from src.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor from .base_processor import BaseProcessor
from src.chat.focus_chat.info.mind_info import MindInfo from src.chat.focus_chat.info.mind_info import MindInfo
from typing import List, Optional from typing import List, Optional

View File

@ -1,6 +1,6 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.observation import Observation
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
import time import time
import traceback import traceback
@ -8,7 +8,7 @@ from src.common.logger_manager import get_logger
from src.individuality.individuality import individuality from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
from src.chat.person_info.relationship_manager import relationship_manager from src.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor from .base_processor import BaseProcessor
from typing import List, Optional from typing import List, Optional
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
@ -33,12 +33,13 @@ def init_prompt():
现在请你根据现有的信息思考自我认同 现在请你根据现有的信息思考自我认同
1. 你是一个什么样的人,你和群里的人关系如何 1. 你是一个什么样的人,你和群里的人关系如何
2. 思考有没有人提到你或者图片与你有关 2. 你的形象是什么
3. 你的自我认同是否有助于你的回答如果你需要自我相关的信息来帮你参与聊天请输出否则请输出十个字以内的简短自我认同 3. 思考有没有人提到你或者图片与你有关
4. 一般情况下不用输出自我认同只需要输出十几个字的简短自我认同就好除非有明显需要自我认同的场景 4. 你的自我认同是否有助于你的回答如果你需要自我相关的信息来帮你参与聊天请输出否则请输出十几个字的简短自我认同
5. 一般情况下不用输出自我认同只需要输出十几个字的简短自我认同就好除非有明显需要自我认同的场景
请思考的平淡一些简短一些说中文不要浮夸平淡一些 输出内容平淡一些说中文不要浮夸平淡一些
请注意不要输出多余内容(包括前后缀冒号和引号括号()表情包at或 @等 )只输出自我认同内容 请注意不要输出多余内容(包括前后缀冒号和引号括号()表情包at或 @等 )只输出自我认同内容记得明确说明这是你的自我认同
""" """
Prompt(indentify_prompt, "indentify_prompt") Prompt(indentify_prompt, "indentify_prompt")

View File

@ -1,5 +1,5 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
import time import time
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
@ -7,7 +7,7 @@ from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.tools.tool_use import ToolUser from src.tools.tool_use import ToolUser
from src.chat.utils.json_utils import process_llm_tool_calls from src.chat.utils.json_utils import process_llm_tool_calls
from src.chat.person_info.relationship_manager import relationship_manager from src.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor from .base_processor import BaseProcessor
from typing import List, Optional, Dict from typing import List, Optional, Dict
from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.observation import Observation

View File

@ -1,6 +1,6 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.observation import Observation
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
import time import time
import traceback import traceback

View File

@ -1,7 +1,7 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.structure_observation import StructureObservation from src.chat.heart_flow.observation.structure_observation import StructureObservation
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.utils.prompt_builder import Prompt from src.chat.utils.prompt_builder import Prompt
@ -61,6 +61,8 @@ class MemoryActivator:
elif isinstance(observation, HFCloopObservation): elif isinstance(observation, HFCloopObservation):
obs_info_text += observation.get_observe_info() obs_info_text += observation.get_observe_info()
logger.debug(f"回忆待检索内容obs_info_text: {obs_info_text}")
# prompt = await global_prompt_manager.format_prompt( # prompt = await global_prompt_manager.format_prompt(
# "memory_activator_prompt", # "memory_activator_prompt",
# obs_info_text=obs_info_text, # obs_info_text=obs_info_text,
@ -81,7 +83,7 @@ class MemoryActivator:
# valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3 # valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
# ) # )
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=obs_info_text, max_memory_num=3, max_memory_length=2, max_depth=3, fast_retrieval=True text=obs_info_text, max_memory_num=5, max_memory_length=2, max_depth=3, fast_retrieval=True
) )
# logger.debug(f"获取到的记忆: {related_memory}") # logger.debug(f"获取到的记忆: {related_memory}")

View File

@ -78,7 +78,7 @@ class ExitFocusChatAction(BaseAction):
if self.sub_heartflow: if self.sub_heartflow:
try: try:
# 转换为normal_chat状态 # 转换为normal_chat状态
await self.sub_heartflow.change_chat_state(ChatState.NORMAL_CHAT) await self.sub_heartflow.change_chat_state(ChatState.CHAT)
status_message = "已成功切换到普通聊天模式" status_message = "已成功切换到普通聊天模式"
logger.info(f"{self.log_prefix} {status_message}") logger.info(f"{self.log_prefix} {status_message}")
except Exception as e: except Exception as e:

View File

@ -1,11 +1,14 @@
import traceback import traceback
from typing import Tuple, Dict, List, Any, Optional from typing import Tuple, Dict, List, Any, Optional
from src.chat.focus_chat.planners.actions.base_action import BaseAction from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action # noqa F401
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.person_info.person_info import person_info_manager from src.person_info.person_info import person_info_manager
from abc import abstractmethod from abc import abstractmethod
import os
import inspect
import toml # 导入 toml 库
logger = get_logger("plugin_action") logger = get_logger("plugin_action")
@ -16,12 +19,24 @@ class PluginAction(BaseAction):
封装了主程序内部依赖提供简化的API接口给插件开发者 封装了主程序内部依赖提供简化的API接口给插件开发者
""" """
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs): action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
def __init__(
self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
global_config: Optional[dict] = None,
**kwargs,
):
"""初始化插件动作基类""" """初始化插件动作基类"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id) super().__init__(action_data, reasoning, cycle_timers, thinking_id)
# 存储内部服务和对象引用 # 存储内部服务和对象引用
self._services = {} self._services = {}
self._global_config = global_config # 存储全局配置的只读引用
self.config: Dict[str, Any] = {} # 用于存储插件自身的配置
# 从kwargs提取必要的内部服务 # 从kwargs提取必要的内部服务
if "observations" in kwargs: if "observations" in kwargs:
@ -32,6 +47,61 @@ class PluginAction(BaseAction):
self._services["chat_stream"] = kwargs["chat_stream"] self._services["chat_stream"] = kwargs["chat_stream"]
self.log_prefix = kwargs.get("log_prefix", "") self.log_prefix = kwargs.get("log_prefix", "")
self._load_plugin_config() # 初始化时加载插件配置
def _load_plugin_config(self):
"""
加载插件自身的配置文件
配置文件应与插件模块在同一目录下
插件可以通过覆盖 `action_config_file_name` 类属性来指定文件名
如果 `action_config_file_name` 未指定则不加载配置
仅支持 TOML (.toml) 格式
"""
if not self.action_config_file_name:
logger.debug(
f"{self.log_prefix} 插件 {self.__class__.__name__} 未指定 action_config_file_name不加载插件配置。"
)
return
try:
plugin_module_path = inspect.getfile(self.__class__)
plugin_dir = os.path.dirname(plugin_module_path)
config_file_path = os.path.join(plugin_dir, self.action_config_file_name)
if not os.path.exists(config_file_path):
logger.warning(
f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置文件 {config_file_path} 不存在。"
)
return
file_ext = os.path.splitext(self.action_config_file_name)[1].lower()
if file_ext == ".toml":
with open(config_file_path, "r", encoding="utf-8") as f:
self.config = toml.load(f) or {}
logger.info(f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置已从 {config_file_path} 加载。")
else:
logger.warning(
f"{self.log_prefix} 不支持的插件配置文件格式: {file_ext}。仅支持 .toml。插件配置未加载。"
)
self.config = {} # 确保未加载时为空字典
return
except Exception as e:
logger.error(
f"{self.log_prefix} 加载插件 {self.__class__.__name__} 的配置文件 {self.action_config_file_name} 时出错: {e}"
)
self.config = {} # 出错时确保 config 是一个空字典
def get_global_config(self, key: str, default: Any = None) -> Any:
"""
安全地从全局配置中获取一个值
插件应使用此方法读取全局配置以保证只读和隔离性
"""
if self._global_config:
return self._global_config.get(key, default)
logger.debug(f"{self.log_prefix} 尝试访问全局配置项 '{key}',但全局配置未提供。")
return default
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]: async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
"""根据用户名获取用户ID""" """根据用户名获取用户ID"""
@ -41,7 +111,7 @@ class PluginAction(BaseAction):
return platform, user_id return platform, user_id
# 提供简化的API方法 # 提供简化的API方法
async def send_message(self, text: str, target: Optional[str] = None) -> bool: async def send_message(self, type: str, data: str, target: Optional[str] = "") -> bool:
"""发送消息的简化方法 """发送消息的简化方法
Args: Args:
@ -60,7 +130,7 @@ class PluginAction(BaseAction):
return False return False
# 构造简化的动作数据 # 构造简化的动作数据
reply_data = {"text": text, "target": target or "", "emojis": []} # reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有) # 获取锚定消息(如果有)
observations = self._services.get("observations", []) observations = self._services.get("observations", [])
@ -68,7 +138,8 @@ class PluginAction(BaseAction):
chatting_observation: ChattingObservation = next( chatting_observation: ChattingObservation = next(
obs for obs in observations if isinstance(obs, ChattingObservation) obs for obs in observations if isinstance(obs, ChattingObservation)
) )
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
anchor_message = chatting_observation.search_message_by_text(target)
# 如果没有找到锚点消息,创建一个占位符 # 如果没有找到锚点消息,创建一个占位符
if not anchor_message: if not anchor_message:
@ -80,7 +151,7 @@ class PluginAction(BaseAction):
anchor_message.update_chat_stream(chat_stream) anchor_message.update_chat_stream(chat_stream)
response_set = [ response_set = [
("text", text), (type, data),
] ]
# 调用内部方法发送消息 # 调用内部方法发送消息

View File

@ -2,7 +2,7 @@ import json # <--- 确保导入 json
import traceback import traceback
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from rich.traceback import install from rich.traceback import install
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info.obs_info import ObsInfo from src.chat.focus_chat.info.obs_info import ObsInfo
@ -10,6 +10,7 @@ from src.chat.focus_chat.info.cycle_info import CycleInfo
from src.chat.focus_chat.info.mind_info import MindInfo from src.chat.focus_chat.info.mind_info import MindInfo
from src.chat.focus_chat.info.action_info import ActionInfo from src.chat.focus_chat.info.action_info import ActionInfo
from src.chat.focus_chat.info.structured_info import StructuredInfo from src.chat.focus_chat.info.structured_info import StructuredInfo
from src.chat.focus_chat.info.self_info import SelfInfo
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.individuality.individuality import individuality from src.individuality.individuality import individuality
@ -22,7 +23,11 @@ install(extra_lines=3)
def init_prompt(): def init_prompt():
Prompt( Prompt(
"""{extra_info_block} """
你的自我认知是
{self_info_block}
{extra_info_block}
你需要基于以下信息决定如何参与对话 你需要基于以下信息决定如何参与对话
这些信息可能会有冲突请你整合这些信息并选择一个最合适的action 这些信息可能会有冲突请你整合这些信息并选择一个最合适的action
@ -127,6 +132,8 @@ class ActionPlanner:
current_mind = info.get_current_mind() current_mind = info.get_current_mind()
elif isinstance(info, CycleInfo): elif isinstance(info, CycleInfo):
cycle_info = info.get_observe_info() cycle_info = info.get_observe_info()
elif isinstance(info, SelfInfo):
self_info = info.get_processed_info()
elif isinstance(info, StructuredInfo): elif isinstance(info, StructuredInfo):
_structured_info = info.get_data() _structured_info = info.get_data()
elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo
@ -148,6 +155,7 @@ class ActionPlanner:
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) --- # --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
prompt = await self.build_planner_prompt( prompt = await self.build_planner_prompt(
self_info_block=self_info,
is_group_chat=is_group_chat, # <-- Pass HFC state is_group_chat=is_group_chat, # <-- Pass HFC state
chat_target_info=None, chat_target_info=None,
observed_messages_str=observed_messages_str, # <-- Pass local variable observed_messages_str=observed_messages_str, # <-- Pass local variable
@ -236,6 +244,7 @@ class ActionPlanner:
async def build_planner_prompt( async def build_planner_prompt(
self, self,
self_info_block: str,
is_group_chat: bool, # Now passed as argument is_group_chat: bool, # Now passed as argument
chat_target_info: Optional[dict], # Now passed as argument chat_target_info: Optional[dict], # Now passed as argument
observed_messages_str: str, observed_messages_str: str,
@ -301,7 +310,8 @@ class ActionPlanner:
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format( prompt = planner_prompt_template.format(
bot_name=global_config.bot.nickname, self_info_block=self_info_block,
# bot_name=global_config.bot.nickname,
prompt_personality=personality_block, prompt_personality=personality_block,
chat_context_description=chat_context_description, chat_context_description=chat_context_description,
chat_content_block=chat_content_block, chat_content_block=chat_content_block,

View File

@ -3,7 +3,7 @@ import traceback
from json_repair import repair_json from json_repair import repair_json
from rich.traceback import install from rich.traceback import install
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.chat.focus_chat.working_memory.memory_item import MemoryItem from src.chat.focus_chat.working_memory.memory_item import MemoryItem
import json # 添加json模块导入 import json # 添加json模块导入

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
import traceback import traceback
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (

View File

@ -88,5 +88,6 @@ class HFCloopObservation:
for action_name, action_info in using_actions.items(): for action_name, action_info in using_actions.items():
action_description = action_info["description"] action_description = action_info["description"]
cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n" cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
cycle_info_block += "注意,除了上述动作选项之外,你在群聊里不能做其他任何事情,这是你能力的边界\n"
self.observe_info = cycle_info_block self.observe_info = cycle_info_block

View File

@ -2,7 +2,7 @@ import asyncio
from typing import Optional, Tuple, Dict from typing import Optional, Tuple, Dict
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
from src.chat.person_info.person_info import person_info_manager from src.person_info.person_info import person_info_manager
logger = get_logger("heartflow_utils") logger = get_logger("heartflow_utils")

View File

@ -11,7 +11,7 @@ import jieba
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from collections import Counter from collections import Counter
from ...chat.models.utils_model import LLMRequest from ...llm_models.utils_model import LLMRequest
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from ..utils.chat_message_builder import ( from ..utils.chat_message_builder import (
@ -338,7 +338,8 @@ class Hippocampus:
# 去重 # 去重
keywords = list(set(keywords)) keywords = list(set(keywords))
# 限制关键词数量 # 限制关键词数量
keywords = keywords[:5] logger.debug(f"提取关键词: {keywords}")
else: else:
# 使用LLM提取关键词 # 使用LLM提取关键词
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
@ -361,7 +362,7 @@ class Hippocampus:
# 过滤掉不存在于记忆图中的关键词 # 过滤掉不存在于记忆图中的关键词
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
if not valid_keywords: if not valid_keywords:
# logger.info("没有找到有效的关键词节点") logger.info("没有找到有效的关键词节点")
return [] return []
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}") logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")

View File

@ -1,8 +1,8 @@
from ..emoji_system.emoji_manager import emoji_manager from src.chat.emoji_system.emoji_manager import emoji_manager
from ..person_info.relationship_manager import relationship_manager from src.person_info.relationship_manager import relationship_manager
from .chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
from .message_sender import message_manager from src.chat.message_receive.message_sender import message_manager
from .storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
__all__ = [ __all__ = [

View File

@ -77,6 +77,7 @@ class ChatBot:
message = MessageRecv(message_data) message = MessageRecv(message_data)
group_info = message.message_info.group_info group_info = message.message_info.group_info
user_info = message.message_info.user_info user_info = message.message_info.user_info
chat_manager.register_message(message)
# 确认从接口发来的message是否有自定义的prompt模板信息 # 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default: if message.message_info.template_info and not message.message_info.template_info.template_default:
@ -86,7 +87,7 @@ class ChatBot:
if isinstance(template_items, dict): if isinstance(template_items, dict):
for k in template_items.keys(): for k in template_items.keys():
await Prompt.create_async(template_items[k], k) await Prompt.create_async(template_items[k], k)
print(f"注册{template_items[k]},{k}") logger.debug(f"注册{template_items[k]},{k}")
else: else:
template_group_name = None template_group_name = None

View File

@ -2,13 +2,17 @@ import asyncio
import hashlib import hashlib
import time import time
import copy import copy
from typing import Dict, Optional from typing import Dict, Optional, TYPE_CHECKING
from ...common.database.database import db from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入 from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install from rich.traceback import install
@ -18,6 +22,23 @@ install(extra_lines=3)
logger = get_logger("chat_stream") logger = get_logger("chat_stream")
class ChatMessageContext:
"""聊天消息上下文,存储消息的上下文信息"""
def __init__(self, message: "MessageRecv"):
self.message = message
def get_template_name(self) -> str:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
return None
def get_last_message(self) -> "MessageRecv":
"""获取最后一条消息"""
return self.message
class ChatStream: class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文""" """聊天流对象,存储一个完整的聊天上下文"""
@ -36,6 +57,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time() self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False self.saved = False
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""转换为字典格式""" """转换为字典格式"""
@ -67,6 +89,10 @@ class ChatStream:
self.last_active_time = time.time() self.last_active_time = time.time()
self.saved = False self.saved = False
def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
self.context = ChatMessageContext(message)
class ChatManager: class ChatManager:
"""聊天管理器,管理所有聊天流""" """聊天管理器,管理所有聊天流"""
@ -82,6 +108,7 @@ class ChatManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
try: try:
db.connect(reuse_if_open=True) db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在 # 确保 ChatStreams 表存在
@ -113,6 +140,16 @@ class ChatManager:
except Exception as e: except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}") logger.error(f"聊天流自动保存失败: {str(e)}")
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
stream_id = self._generate_stream_id(
message.message_info.platform,
message.message_info.user_info,
message.message_info.group_info,
)
self.last_messages[stream_id] = message
logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod @staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID""" """生成聊天流唯一ID"""
@ -146,12 +183,19 @@ class ChatManager:
# 检查内存中是否存在 # 检查内存中是否存在
if stream_id in self.streams: if stream_id in self.streams:
stream = self.streams[stream_id] stream = self.streams[stream_id]
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.update_active_time() stream.update_active_time()
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
stream.group_info = group_info stream.group_info = group_info
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
@ -202,14 +246,26 @@ class ChatManager:
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e raise e
stream = copy.deepcopy(stream)
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 保存到内存和数据库 # 保存到内存和数据库
self.streams[stream_id] = stream self.streams[stream_id] = stream
await self._save_stream(stream) await self._save_stream(stream)
return copy.deepcopy(stream) return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]: def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流""" """通过stream_id获取聊天流"""
return self.streams.get(stream_id) stream = self.streams.get(stream_id)
if not stream:
return None
if stream_id in self.last_messages:
stream.set_context(self.last_messages[stream_id])
return stream
def get_stream_by_info( def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
@ -306,6 +362,8 @@ class ChatManager:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
stream.saved = True stream.saved = True
self.streams[stream.stream_id] = stream self.streams[stream.stream_id] = stream
if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id])
except Exception as e: except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)

View File

@ -1,12 +1,14 @@
import time import time
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any from typing import Optional, Any, TYPE_CHECKING
import urllib3 import urllib3
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from .chat_stream import ChatStream
if TYPE_CHECKING:
from .chat_stream import ChatStream
from ..utils.utils_image import image_manager from ..utils.utils_image import image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from rich.traceback import install from rich.traceback import install
@ -25,7 +27,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@dataclass @dataclass
class Message(MessageBase): class Message(MessageBase):
chat_stream: ChatStream = None chat_stream: "ChatStream" = None
reply: Optional["Message"] = None reply: Optional["Message"] = None
detailed_plain_text: str = "" detailed_plain_text: str = ""
processed_plain_text: str = "" processed_plain_text: str = ""
@ -34,7 +36,7 @@ class Message(MessageBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: "ChatStream",
user_info: UserInfo, user_info: UserInfo,
message_segment: Optional[Seg] = None, message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None, timestamp: Optional[float] = None,
@ -111,7 +113,7 @@ class MessageRecv(Message):
self.detailed_plain_text = "" # 初始化为空字符串 self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji = False self.is_emoji = False
def update_chat_stream(self, chat_stream: ChatStream): def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream self.chat_stream = chat_stream
async def process(self) -> None: async def process(self) -> None:
@ -165,7 +167,7 @@ class MessageProcessBase(Message):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: "ChatStream",
bot_user_info: UserInfo, bot_user_info: UserInfo,
message_segment: Optional[Seg] = None, message_segment: Optional[Seg] = None,
reply: Optional["MessageRecv"] = None, reply: Optional["MessageRecv"] = None,
@ -241,7 +243,7 @@ class MessageThinking(MessageProcessBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: "ChatStream",
bot_user_info: UserInfo, bot_user_info: UserInfo,
reply: Optional["MessageRecv"] = None, reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0, thinking_start_time: float = 0,
@ -269,7 +271,7 @@ class MessageSending(MessageProcessBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: "ChatStream",
bot_user_info: UserInfo, bot_user_info: UserInfo,
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复 sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
message_segment: Seg, message_segment: Seg,
@ -353,7 +355,7 @@ class MessageSending(MessageProcessBase):
class MessageSet: class MessageSet:
"""消息集合类,可以存储多个发送消息""" """消息集合类,可以存储多个发送消息"""
def __init__(self, chat_stream: ChatStream, message_id: str): def __init__(self, chat_stream: "ChatStream", message_id: str):
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.message_id = message_id self.message_id = message_id
self.messages: list[MessageSending] = [] self.messages: list[MessageSending] = []

View File

@ -11,9 +11,10 @@ from src.common.logger_manager import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.manager.mood_manager import mood_manager from src.manager.mood_manager import mood_manager
from src.chat.message_receive.chat_stream import ChatStream, chat_manager from src.chat.message_receive.chat_stream import ChatStream, chat_manager
from src.chat.person_info.relationship_manager import relationship_manager from src.person_info.relationship_manager import relationship_manager
from src.chat.utils.info_catcher import info_catcher_manager from src.chat.utils.info_catcher import info_catcher_manager
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.utils.prompt_builder import global_prompt_manager
from .normal_chat_generator import NormalChatGenerator from .normal_chat_generator import NormalChatGenerator
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
from src.chat.message_receive.message_sender import message_manager from src.chat.message_receive.message_sender import message_manager
@ -189,30 +190,31 @@ class NormalChat:
通常由start_monitoring_interest()启动 通常由start_monitoring_interest()启动
""" """
while True: while True:
await asyncio.sleep(0.5) # 每秒检查一次 async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
# 检查任务是否已被取消 await asyncio.sleep(0.5) # 每秒检查一次
if self._chat_task is None or self._chat_task.cancelled(): # 检查任务是否已被取消
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出") if self._chat_task is None or self._chat_task.cancelled():
break logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
break
items_to_process = list(self.interest_dict.items()) items_to_process = list(self.interest_dict.items())
if not items_to_process: if not items_to_process:
continue continue
# 处理每条兴趣消息 # 处理每条兴趣消息
for msg_id, (message, interest_value, is_mentioned) in items_to_process: for msg_id, (message, interest_value, is_mentioned) in items_to_process:
try: try:
# 处理消息 # 处理消息
await self.normal_response( await self.normal_response(
message=message, message=message,
is_mentioned=is_mentioned, is_mentioned=is_mentioned,
interested_rate=interest_value, interested_rate=interest_value,
rewind_response=False, rewind_response=False,
) )
except Exception as e: except Exception as e:
logger.error(f"[{self.stream_name}] 处理兴趣消息{msg_id}时出错: {e}\n{traceback.format_exc()}") logger.error(f"[{self.stream_name}] 处理兴趣消息{msg_id}时出错: {e}\n{traceback.format_exc()}")
finally: finally:
self.interest_dict.pop(msg_id, None) self.interest_dict.pop(msg_id, None)
# 改为实例方法, 移除 chat 参数 # 改为实例方法, 移除 chat 参数
async def normal_response( async def normal_response(

View File

@ -1,8 +1,8 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import random import random
from ..models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from ...config.config import global_config from src.config.config import global_config
from ..message_receive.message import MessageThinking from src.chat.message_receive.message import MessageThinking
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
from src.chat.utils.utils import process_llm_response from src.chat.utils.utils import process_llm_response
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager from src.person_info.person_info import person_info_manager, PersonInfoManager
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import importlib import importlib
from typing import Dict, Optional from typing import Dict, Optional

View File

@ -4,7 +4,7 @@ import time # 导入 time 模块以获取当前时间
import random import random
import re import re
from src.common.message_repository import find_messages, count_messages from src.common.message_repository import find_messages, count_messages
from src.chat.person_info.person_info import person_info_manager from src.person_info.person_info import person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable from src.chat.utils.utils import translate_timestamp_to_human_readable

View File

@ -2,6 +2,7 @@ from typing import Dict, Any, Optional, List, Union
import re import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncio import asyncio
import contextvars
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
# import traceback # import traceback
@ -15,29 +16,59 @@ logger = get_module_logger("prompt_build")
class PromptContext: class PromptContext:
def __init__(self): def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context: Optional[str] = None # 使用contextvars创建协程上下文变量
self._context_lock = asyncio.Lock() # 添加异步锁 self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value)
@asynccontextmanager @asynccontextmanager
async def async_scope(self, context_id: str): async def async_scope(self, context_id: Optional[str] = None):
"""创建一个异步的临时提示模板作用域""" """创建一个异步的临时提示模板作用域"""
async with self._context_lock: # 保存当前上下文并设置新上下文
if context_id not in self._context_prompts: if context_id is not None:
self._context_prompts[context_id] = {} async with self._context_lock:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
# 保存当前协程的上下文值,不影响其他协程
previous_context = self._current_context previous_context = self._current_context
self._current_context = context_id # 设置当前协程的新上下文
token = self._current_context_var.set(context_id)
else:
# 如果没有提供新上下文,保持当前上下文不变
previous_context = self._current_context
token = None
try: try:
yield self yield self
finally: finally:
async with self._context_lock: # 恢复之前的上下文
self._current_context = previous_context if context_id is not None:
if token:
self._current_context_var.reset(token)
else:
self._current_context = previous_context
async def get_prompt_async(self, name: str) -> Optional["Prompt"]: async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板""" """异步获取当前作用域中的提示模板"""
async with self._context_lock: async with self._context_lock:
if self._current_context and name in self._context_prompts[self._current_context]: current_context = self._current_context
return self._context_prompts[self._current_context][name] logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
@ -56,8 +87,8 @@ class PromptManager:
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
@asynccontextmanager @asynccontextmanager
async def async_message_scope(self, message_id: str): async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域""" """为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
async with self._context.async_scope(message_id): async with self._context.async_scope(message_id):
yield self yield self
@ -65,9 +96,11 @@ class PromptManager:
# 首先尝试从当前上下文获取 # 首先尝试从当前上下文获取
context_prompt = await self._context.get_prompt_async(name) context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None: if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt return context_prompt
# 如果上下文中不存在,则使用全局提示模板 # 如果上下文中不存在,则使用全局提示模板
async with self._lock: async with self._lock:
logger.debug(f"从全局获取提示词: {name}")
if name not in self._prompts: if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found") raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name] return self._prompts[name]

View File

@ -10,7 +10,7 @@ from maim_message import UserInfo
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.manager.mood_manager import mood_manager from src.manager.mood_manager import mood_manager
from ..message_receive.message import MessageRecv from ..message_receive.message import MessageRecv
from ..models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator from .typo_generator import ChineseTypoGenerator
from ...config.config import global_config from ...config.config import global_config
from ...common.message_repository import find_messages, count_messages from ...common.message_repository import find_messages, count_messages

View File

@ -8,10 +8,10 @@ import io
import numpy as np import numpy as np
from ...common.database.database import db from src.common.database.database import db
from ...common.database.database_model import Images, ImageDescriptions from src.common.database.database_model import Images, ImageDescriptions
from ...config.config import global_config from src.config.config import global_config
from ..models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install from rich.traceback import install

View File

@ -869,6 +869,23 @@ API_SERVER_STYLE_CONFIG = {
}, },
} }
# maim_message 消息服务样式配置
MAIM_MESSAGE_STYLE_CONFIG = {
"advanced": {
"console_format": (
"<white>{time:YYYY-MM-DD HH:mm:ss}</white> | "
"<level>{level: <8}</level> | "
"<fg #00B2FF>消息服务</fg #00B2FF> | "
"<level>{message}</level>"
),
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息服务 | {message}",
},
"simple": {
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #00B2FF>消息服务</fg #00B2FF> | {message}",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息服务 | {message}",
},
}
# 根据SIMPLE_OUTPUT选择配置 # 根据SIMPLE_OUTPUT选择配置
MAIN_STYLE_CONFIG = MAIN_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIN_STYLE_CONFIG["advanced"] MAIN_STYLE_CONFIG = MAIN_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIN_STYLE_CONFIG["advanced"]
@ -946,6 +963,9 @@ CHAT_MESSAGE_STYLE_CONFIG = (
CHAT_IMAGE_STYLE_CONFIG = CHAT_IMAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_IMAGE_STYLE_CONFIG["advanced"] CHAT_IMAGE_STYLE_CONFIG = CHAT_IMAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_IMAGE_STYLE_CONFIG["advanced"]
INIT_STYLE_CONFIG = INIT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INIT_STYLE_CONFIG["advanced"] INIT_STYLE_CONFIG = INIT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INIT_STYLE_CONFIG["advanced"]
API_SERVER_STYLE_CONFIG = API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"] API_SERVER_STYLE_CONFIG = API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"]
MAIM_MESSAGE_STYLE_CONFIG = (
MAIM_MESSAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIM_MESSAGE_STYLE_CONFIG["advanced"]
)
INTEREST_CHAT_STYLE_CONFIG = ( INTEREST_CHAT_STYLE_CONFIG = (
INTEREST_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INTEREST_CHAT_STYLE_CONFIG["advanced"] INTEREST_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INTEREST_CHAT_STYLE_CONFIG["advanced"]
) )

View File

@ -1,6 +1,27 @@
from src.common.server import global_server from src.common.server import global_server
import os import os
import importlib.metadata
from maim_message import MessageServer from maim_message import MessageServer
from src.common.logger_manager import get_logger
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app()) # 检查maim_message版本
try:
maim_message_version = importlib.metadata.version("maim_message")
version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 0]
except (importlib.metadata.PackageNotFoundError, ValueError):
version_compatible = False
# 根据版本决定是否使用自定义logger
kwargs = {
"host": os.environ["HOST"],
"port": int(os.environ["PORT"]),
"app": global_server.get_app(),
}
# 只有在版本 >= 0.3.0 时才使用自定义logger
if version_compatible:
maim_message_logger = get_logger("maim_message")
kwargs["custom_logger"] = maim_message_logger
global_api = MessageServer(**kwargs)

View File

@ -68,30 +68,30 @@ class TelemetryHeartBeatTask(AsyncTask):
response = requests.post( response = requests.post(
f"{TELEMETRY_SERVER_URL}/stat/reg_client", f"{TELEMETRY_SERVER_URL}/stat/reg_client",
json={"deploy_time": local_storage["deploy_time"]}, json={"deploy_time": local_storage["deploy_time"]},
timeout=5, # 设置超时时间为5秒
) )
except Exception as e:
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
logger.debug(local_storage["deploy_time"])
logger.debug(response)
if response.status_code == 200:
data = response.json()
client_id = data.get("mmc_uuid")
if client_id:
# 将UUID存储到本地
local_storage["mmc_uuid"] = client_id
self.client_uuid = client_id
logger.info(f"成功获取UUID: {self.client_uuid}")
return True # 成功获取UUID返回True
else:
logger.error("无效的服务端响应")
else:
logger.error(f"请求UUID失败状态码: {response.status_code}, 响应内容: {response.text}")
except requests.RequestException as e:
logger.error(f"请求UUID时出错: {e}") # 可能是网络问题 logger.error(f"请求UUID时出错: {e}") # 可能是网络问题
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
logger.debug(local_storage["deploy_time"])
logger.debug(response)
if response.status_code == 200:
data = response.json()
if client_id := data.get("mmc_uuid"):
# 将UUID存储到本地
local_storage["mmc_uuid"] = client_id
self.client_uuid = client_id
logger.info(f"成功获取UUID: {self.client_uuid}")
return True # 成功获取UUID返回True
else:
logger.error("无效的服务端响应")
else:
logger.error(f"请求UUID失败状态码: {response.status_code}, 响应内容: {response.text}")
# 请求失败,重试次数+1 # 请求失败,重试次数+1
try_count += 1 try_count += 1
if try_count > 3: if try_count > 3:
@ -100,47 +100,48 @@ class TelemetryHeartBeatTask(AsyncTask):
return False return False
else: else:
# 如果可以重试,等待后继续(指数退避) # 如果可以重试,等待后继续(指数退避)
logger.info(f"获取UUID失败将于 {4**try_count} 秒后重试...")
await asyncio.sleep(4**try_count) await asyncio.sleep(4**try_count)
async def _send_heartbeat(self): async def _send_heartbeat(self):
"""向服务器发送心跳""" """向服务器发送心跳"""
headers = {
"Client-UUID": self.client_uuid,
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}",
}
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
logger.debug(headers)
try: try:
headers = {
"Client-UUID": self.client_uuid,
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}",
}
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
logger.debug(headers)
response = requests.post( response = requests.post(
f"{self.server_url}/stat/client_heartbeat", f"{self.server_url}/stat/client_heartbeat",
headers=headers, headers=headers,
json=self.info_dict, json=self.info_dict,
timeout=5, # 设置超时时间为5秒
) )
except Exception as e:
logger.debug(response)
# 处理响应
if 200 <= response.status_code < 300:
# 成功
logger.debug(f"心跳发送成功,状态码: {response.status_code}")
elif response.status_code == 403:
# 403 Forbidden
logger.error(
"心跳发送失败403 Forbidden: 可能是UUID无效或未注册。"
"处理措施重置UUID下次发送心跳时将尝试重新注册。"
)
self.client_uuid = None
del local_storage["mmc_uuid"] # 删除本地存储的UUID
else:
# 其他错误
logger.error(f"心跳发送失败,状态码: {response.status_code}, 响应内容: {response.text}")
except requests.RequestException as e:
logger.error(f"心跳发送失败: {e}") logger.error(f"心跳发送失败: {e}")
logger.debug(response)
# 处理响应
if 200 <= response.status_code < 300:
# 成功
logger.debug(f"心跳发送成功,状态码: {response.status_code}")
elif response.status_code == 403:
# 403 Forbidden
logger.error(
"心跳发送失败403 Forbidden: 可能是UUID无效或未注册。"
"处理措施重置UUID下次发送心跳时将尝试重新注册。"
)
self.client_uuid = None
del local_storage["mmc_uuid"] # 删除本地存储的UUID
else:
# 其他错误
logger.error(f"心跳发送失败,状态码: {response.status_code}, 响应内容: {response.text}")
async def run(self): async def run(self):
# 发送心跳 # 发送心跳
if global_config.telemetry.enable: if global_config.telemetry.enable:

View File

@ -1,7 +1,7 @@
import time import time
from typing import Tuple, Optional # 增加了 Optional from typing import Tuple, Optional # 增加了 Optional
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.experimental.PFC.chat_observer import ChatObserver from src.experimental.PFC.chat_observer import ChatObserver
from src.experimental.PFC.pfc_utils import get_items_from_json from src.experimental.PFC.pfc_utils import get_items_from_json

View File

@ -1,6 +1,6 @@
from typing import List, Tuple, TYPE_CHECKING from typing import List, Tuple, TYPE_CHECKING
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.experimental.PFC.chat_observer import ChatObserver from src.experimental.PFC.chat_observer import ChatObserver
from src.experimental.PFC.pfc_utils import get_items_from_json from src.experimental.PFC.pfc_utils import get_items_from_json

View File

@ -1,7 +1,7 @@
from typing import List, Tuple from typing import List, Tuple
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.message import Message from src.chat.message_receive.message import Message
from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.knowledge.knowledge_lib import qa_manager

View File

@ -1,7 +1,7 @@
import json import json
from typing import Tuple, List, Dict, Any from typing import Tuple, List, Dict, Any
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.experimental.PFC.chat_observer import ChatObserver from src.experimental.PFC.chat_observer import ChatObserver
from maim_message import UserInfo from maim_message import UserInfo

View File

@ -1,6 +1,6 @@
from typing import Tuple, List, Dict, Any from typing import Tuple, List, Dict, Any
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.experimental.PFC.chat_observer import ChatObserver from src.experimental.PFC.chat_observer import ChatObserver
from src.experimental.PFC.reply_checker import ReplyChecker from src.experimental.PFC.reply_checker import ReplyChecker

View File

@ -1,6 +1,6 @@
import random import random
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from typing import List, Tuple from typing import List, Tuple

View File

@ -3,10 +3,8 @@ import json
import re import re
from datetime import datetime from datetime import datetime
from typing import Tuple, Union, Dict, Any from typing import Tuple, Union, Dict, Any
import aiohttp import aiohttp
from aiohttp.client import ClientResponse from aiohttp.client import ClientResponse
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
import base64 import base64
from PIL import Image from PIL import Image
@ -14,7 +12,7 @@ import io
import os import os
from src.common.database.database import db # 确保 db 被导入用于 create_tables from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from ...config.config import global_config from src.config.config import global_config
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)

View File

@ -6,7 +6,7 @@ from .manager.async_task_manager import async_task_manager
from .chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from .chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from .manager.mood_manager import MoodPrintTask, MoodUpdateTask from .manager.mood_manager import MoodPrintTask, MoodUpdateTask
from .chat.emoji_system.emoji_manager import emoji_manager from .chat.emoji_system.emoji_manager import emoji_manager
from .chat.person_info.person_info import person_info_manager from .person_info.person_info import person_info_manager
from .chat.normal_chat.willing.willing_manager import willing_manager from .chat.normal_chat.willing.willing_manager import willing_manager
from .chat.message_receive.chat_stream import chat_manager from .chat.message_receive.chat_stream import chat_manager
from src.chat.heart_flow.heartflow import heartflow from src.chat.heart_flow.heartflow import heartflow

View File

@ -22,13 +22,21 @@ class LocalStoreManager:
def __getitem__(self, item: str) -> str | list | dict | int | float | bool | None: def __getitem__(self, item: str) -> str | list | dict | int | float | bool | None:
"""获取本地存储数据""" """获取本地存储数据"""
return self.store.get(item, None) return self.store.get(item)
def __setitem__(self, key: str, value: str | list | dict | int | float | bool): def __setitem__(self, key: str, value: str | list | dict | int | float | bool):
"""设置本地存储数据""" """设置本地存储数据"""
self.store[key] = value self.store[key] = value
self.save_local_store() self.save_local_store()
def __delitem__(self, key: str):
"""删除本地存储数据"""
if key in self.store:
del self.store[key]
self.save_local_store()
else:
logger.warning(f"尝试删除不存在的键: {key}")
def __contains__(self, item: str) -> bool: def __contains__(self, item: str) -> bool:
"""检查本地存储数据是否存在""" """检查本地存储数据是否存在"""
return item in self.store return item in self.store

View File

@ -1,13 +1,13 @@
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from ...common.database.database import db from src.common.database.database import db
from ...common.database.database_model import PersonInfo # 新增导入 from src.common.database.database_model import PersonInfo # 新增导入
import copy import copy
import hashlib import hashlib
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
import datetime import datetime
import asyncio import asyncio
import numpy as np import numpy as np
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.individuality.individuality import individuality from src.individuality.individuality import individuality

View File

@ -1,13 +1,13 @@
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from ..message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
import math import math
from bson.decimal128 import Decimal128 from bson.decimal128 import Decimal128
from .person_info import person_info_manager from src.person_info.person_info import person_info_manager
import time import time
import random import random
from maim_message import UserInfo from maim_message import UserInfo
from ...manager.mood_manager import mood_manager from src.manager.mood_manager import mood_manager
# import re # import re
# import traceback # import traceback

View File

@ -1,7 +1,7 @@
"""测试插件动作模块""" """测试插件动作模块"""
# 导入所有动作模块以确保装饰器被执行 # 导入所有动作模块以确保装饰器被执行
# from . import test_action # noqa from . import test_action # noqa
# from . import online_action # noqa from . import online_action # noqa
# from . import mute_action # noqa from . import mute_action # noqa

View File

@ -25,7 +25,7 @@ class MuteAction(PluginAction):
"当千石可乐或可乐酱要求你禁言时使用", "当千石可乐或可乐酱要求你禁言时使用",
"当你想回避某个话题时使用", "当你想回避某个话题时使用",
] ]
default = True # 不是默认动作,需要手动添加到使用集 default = False # 不是默认动作,需要手动添加到使用集
async def process(self) -> Tuple[bool, str]: async def process(self) -> Tuple[bool, str]:
"""处理测试动作""" """处理测试动作"""
@ -40,7 +40,11 @@ class MuteAction(PluginAction):
await self.send_message_by_expressor(f"我要禁言{target}{platform},时长{duration}秒,理由{reason},表达情绪") await self.send_message_by_expressor(f"我要禁言{target}{platform},时长{duration}秒,理由{reason},表达情绪")
try: try:
await self.send_message(f"[command]mute,{user_id},{duration}") await self.send_message(
type="text",
data=f"[command]mute,{user_id},{duration}",
# target = target
)
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}") logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}")

View File

@ -17,7 +17,7 @@ class CheckOnlineAction(PluginAction):
"mode参数为version时查看在线版本状态默认用这种", "mode参数为version时查看在线版本状态默认用这种",
"mode参数为type时查看在线系统类型分布", "mode参数为type时查看在线系统类型分布",
] ]
default = True # 不是默认动作,需要手动添加到使用集 default = False # 不是默认动作,需要手动添加到使用集
async def process(self) -> Tuple[bool, str]: async def process(self) -> Tuple[bool, str]:
"""处理测试动作""" """处理测试动作"""

View File

@ -0,0 +1,5 @@
"""测试插件包:图片发送"""
"""
这是一个测试插件用于测试图片发送功能
"""

View File

@ -0,0 +1,4 @@
"""测试插件动作模块"""
# 导入所有动作模块以确保装饰器被执行
from . import pic_action # noqa

View File

@ -0,0 +1,50 @@
import os
CONFIG_CONTENT = """\
# 请替换为您的火山引擎 Access Key ID
volcano_ak = "YOUR_VOLCANO_ENGINE_ACCESS_KEY_ID_HERE"
# 请替换为您的火山引擎 Secret Access Key
volcano_sk = "YOUR_VOLCANO_ENGINE_SECRET_ACCESS_KEY_HERE"
# 火山方舟 API 的基础 URL
base_url = "https://ark.cn-beijing.volces.com/api/v3"
# 默认图片生成模型
default_model = "doubao-seedream-3-0-t2i-250415"
# 默认图片尺寸
default_size = "1024x1024"
# 用于图片生成的API密钥
# PicAction 当前配置为在HTTP请求体和Authorization头中使用此密钥。
# 如果您的API认证方式不同请相应调整或移除。
volcano_generate_api_key = "YOUR_VOLCANO_GENERATE_API_KEY_HERE"
# 是否默认开启水印
default_watermark = true
# 默认引导强度
default_guidance_scale = 2.5
# 默认随机种子
default_seed = 42
# 更多插件特定配置可以在此添加...
# custom_parameter = "some_value"
"""
def generate_config():
# 获取当前脚本所在的目录
current_dir = os.path.dirname(os.path.abspath(__file__))
config_file_path = os.path.join(current_dir, "pic_action_config.toml")
if not os.path.exists(config_file_path):
try:
with open(config_file_path, "w", encoding="utf-8") as f:
f.write(CONFIG_CONTENT)
print(f"配置文件已生成: {config_file_path}")
print("请记得编辑该文件,填入您的火山引擎 AK/SK 和 API 密钥。")
except IOError as e:
print(f"错误:无法写入配置文件 {config_file_path}。原因: {e}")
else:
print(f"配置文件已存在: {config_file_path}")
print("未进行任何更改。如果您想重新生成,请先删除或重命名现有文件。")
if __name__ == "__main__":
generate_config()

View File

@ -0,0 +1,264 @@
import asyncio
import json
import urllib.request
import urllib.error
import base64 # 新增用于Base64编码
import traceback # 新增:用于打印堆栈跟踪
from typing import Tuple
from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
from src.common.logger_manager import get_logger
from .generate_pic_config import generate_config
logger = get_logger("pic_action")
# 当此模块被加载时,尝试生成配置文件(如果它不存在)
# 注意:在某些插件加载机制下,这可能会在每次机器人启动或插件重载时执行
# 考虑是否需要更复杂的逻辑来决定何时运行 (例如,仅在首次安装时)
generate_config()
@register_action
class PicAction(PluginAction):
"""根据描述使用火山引擎HTTP API生成图片的动作处理类"""
action_name = "pic_action"
action_description = "可以根据特定的描述,使用火山引擎模型生成并发送一张图片 (通过HTTP API)"
action_parameters = {
"description": "图片描述,输入你想要生成并发送的图片的描述,必填",
"size": "图片尺寸,例如 '1024x1024' (可选, 默认从配置或 '1024x1024')",
}
action_require = [
"当有人要求你生成并发送一张图片时使用",
"当有人让你画一张图时使用",
]
default = False
action_config_file_name = "pic_action_config.toml"
def __init__(
self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
global_config: dict = None,
**kwargs,
):
super().__init__(action_data, reasoning, cycle_timers, thinking_id, global_config, **kwargs)
http_base_url = self.config.get("base_url")
http_api_key = self.config.get("volcano_generate_api_key")
if not (http_base_url and http_api_key):
logger.error(
f"{self.log_prefix} PicAction初始化, 但HTTP配置 (base_url 或 volcano_generate_api_key) 缺失. HTTP图片生成将失败."
)
else:
logger.info(f"{self.log_prefix} HTTP方式初始化完成. Base URL: {http_base_url}, API Key已配置.")
# _restore_env_vars 方法不再需要,已移除
async def process(self) -> Tuple[bool, str]:
"""处理图片生成动作通过HTTP API"""
logger.info(f"{self.log_prefix} 执行 pic_action (HTTP): {self.reasoning}")
http_base_url = self.config.get("base_url")
http_api_key = self.config.get("volcano_generate_api_key")
if not (http_base_url and http_api_key):
error_msg = "抱歉图片生成功能所需的HTTP配置如API地址或密钥不完整无法提供服务。"
await self.send_message_by_expressor(error_msg)
logger.error(f"{self.log_prefix} HTTP调用配置缺失: base_url 或 volcano_generate_api_key.")
return False, "HTTP配置不完整"
description = self.action_data.get("description")
if not description:
logger.warning(f"{self.log_prefix} 图片描述为空,无法生成图片。")
await self.send_message_by_expressor("你需要告诉我想要画什么样的图片哦~")
return False, "图片描述为空"
default_model = self.config.get("default_model", "doubao-seedream-3-0-t2i-250415")
image_size = self.action_data.get("size", self.config.get("default_size", "1024x1024"))
# guidance_scale 现在完全由配置文件控制
guidance_scale_input = self.config.get("default_guidance_scale", 2.5) # 默认2.5
guidance_scale_val = 2.5 # Fallback default
try:
guidance_scale_val = float(guidance_scale_input)
except (ValueError, TypeError):
logger.warning(
f"{self.log_prefix} 配置文件中的 default_guidance_scale 值 '{guidance_scale_input}' 无效 (应为浮点数),使用默认值 2.5。"
)
guidance_scale_val = 2.5
# Seed parameter - ensure it's always an integer
seed_config_value = self.config.get("default_seed")
seed_val = 42 # Default seed if not configured or invalid
if seed_config_value is not None:
try:
seed_val = int(seed_config_value)
except (ValueError, TypeError):
logger.warning(
f"{self.log_prefix} 配置文件中的 default_seed ('{seed_config_value}') 无效,将使用默认种子 42。"
)
# seed_val is already 42
else:
logger.info(
f"{self.log_prefix} 未在配置中找到 default_seed将使用默认种子 42。建议在配置文件中添加 default_seed。"
)
# seed_val is already 42
# Watermark 现在完全由配置文件控制
effective_watermark_source = self.config.get("default_watermark", True) # 默认True
if isinstance(effective_watermark_source, bool):
watermark_val = effective_watermark_source
elif isinstance(effective_watermark_source, str):
watermark_val = effective_watermark_source.lower() == "true"
else:
logger.warning(
f"{self.log_prefix} 配置文件中的 default_watermark 值 '{effective_watermark_source}' 无效 (应为布尔值或 'true'/'false'),使用默认值 True。"
)
watermark_val = True
await self.send_message_by_expressor(
f"收到!正在为您生成关于 '{description}' 的图片,请稍候...(模型: {default_model}, 尺寸: {image_size}"
)
try:
success, result = await asyncio.to_thread(
self._make_http_image_request,
prompt=description,
model=default_model,
size=image_size,
seed=seed_val,
guidance_scale=guidance_scale_val,
watermark=watermark_val,
)
except Exception as e:
logger.error(f"{self.log_prefix} (HTTP) 异步请求执行失败: {e!r}", exc_info=True)
traceback.print_exc()
success = False
result = f"图片生成服务遇到意外问题: {str(e)[:100]}"
if success:
image_url = result
logger.info(f"{self.log_prefix} 图片URL获取成功: {image_url[:70]}... 下载并编码.")
try:
encode_success, encode_result = await asyncio.to_thread(self._download_and_encode_base64, image_url)
except Exception as e:
logger.error(f"{self.log_prefix} (B64) 异步下载/编码失败: {e!r}", exc_info=True)
traceback.print_exc()
encode_success = False
encode_result = f"图片下载或编码时发生内部错误: {str(e)[:100]}"
if encode_success:
base64_image_string = encode_result
send_success = await self.send_message(type="emoji", data=base64_image_string)
if send_success:
await self.send_message_by_expressor("图片表情已发送!")
return True, "图片表情已发送"
else:
await self.send_message_by_expressor("图片已处理为Base64但作为表情发送失败了。")
return False, "图片表情发送失败 (Base64)"
else:
await self.send_message_by_expressor(f"获取到图片URL但在处理图片时失败了{encode_result}")
return False, f"图片处理失败(Base64): {encode_result}"
else:
error_message = result
await self.send_message_by_expressor(f"哎呀,生成图片时遇到问题:{error_message}")
return False, f"图片生成失败: {error_message}"
def _download_and_encode_base64(self, image_url: str) -> Tuple[bool, str]:
"""下载图片并将其编码为Base64字符串"""
logger.info(f"{self.log_prefix} (B64) 下载并编码图片: {image_url[:70]}...")
try:
with urllib.request.urlopen(image_url, timeout=30) as response:
if response.status == 200:
image_bytes = response.read()
base64_encoded_image = base64.b64encode(image_bytes).decode("utf-8")
logger.info(f"{self.log_prefix} (B64) 图片下载编码完成. Base64长度: {len(base64_encoded_image)}")
return True, base64_encoded_image
else:
error_msg = f"下载图片失败 (状态: {response.status})"
logger.error(f"{self.log_prefix} (B64) {error_msg} URL: {image_url}")
return False, error_msg
except Exception as e: # Catches all exceptions from urlopen, b64encode, etc.
logger.error(f"{self.log_prefix} (B64) 下载或编码时错误: {e!r}", exc_info=True)
traceback.print_exc()
return False, f"下载或编码图片时发生错误: {str(e)[:100]}"
def _make_http_image_request(
self, prompt: str, model: str, size: str, seed: int | None, guidance_scale: float, watermark: bool
) -> Tuple[bool, str]:
base_url = self.config.get("base_url")
generate_api_key = self.config.get("volcano_generate_api_key")
endpoint = f"{base_url.rstrip('/')}/images/generations"
payload_dict = {
"model": model,
"prompt": prompt,
"response_format": "url",
"size": size,
"guidance_scale": guidance_scale,
"watermark": watermark,
"seed": seed, # seed is now always an int from process()
"api-key": generate_api_key,
}
# if seed is not None: # No longer needed, seed is always an int
# payload_dict["seed"] = seed
data = json.dumps(payload_dict).encode("utf-8")
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {generate_api_key}",
}
logger.info(f"{self.log_prefix} (HTTP) 发起图片请求: {model}, Prompt: {prompt[:30]}... To: {endpoint}")
logger.debug(
f"{self.log_prefix} (HTTP) Request Headers: {{...Authorization: Bearer {generate_api_key[:10]}...}}"
)
logger.debug(
f"{self.log_prefix} (HTTP) Request Body (api-key omitted): {json.dumps({k: v for k, v in payload_dict.items() if k != 'api-key'})}"
)
req = urllib.request.Request(endpoint, data=data, headers=headers, method="POST")
try:
with urllib.request.urlopen(req, timeout=60) as response:
response_status = response.status
response_body_bytes = response.read()
response_body_str = response_body_bytes.decode("utf-8")
logger.info(f"{self.log_prefix} (HTTP) 响应: {response_status}. Preview: {response_body_str[:150]}...")
if 200 <= response_status < 300:
response_data = json.loads(response_body_str)
image_url = None
if (
isinstance(response_data.get("data"), list)
and response_data["data"]
and isinstance(response_data["data"][0], dict)
):
image_url = response_data["data"][0].get("url")
elif response_data.get("url"):
image_url = response_data.get("url")
if image_url:
logger.info(f"{self.log_prefix} (HTTP) 图片生成成功URL: {image_url[:70]}...")
return True, image_url
else:
logger.error(
f"{self.log_prefix} (HTTP) API成功但无图片URL. 响应预览: {response_body_str[:300]}..."
)
return False, "图片生成API响应成功但未找到图片URL"
else:
logger.error(
f"{self.log_prefix} (HTTP) API请求失败. 状态: {response.status}. 正文: {response_body_str[:300]}..."
)
return False, f"图片API请求失败(状态码 {response.status})"
except Exception as e:
logger.error(f"{self.log_prefix} (HTTP) 图片生成时意外错误: {e!r}", exc_info=True)
traceback.print_exc()
return False, f"图片生成HTTP请求时发生意外错误: {str(e)[:100]}"

View File

@ -0,0 +1,24 @@
# 请替换为您的火山引擎 Access Key ID
volcano_ak = "YOUR_VOLCANO_ENGINE_ACCESS_KEY_ID_HERE"
# 请替换为您的火山引擎 Secret Access Key
volcano_sk = "YOUR_VOLCANO_ENGINE_SECRET_ACCESS_KEY_HERE"
# 火山方舟 API 的基础 URL
base_url = "https://ark.cn-beijing.volces.com/api/v3"
# 默认图片生成模型
default_model = "doubao-seedream-3-0-t2i-250415"
# 默认图片尺寸
default_size = "1024x1024"
# 用于图片生成的API密钥
# PicAction 当前配置为在HTTP请求体和Authorization头中使用此密钥。
# 如果您的API认证方式不同请相应调整或移除。
volcano_generate_api_key = "YOUR_VOLCANO_GENERATE_API_KEY_HERE"
# 是否默认开启水印
default_watermark = true
# 默认引导强度
default_guidance_scale = 2.5
# 默认随机种子
default_seed = 42
# 更多插件特定配置可以在此添加...
# custom_parameter = "some_value"

View File

@ -1,5 +1,5 @@
from src.tools.tool_can_use.base_tool import BaseTool, register_tool from src.tools.tool_can_use.base_tool import BaseTool, register_tool
from src.chat.person_info.person_info import person_info_manager from src.person_info.person_info import person_info_manager
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
import time import time

View File

@ -1,4 +1,4 @@
from src.chat.models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
import json import json
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger