mirror of https://github.com/Mai-with-u/MaiBot.git
同步远程dev分支的更改
commit
d8f3338f38
|
|
@ -25,8 +25,8 @@
|
|||
|
||||
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
||||
|
||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,支持normal和focus统一化处理。
|
||||
- 🔌 **强大插件系统**:全面重构的插件架构,支持完整的管理API和权限控制。
|
||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制。
|
||||
- 🔌 **强大插件系统**:全面重构的插件架构,更多API。
|
||||
- 🤔 **实时思维系统**:模拟人类思考过程。
|
||||
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
|
||||
- 💝 **情感表达系统**:情绪系统和表情包系统。
|
||||
|
|
@ -46,7 +46,7 @@
|
|||
|
||||
## 🔥 更新和安装
|
||||
|
||||
**最新版本: v0.9.1** ([更新日志](changelogs/changelog.md))
|
||||
**最新版本: v0.10.0** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||
|
|
@ -56,7 +56,6 @@
|
|||
- `classical`: 旧版本(停止维护)
|
||||
|
||||
### 最新版本部署教程
|
||||
- [从0.6/0.7升级须知](https://docs.mai-mai.org/faq/maibot/update_to_07.html)
|
||||
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
|
||||
|
||||
> [!WARNING]
|
||||
|
|
@ -64,7 +63,6 @@
|
|||
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
|
||||
> - 文档未完善,有问题可以提交 Issue 或者 Discussion。
|
||||
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
|
||||
> - 由于持续迭代,可能存在一些已知或未知的 bug。
|
||||
> - 由于程序处于开发中,可能消耗较多 token。
|
||||
|
||||
## 💬 讨论
|
||||
|
|
|
|||
|
|
@ -1,16 +1,74 @@
|
|||
# Changelog
|
||||
|
||||
## [0.10.0] - 2025-7-1
|
||||
### 主要功能更改
|
||||
### 🌟 主要功能更改
|
||||
- 优化的回复生成,现在的回复对上下文把控更加精准
|
||||
- 新的回复逻辑控制,现在合并了normal和focus模式,更加统一
|
||||
- 优化表达方式系统,现在学习和使用更加精准
|
||||
- 新的关系系统,现在的关系构建更精准也更克制
|
||||
- 工具系统重构,现在合并到了插件系统中
|
||||
- 彻底重构了整个LLM Request了,现在支持模型轮询和更多灵活的参数
|
||||
- 同时重构了整个模型配置系统,升级需要重新配置llm配置文件
|
||||
- 随着LLM Request的重构,插件系统彻底重构完成。插件系统进入稳定状态,仅增加新的API
|
||||
- 具体相比于之前的更改可以查看[changes.md](./changes.md)
|
||||
- **警告所有插件开发者:插件系统即将迎来不稳定时期,随时会发动更改。**
|
||||
|
||||
#### 🔧 工具系统重构
|
||||
- **工具系统整合**: 工具系统现在完全合并到插件系统中,提供统一的扩展能力
|
||||
- **工具启用控制**: 支持配置是否启用特定工具,提供更人性化的直接调用方式
|
||||
- **配置文件读取**: 工具现在支持读取配置文件,增强配置灵活性
|
||||
|
||||
#### 🚀 LLM系统全面重构
|
||||
- **LLM Request重构**: 彻底重构了整个LLM Request系统,现在支持模型轮询和更多灵活的参数
|
||||
- **模型配置升级**: 同时重构了整个模型配置系统,升级需要重新配置llm配置文件
|
||||
- **任务类型支持**: 新增任务类型和能力字段至模型配置,增强模型初始化逻辑
|
||||
- **异常处理增强**: 增强LLMRequest类的异常处理,添加统一的模型异常处理方法
|
||||
|
||||
#### 🔌 插件系统稳定化
|
||||
- **插件系统重构完成**: 随着LLM Request的重构,插件系统彻底重构完成,进入稳定状态
|
||||
- **API扩展**: 仅增加新的API,保持向后兼容性
|
||||
- **插件管理优化**: 让插件管理配置真正有用,提升管理体验
|
||||
|
||||
#### 💾 记忆系统优化
|
||||
- **及时构建**: 记忆系统再优化,现在及时构建,并且不会重复构建
|
||||
- **精确提取**: 记忆提取更精确,提升记忆质量
|
||||
|
||||
#### 🎭 表达方式系统
|
||||
- **表达方式记录**: 记录使用的表达方式,提供更好的学习追踪
|
||||
- **学习优化**: 优化表达方式提取,修复表达学习出错问题
|
||||
- **配置优化**: 优化表达方式配置和逻辑,提升系统稳定性
|
||||
|
||||
#### 🔄 聊天系统统一
|
||||
- **normal和focus合并**: 彻底合并normal和focus,完全基于planner决定target message
|
||||
- **no_reply内置**: 将no_reply功能移动到主循环中,简化系统架构
|
||||
- **回复优化**: 优化reply,填补缺失值,让麦麦可以回复自己的消息
|
||||
- **频率控制API**: 加入聊天频率控制相关API,提供更精细的控制
|
||||
|
||||
#### 日志系统改进
|
||||
- **日志颜色优化**: 修改了log的颜色,更加护眼
|
||||
- **日志清理优化**: 修复了日志清理先等24h的问题,提升系统性能
|
||||
- **计时定位**: 通过计时定位LLM异常延时,提升问题排查效率
|
||||
|
||||
### 🐛 问题修复
|
||||
|
||||
#### 代码质量提升
|
||||
- **lint问题修复**: 修复了lint爆炸的问题,代码更加规范了
|
||||
- **导入优化**: 修复导入爆炸和文档错误,优化代码结构
|
||||
|
||||
#### 系统稳定性
|
||||
- **循环导入**: 修复了import时循环导入的问题
|
||||
- **并行动作**: 修复并行动作炸裂问题,提升并发处理能力
|
||||
- **空响应处理**: 空响应就raise,避免系统异常
|
||||
|
||||
#### 功能修复
|
||||
- **API问题**: 修复api问题,提升系统可用性
|
||||
- **notice问题**: 为组件方法提供新参数,暂时解决notice问题
|
||||
- **关系构建**: 修复不认识的用户构建关系问题
|
||||
- **流式解析**: 修复流式解析越界问题,避免空choices的SSE帧错误
|
||||
|
||||
#### 配置和兼容性
|
||||
- **默认值**: 添加默认值,提升配置灵活性
|
||||
- **类型问题**: 修复类型问题,提升代码健壮性
|
||||
- **配置加载**: 优化配置加载逻辑,提升系统启动稳定性
|
||||
|
||||
### 细节优化
|
||||
- 修复了lint爆炸的问题,代码更加规范了
|
||||
- 修改了log的颜色,更加护眼
|
||||
|
||||
## [0.9.1] - 2025-7-26
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ services:
|
|||
# - ./data/MaiMBot:/data/MaiMBot
|
||||
# networks:
|
||||
# - maim_bot
|
||||
|
||||
volumes:
|
||||
site-packages:
|
||||
networks:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import math
|
||||
import random
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from rich.traceback import install
|
||||
|
|
@ -8,6 +9,7 @@ from collections import deque
|
|||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
|
|
@ -15,21 +17,19 @@ from src.chat.planner_actions.planner import ActionPlanner
|
|||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.chat_loop.hfc_utils import CycleDetail
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.frequency_control.talk_frequency_control import talk_frequency_control
|
||||
from src.chat.frequency_control.focus_value_control import focus_value_control
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import ChatMode, EventType
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
import math
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
# no_action逻辑已集成到heartFC_chat.py中,不再需要导入
|
||||
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
|
||||
# 导入记忆系统
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.frequency_control.talk_frequency_control import talk_frequency_control
|
||||
from src.chat.frequency_control.focus_value_control import focus_value_control
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
|
|
@ -62,10 +62,7 @@ class HeartFChatting:
|
|||
其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_id: str,
|
||||
):
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
HeartFChatting 初始化函数
|
||||
|
||||
|
|
@ -83,7 +80,7 @@ class HeartFChatting:
|
|||
|
||||
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
|
||||
|
||||
self.talk_frequency_control = talk_frequency_control.get_talk_frequency_control(self.stream_id)
|
||||
self.focus_value_control = focus_value_control.get_focus_value_control(self.stream_id)
|
||||
|
||||
|
|
@ -104,7 +101,7 @@ class HeartFChatting:
|
|||
self.plan_timeout_count = 0
|
||||
|
||||
self.last_read_time = time.time() - 1
|
||||
|
||||
|
||||
self.focus_energy = 1
|
||||
self.no_action_consecutive = 0
|
||||
# 最近三次no_action的新消息兴趣度记录
|
||||
|
|
@ -166,27 +163,26 @@ class HeartFChatting:
|
|||
|
||||
# 获取动作类型,兼容新旧格式
|
||||
action_type = "未知动作"
|
||||
if hasattr(self, '_current_cycle_detail') and self._current_cycle_detail:
|
||||
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
|
||||
loop_plan_info = self._current_cycle_detail.loop_plan_info
|
||||
if isinstance(loop_plan_info, dict):
|
||||
action_result = loop_plan_info.get('action_result', {})
|
||||
action_result = loop_plan_info.get("action_result", {})
|
||||
if isinstance(action_result, dict):
|
||||
# 旧格式:action_result是字典
|
||||
action_type = action_result.get('action_type', '未知动作')
|
||||
action_type = action_result.get("action_type", "未知动作")
|
||||
elif isinstance(action_result, list) and action_result:
|
||||
# 新格式:action_result是actions列表
|
||||
action_type = action_result[0].get('action_type', '未知动作')
|
||||
action_type = action_result[0].get("action_type", "未知动作")
|
||||
elif isinstance(loop_plan_info, list) and loop_plan_info:
|
||||
# 直接是actions列表的情况
|
||||
action_type = loop_plan_info[0].get('action_type', '未知动作')
|
||||
action_type = loop_plan_info[0].get("action_type", "未知动作")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
|
||||
f"选择动作: {action_type}"
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
|
||||
def _determine_form_type(self) -> None:
|
||||
"""判断使用哪种形式的no_action"""
|
||||
# 如果连续no_action次数少于3次,使用waiting形式
|
||||
|
|
@ -195,42 +191,44 @@ class HeartFChatting:
|
|||
else:
|
||||
# 计算最近三次记录的兴趣度总和
|
||||
total_recent_interest = sum(self.recent_interest_records)
|
||||
|
||||
|
||||
# 计算调整后的阈值
|
||||
adjusted_threshold = 1 / self.talk_frequency_control.get_current_talk_frequency()
|
||||
|
||||
logger.info(f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}"
|
||||
)
|
||||
|
||||
# 如果兴趣度总和小于阈值,进入breaking形式
|
||||
if total_recent_interest < adjusted_threshold:
|
||||
logger.info(f"{self.log_prefix} 兴趣度不足,进入休息")
|
||||
self.focus_energy = random.randint(3, 6)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息")
|
||||
self.focus_energy = 1
|
||||
|
||||
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool,float]:
|
||||
self.focus_energy = 1
|
||||
|
||||
async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]:
|
||||
"""
|
||||
判断是否应该处理消息
|
||||
|
||||
|
||||
Args:
|
||||
new_message: 新消息列表
|
||||
mode: 当前聊天模式
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该处理消息
|
||||
"""
|
||||
new_message_count = len(new_message)
|
||||
talk_frequency = self.talk_frequency_control.get_current_talk_frequency()
|
||||
|
||||
|
||||
modified_exit_count_threshold = self.focus_energy * 0.5 / talk_frequency
|
||||
modified_exit_interest_threshold = 1.5 / talk_frequency
|
||||
total_interest = 0.0
|
||||
for msg_dict in new_message:
|
||||
interest_value = msg_dict.get("interest_value")
|
||||
if interest_value is not None and msg_dict.get("processed_plain_text", ""):
|
||||
for msg in new_message:
|
||||
interest_value = msg.interest_value
|
||||
if interest_value is not None and msg.processed_plain_text:
|
||||
total_interest += float(interest_value)
|
||||
|
||||
|
||||
if new_message_count >= modified_exit_count_threshold:
|
||||
self.recent_interest_records.append(total_interest)
|
||||
logger.info(
|
||||
|
|
@ -244,9 +242,11 @@ class HeartFChatting:
|
|||
if new_message_count > 0:
|
||||
# 只在兴趣值变化时输出log
|
||||
if not hasattr(self, "_last_accumulated_interest") or total_interest != self._last_accumulated_interest:
|
||||
logger.info(f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}"
|
||||
)
|
||||
self._last_accumulated_interest = total_interest
|
||||
|
||||
|
||||
if total_interest >= modified_exit_interest_threshold:
|
||||
# 记录兴趣度到列表
|
||||
self.recent_interest_records.append(total_interest)
|
||||
|
|
@ -261,29 +261,25 @@ class HeartFChatting:
|
|||
f"{self.log_prefix} 已等待{time.time() - self.last_read_time:.0f}秒,累计{new_message_count}条消息,累计兴趣{total_interest:.1f},继续等待..."
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return False,0.0
|
||||
|
||||
return False, 0.0
|
||||
|
||||
async def _loopbody(self):
|
||||
recent_messages_dict = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit = 10,
|
||||
limit=10,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=True,
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_recent_messages_dict = [temporarily_transform_class_to_dict(msg) for msg in recent_messages_dict]
|
||||
)
|
||||
# 统一的消息处理逻辑
|
||||
should_process,interest_value = await self._should_process_messages(temp_recent_messages_dict)
|
||||
should_process, interest_value = await self._should_process_messages(recent_messages_dict)
|
||||
|
||||
if should_process:
|
||||
self.last_read_time = time.time()
|
||||
await self._observe(interest_value = interest_value)
|
||||
await self._observe(interest_value=interest_value)
|
||||
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
|
|
@ -298,22 +294,21 @@ class HeartFChatting:
|
|||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
selected_expressions:List[int] = None,
|
||||
selected_expressions: List[int] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.get("chat_info_platform")
|
||||
if platform is None:
|
||||
platform = getattr(self.chat_stream, "platform", "unknown")
|
||||
|
||||
person = Person(platform = platform ,user_id = action_message.get("user_id", ""))
|
||||
|
||||
person = Person(platform=platform, user_id=action_message.get("user_id", ""))
|
||||
person_name = person.person_name
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
|
|
@ -342,12 +337,10 @@ class HeartFChatting:
|
|||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(self,interest_value:float = 0.0) -> bool:
|
||||
|
||||
async def _observe(self, interest_value: float = 0.0) -> bool:
|
||||
action_type = "no_action"
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
|
||||
# 使用sigmoid函数将interest_value转换为概率
|
||||
# 当interest_value为0时,概率接近0(使用Focus模式)
|
||||
# 当interest_value很高时,概率接近1(使用Normal模式)
|
||||
|
|
@ -361,12 +354,14 @@ class HeartFChatting:
|
|||
x0 = 1.0 # 控制曲线中心点
|
||||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
||||
|
||||
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 0.5 / self.talk_frequency_control.get_current_talk_frequency()
|
||||
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 2 * self.talk_frequency_control.get_current_talk_frequency()
|
||||
|
||||
# 根据概率决定使用哪种模式
|
||||
if random.random() < normal_mode_probability:
|
||||
mode = ChatMode.NORMAL
|
||||
logger.info(f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability*100:.0f}%概率下选择回复")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability * 100:.0f}%概率下选择回复"
|
||||
)
|
||||
else:
|
||||
mode = ChatMode.FOCUS
|
||||
|
||||
|
|
@ -387,10 +382,9 @@ class HeartFChatting:
|
|||
await hippocampus_manager.build_memory_for_chat(self.stream_id)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
|
||||
|
||||
|
||||
if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS:
|
||||
#如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考
|
||||
# 如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考
|
||||
actions = [
|
||||
{
|
||||
"action_type": "no_action",
|
||||
|
|
@ -420,23 +414,21 @@ class HeartFChatting:
|
|||
):
|
||||
return False
|
||||
with Timer("规划器", cycle_timers):
|
||||
actions, _= await self.action_planner.plan(
|
||||
actions, _ = await self.action_planner.plan(
|
||||
mode=mode,
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
async def execute_action(action_info,actions):
|
||||
async def execute_action(action_info, actions):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
if action_info["action_type"] == "no_action":
|
||||
# 直接处理no_action逻辑,不再通过动作系统
|
||||
reason = action_info.get("reasoning", "选择不回复")
|
||||
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
|
||||
# 存储no_action信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
|
|
@ -447,13 +439,8 @@ class HeartFChatting:
|
|||
action_data={"reason": reason},
|
||||
action_name="no_action",
|
||||
)
|
||||
|
||||
return {
|
||||
"action_type": "no_action",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": ""
|
||||
}
|
||||
|
||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||
elif action_info["action_type"] != "reply":
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
|
|
@ -463,20 +450,19 @@ class HeartFChatting:
|
|||
action_info["action_data"],
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
action_info["action_message"]
|
||||
action_info["action_message"],
|
||||
)
|
||||
return {
|
||||
"action_type": action_info["action_type"],
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command
|
||||
"command": command,
|
||||
}
|
||||
else:
|
||||
|
||||
try:
|
||||
success, response_set, prompt_selected_expressions = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message = action_info["action_message"],
|
||||
reply_message=action_info["action_message"],
|
||||
available_actions=available_actions,
|
||||
choosen_actions=actions,
|
||||
reply_reason=action_info.get("reasoning", ""),
|
||||
|
|
@ -485,29 +471,21 @@ class HeartFChatting:
|
|||
from_plugin=False,
|
||||
return_expressions=True,
|
||||
)
|
||||
|
||||
|
||||
if prompt_selected_expressions and len(prompt_selected_expressions) > 1:
|
||||
_,selected_expressions = prompt_selected_expressions
|
||||
_, selected_expressions = prompt_selected_expressions
|
||||
else:
|
||||
selected_expressions = []
|
||||
|
||||
if not success or not response_set:
|
||||
logger.info(f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
|
||||
)
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None
|
||||
}
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
|
|
@ -521,7 +499,7 @@ class HeartFChatting:
|
|||
"action_type": "reply",
|
||||
"success": True,
|
||||
"reply_text": reply_text,
|
||||
"loop_info": loop_info
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
|
|
@ -531,26 +509,26 @@ class HeartFChatting:
|
|||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
action_tasks = [asyncio.create_task(execute_action(action,actions)) for action in actions]
|
||||
|
||||
|
||||
action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions]
|
||||
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
action_command = ""
|
||||
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
|
||||
_cur_action = actions[i]
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
|
|
@ -590,7 +568,6 @@ class HeartFChatting:
|
|||
},
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
|
||||
|
||||
if s4u_config.enable_s4u:
|
||||
await stop_typing()
|
||||
|
|
@ -602,7 +579,7 @@ class HeartFChatting:
|
|||
# await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
|
||||
|
||||
action_type = actions[0]["action_type"] if actions else "no_action"
|
||||
|
||||
|
||||
# 管理no_action计数器:当执行了非no_action动作时,重置计数器
|
||||
if action_type != "no_action":
|
||||
# no_action逻辑已集成到heartFC_chat.py中,直接重置计数器
|
||||
|
|
@ -610,7 +587,7 @@ class HeartFChatting:
|
|||
self.no_action_consecutive = 0
|
||||
logger.debug(f"{self.log_prefix} 执行了{action_type}动作,重置no_action计数器")
|
||||
return True
|
||||
|
||||
|
||||
if action_type == "no_action":
|
||||
self.no_action_consecutive += 1
|
||||
self._determine_form_type()
|
||||
|
|
@ -692,11 +669,12 @@ class HeartFChatting:
|
|||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
async def _send_response(self,
|
||||
reply_set,
|
||||
message_data,
|
||||
selected_expressions:List[int] = None,
|
||||
) -> str:
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set,
|
||||
message_data,
|
||||
selected_expressions: List[int] = None,
|
||||
) -> str:
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
|
|
@ -714,7 +692,7 @@ class HeartFChatting:
|
|||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message = message_data,
|
||||
reply_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
selected_expressions=selected_expressions,
|
||||
|
|
@ -724,7 +702,7 @@ class HeartFChatting:
|
|||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message = message_data,
|
||||
reply_message=message_data,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
selected_expressions=selected_expressions,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import List, Dict, Optional, Any, Tuple
|
|||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
|
|
@ -346,21 +347,17 @@ class ExpressionLearner:
|
|||
current_time = time.time()
|
||||
|
||||
# 获取上次学习时间
|
||||
random_msg_temp = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=current_time,
|
||||
limit=num,
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
random_msg: Optional[List[Dict[str, Any]]] = [temporarily_transform_class_to_dict(msg) for msg in random_msg_temp] if random_msg_temp else None
|
||||
|
||||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
# 转化成str
|
||||
chat_id: str = random_msg[0]["chat_id"]
|
||||
chat_id: str = random_msg[0].chat_id
|
||||
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
|
|
|
|||
|
|
@ -70,11 +70,9 @@ class ActionModifier:
|
|||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
|
||||
|
||||
chat_content = build_readable_messages(
|
||||
temp_msg_list_before_now_half,
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
|
|
|
|||
|
|
@ -280,11 +280,8 @@ class ActionPlanner:
|
|||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_msg_list_before_now = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=temp_msg_list_before_now,
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
|
|
|
|||
|
|
@ -2,18 +2,21 @@ import time # 导入 time 模块以获取当前时间
|
|||
import random
|
||||
import re
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Union
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_data_model import MessageAndActionModel
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.person_info.person_info import Person, get_person_id
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("chat_message_builder")
|
||||
|
||||
|
||||
def replace_user_references_sync(
|
||||
|
|
@ -349,7 +352,9 @@ def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit
|
|||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[DatabaseMessages]:
|
||||
def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
|
|
@ -390,7 +395,7 @@ def num_new_messages_since_with_users(
|
|||
|
||||
|
||||
def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[MessageAndActionModel],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
|
|
@ -398,7 +403,7 @@ def _build_readable_messages_internal(
|
|||
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
message_id_list: Optional[List[DatabaseMessages]] = None,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
|
|
@ -429,14 +434,15 @@ def _build_readable_messages_internal(
|
|||
timestamp_to_id = {}
|
||||
if message_id_list:
|
||||
for item in message_id_list:
|
||||
message = item.get("message", {})
|
||||
timestamp = message.get("time")
|
||||
timestamp = item.time
|
||||
if timestamp is not None:
|
||||
timestamp_to_id[timestamp] = item.get("id", "")
|
||||
timestamp_to_id[timestamp] = item.message_id
|
||||
|
||||
def process_pic_ids(content: str) -> str:
|
||||
def process_pic_ids(content: Optional[str]) -> str:
|
||||
"""处理内容中的图片ID,将其替换为[图片x]格式"""
|
||||
nonlocal current_pic_counter
|
||||
if content is None:
|
||||
logger.warning("Content is None when processing pic IDs.")
|
||||
raise ValueError("Content is None")
|
||||
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
|
@ -456,38 +462,23 @@ def _build_readable_messages_internal(
|
|||
# 1 & 2: 获取发送者信息并提取消息组件
|
||||
for msg in messages:
|
||||
# 检查是否是动作记录
|
||||
if msg.get("is_action_record", False):
|
||||
if msg.is_action_record:
|
||||
is_action = True
|
||||
timestamp: float = msg.get("time") # type: ignore
|
||||
content = msg.get("display_message", "")
|
||||
timestamp: float = msg.time
|
||||
content = msg.display_message
|
||||
# 对于动作记录,也处理图片ID
|
||||
content = process_pic_ids(content)
|
||||
message_details_raw.append((timestamp, global_config.bot.nickname, content, is_action))
|
||||
continue
|
||||
|
||||
# 检查并修复缺少的user_info字段
|
||||
if "user_info" not in msg:
|
||||
# 创建user_info字段
|
||||
msg["user_info"] = {
|
||||
"platform": msg.get("user_platform", ""),
|
||||
"user_id": msg.get("user_id", ""),
|
||||
"user_nickname": msg.get("user_nickname", ""),
|
||||
"user_cardname": msg.get("user_cardname", ""),
|
||||
}
|
||||
platform = msg.user_platform
|
||||
user_id = msg.user_id
|
||||
|
||||
user_info = msg.get("user_info", {})
|
||||
platform = user_info.get("platform")
|
||||
user_id = user_info.get("user_id")
|
||||
user_nickname = msg.user_nickname
|
||||
user_cardname = msg.user_cardname
|
||||
|
||||
user_nickname = user_info.get("user_nickname")
|
||||
user_cardname = user_info.get("user_cardname")
|
||||
|
||||
timestamp: float = msg.get("time") # type: ignore
|
||||
content: str
|
||||
if msg.get("display_message"):
|
||||
content = msg.get("display_message", "")
|
||||
else:
|
||||
content = msg.get("processed_plain_text", "") # 默认空字符串
|
||||
timestamp = msg.time
|
||||
content = msg.display_message or msg.processed_plain_text or ""
|
||||
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
|
|
@ -776,7 +767,7 @@ async def build_readable_messages_with_list(
|
|||
|
||||
|
||||
def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
|
|
@ -807,7 +798,7 @@ def build_readable_messages_with_id(
|
|||
|
||||
|
||||
def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
|
|
@ -815,7 +806,7 @@ def build_readable_messages(
|
|||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
message_id_list: Optional[List[DatabaseMessages]] = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
|
|
@ -831,19 +822,32 @@ def build_readable_messages(
|
|||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
"""
|
||||
# WIP HERE and BELOW ----------------------------------------------
|
||||
# 创建messages的深拷贝,避免修改原始列表
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
copy_messages = [msg.copy() for msg in messages]
|
||||
copy_messages: List[MessageAndActionModel] = [
|
||||
MessageAndActionModel(
|
||||
msg.time,
|
||||
msg.user_info.user_id,
|
||||
msg.user_info.platform,
|
||||
msg.user_info.user_nickname,
|
||||
msg.user_info.user_cardname,
|
||||
msg.processed_plain_text,
|
||||
msg.display_message,
|
||||
msg.chat_info.platform,
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
if show_actions and copy_messages:
|
||||
# 获取所有消息的时间范围
|
||||
min_time = min(msg.get("time", 0) for msg in copy_messages)
|
||||
max_time = max(msg.get("time", 0) for msg in copy_messages)
|
||||
min_time = min(msg.time or 0 for msg in copy_messages)
|
||||
max_time = max(msg.time or 0 for msg in copy_messages)
|
||||
|
||||
# 从第一条消息中获取chat_id
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
chat_id = messages[0].chat_id if messages else None
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = (
|
||||
|
|
@ -863,27 +867,28 @@ def build_readable_messages(
|
|||
)
|
||||
|
||||
# 合并两部分动作记录
|
||||
actions = list(actions_in_range) + list(action_after_latest)
|
||||
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
|
||||
|
||||
# 将动作记录转换为消息格式
|
||||
for action in actions:
|
||||
# 只有当build_into_prompt为True时才添加动作记录
|
||||
if action.action_build_into_prompt:
|
||||
action_msg = {
|
||||
"time": action.time,
|
||||
"user_id": global_config.bot.qq_account, # 使用机器人的QQ账号
|
||||
"user_nickname": global_config.bot.nickname, # 使用机器人的昵称
|
||||
"user_cardname": "", # 机器人没有群名片
|
||||
"processed_plain_text": f"{action.action_prompt_display}",
|
||||
"display_message": f"{action.action_prompt_display}",
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
"is_action_record": True, # 添加标识字段
|
||||
"action_name": action.action_name, # 保存动作名称
|
||||
}
|
||||
action_msg = MessageAndActionModel(
|
||||
time=float(action.time), # type: ignore
|
||||
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
|
||||
user_platform=global_config.bot.platform, # 使用机器人的平台
|
||||
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
|
||||
user_cardname="", # 机器人没有群名片
|
||||
processed_plain_text=f"{action.action_prompt_display}",
|
||||
display_message=f"{action.action_prompt_display}",
|
||||
chat_info_platform=str(action.chat_info_platform),
|
||||
is_action_record=True, # 添加标识字段
|
||||
action_name=str(action.action_name), # 保存动作名称
|
||||
)
|
||||
copy_messages.append(action_msg)
|
||||
|
||||
# 重新按时间排序
|
||||
copy_messages.sort(key=lambda x: x.get("time", 0))
|
||||
copy_messages.sort(key=lambda x: x.time or 0)
|
||||
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
|
|
@ -905,8 +910,8 @@ def build_readable_messages(
|
|||
return formatted_string
|
||||
else:
|
||||
# 按 read_mark 分割消息
|
||||
messages_before_mark = [msg for msg in copy_messages if msg.get("time", 0) <= read_mark]
|
||||
messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark]
|
||||
messages_before_mark = [msg for msg in copy_messages if (msg.time or 0) <= read_mark]
|
||||
messages_after_mark = [msg for msg in copy_messages if (msg.time or 0) > read_mark]
|
||||
|
||||
# 共享的图片映射字典和计数器
|
||||
pic_id_mapping = {}
|
||||
|
|
@ -960,13 +965,13 @@ def build_readable_messages(
|
|||
return "".join(result_parts)
|
||||
|
||||
|
||||
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
||||
"""
|
||||
构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。
|
||||
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段,将bbb映射为匿名占位符。
|
||||
"""
|
||||
if not messages:
|
||||
print("111111111111没有消息,无法构建匿名消息")
|
||||
logger.warning("没有消息,无法构建匿名消息")
|
||||
return ""
|
||||
|
||||
person_map = {}
|
||||
|
|
@ -1017,14 +1022,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||
|
||||
for msg in messages:
|
||||
try:
|
||||
platform: str = msg.get("chat_info_platform") # type: ignore
|
||||
user_id = msg.get("user_id")
|
||||
_timestamp = msg.get("time")
|
||||
content: str = ""
|
||||
if msg.get("display_message"):
|
||||
content = msg.get("display_message", "")
|
||||
else:
|
||||
content = msg.get("processed_plain_text", "")
|
||||
platform = msg.chat_info.platform
|
||||
user_id = msg.user_info.user_id
|
||||
content = msg.display_message or msg.processed_plain_text or ""
|
||||
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
|
|
|
|||
|
|
@ -3,18 +3,21 @@ from dataclasses import dataclass, field, fields, MISSING
|
|||
|
||||
from . import AbstractClassFlag
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseUserInfo(AbstractClassFlag):
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.platform, str), "platform must be a string"
|
||||
assert isinstance(self.user_id, str), "user_id must be a string"
|
||||
assert isinstance(self.user_nickname, str), "user_nickname must be a string"
|
||||
assert isinstance(self.user_cardname, str) or self.user_cardname is None, "user_cardname must be a string or None"
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.platform, str), "platform must be a string"
|
||||
# assert isinstance(self.user_id, str), "user_id must be a string"
|
||||
# assert isinstance(self.user_nickname, str), "user_nickname must be a string"
|
||||
# assert isinstance(self.user_cardname, str) or self.user_cardname is None, (
|
||||
# "user_cardname must be a string or None"
|
||||
# )
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -22,11 +25,13 @@ class DatabaseGroupInfo(AbstractClassFlag):
|
|||
group_id: str = field(default_factory=str)
|
||||
group_name: str = field(default_factory=str)
|
||||
group_platform: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.group_id, str), "group_id must be a string"
|
||||
assert isinstance(self.group_name, str), "group_name must be a string"
|
||||
assert isinstance(self.group_platform, str) or self.group_platform is None, "group_platform must be a string or None"
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.group_id, str), "group_id must be a string"
|
||||
# assert isinstance(self.group_name, str), "group_name must be a string"
|
||||
# assert isinstance(self.group_platform, str) or self.group_platform is None, (
|
||||
# "group_platform must be a string or None"
|
||||
# )
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -38,123 +43,117 @@ class DatabaseChatInfo(AbstractClassFlag):
|
|||
user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo)
|
||||
group_info: Optional[DatabaseGroupInfo] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.stream_id, str), "stream_id must be a string"
|
||||
assert isinstance(self.platform, str), "platform must be a string"
|
||||
assert isinstance(self.create_time, float), "create_time must be a float"
|
||||
assert isinstance(self.last_active_time, float), "last_active_time must be a float"
|
||||
assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance"
|
||||
assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, "group_info must be a DatabaseGroupInfo instance or None"
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.stream_id, str), "stream_id must be a string"
|
||||
# assert isinstance(self.platform, str), "platform must be a string"
|
||||
# assert isinstance(self.create_time, float), "create_time must be a float"
|
||||
# assert isinstance(self.last_active_time, float), "last_active_time must be a float"
|
||||
# assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance"
|
||||
# assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, (
|
||||
# "group_info must be a DatabaseGroupInfo instance or None"
|
||||
# )
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseMessages(AbstractClassFlag):
|
||||
# chat_info: DatabaseChatInfo
|
||||
# user_info: DatabaseUserInfo
|
||||
# group_info: Optional[DatabaseGroupInfo] = None
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str = "",
|
||||
time: float = 0.0,
|
||||
chat_id: str = "",
|
||||
reply_to: Optional[str] = None,
|
||||
interest_value: Optional[float] = None,
|
||||
key_words: Optional[str] = None,
|
||||
key_words_lite: Optional[str] = None,
|
||||
is_mentioned: Optional[bool] = None,
|
||||
processed_plain_text: Optional[str] = None,
|
||||
display_message: Optional[str] = None,
|
||||
priority_mode: Optional[str] = None,
|
||||
priority_info: Optional[str] = None,
|
||||
additional_config: Optional[str] = None,
|
||||
is_emoji: bool = False,
|
||||
is_picid: bool = False,
|
||||
is_command: bool = False,
|
||||
is_notify: bool = False,
|
||||
selected_expressions: Optional[str] = None,
|
||||
user_id: str = "",
|
||||
user_nickname: str = "",
|
||||
user_cardname: Optional[str] = None,
|
||||
user_platform: str = "",
|
||||
chat_info_group_id: Optional[str] = None,
|
||||
chat_info_group_name: Optional[str] = None,
|
||||
chat_info_group_platform: Optional[str] = None,
|
||||
chat_info_user_id: str = "",
|
||||
chat_info_user_nickname: str = "",
|
||||
chat_info_user_cardname: Optional[str] = None,
|
||||
chat_info_user_platform: str = "",
|
||||
chat_info_stream_id: str = "",
|
||||
chat_info_platform: str = "",
|
||||
chat_info_create_time: float = 0.0,
|
||||
chat_info_last_active_time: float = 0.0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.message_id = message_id
|
||||
self.time = time
|
||||
self.chat_id = chat_id
|
||||
self.reply_to = reply_to
|
||||
self.interest_value = interest_value
|
||||
|
||||
message_id: str = field(default_factory=str)
|
||||
time: float = field(default_factory=float)
|
||||
chat_id: str = field(default_factory=str)
|
||||
reply_to: Optional[str] = None
|
||||
interest_value: Optional[float] = None
|
||||
self.key_words = key_words
|
||||
self.key_words_lite = key_words_lite
|
||||
self.is_mentioned = is_mentioned
|
||||
|
||||
key_words: Optional[str] = None
|
||||
key_words_lite: Optional[str] = None
|
||||
is_mentioned: Optional[bool] = None
|
||||
self.processed_plain_text = processed_plain_text
|
||||
self.display_message = display_message
|
||||
|
||||
# 从 chat_info 扁平化而来的字段
|
||||
# chat_info_stream_id: str = field(default_factory=str)
|
||||
# chat_info_platform: str = field(default_factory=str)
|
||||
# chat_info_user_platform: str = field(default_factory=str)
|
||||
# chat_info_user_id: str = field(default_factory=str)
|
||||
# chat_info_user_nickname: str = field(default_factory=str)
|
||||
# chat_info_user_cardname: Optional[str] = None
|
||||
# chat_info_group_platform: Optional[str] = None
|
||||
# chat_info_group_id: Optional[str] = None
|
||||
# chat_info_group_name: Optional[str] = None
|
||||
# chat_info_create_time: float = field(default_factory=float)
|
||||
# chat_info_last_active_time: float = field(default_factory=float)
|
||||
self.priority_mode = priority_mode
|
||||
self.priority_info = priority_info
|
||||
|
||||
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
|
||||
# user_platform: str = field(default_factory=str)
|
||||
# user_id: str = field(default_factory=str)
|
||||
# user_nickname: str = field(default_factory=str)
|
||||
# user_cardname: Optional[str] = None
|
||||
self.additional_config = additional_config
|
||||
self.is_emoji = is_emoji
|
||||
self.is_picid = is_picid
|
||||
self.is_command = is_command
|
||||
self.is_notify = is_notify
|
||||
|
||||
processed_plain_text: Optional[str] = None # 处理后的纯文本消息
|
||||
display_message: Optional[str] = None # 显示的消息
|
||||
|
||||
priority_mode: Optional[str] = None
|
||||
priority_info: Optional[str] = None
|
||||
|
||||
additional_config: Optional[str] = None
|
||||
is_emoji: bool = False
|
||||
is_picid: bool = False
|
||||
is_command: bool = False
|
||||
is_notify: bool = False
|
||||
|
||||
selected_expressions: Optional[str] = None
|
||||
|
||||
# def __post_init__(self):
|
||||
|
||||
# if self.chat_info_group_id and self.chat_info_group_name:
|
||||
# self.group_info = DatabaseGroupInfo(
|
||||
# group_id=self.chat_info_group_id,
|
||||
# group_name=self.chat_info_group_name,
|
||||
# group_platform=self.chat_info_group_platform,
|
||||
# )
|
||||
|
||||
# chat_user_info = DatabaseUserInfo(
|
||||
# user_id=self.chat_info_user_id,
|
||||
# user_nickname=self.chat_info_user_nickname,
|
||||
# user_cardname=self.chat_info_user_cardname,
|
||||
# platform=self.chat_info_user_platform,
|
||||
# )
|
||||
# self.chat_info = DatabaseChatInfo(
|
||||
# stream_id=self.chat_info_stream_id,
|
||||
# platform=self.chat_info_platform,
|
||||
# create_time=self.chat_info_create_time,
|
||||
# last_active_time=self.chat_info_last_active_time,
|
||||
# user_info=chat_user_info,
|
||||
# group_info=self.group_info,
|
||||
# )
|
||||
def __init__(self, **kwargs: Any):
|
||||
defined = {f.name: f for f in fields(self.__class__)}
|
||||
for name, f in defined.items():
|
||||
if name in kwargs:
|
||||
setattr(self, name, kwargs.pop(name))
|
||||
elif f.default is not MISSING:
|
||||
setattr(self, name, f.default)
|
||||
else:
|
||||
raise TypeError(f"缺失必需字段: {name}")
|
||||
self.selected_expressions = selected_expressions
|
||||
|
||||
self.group_info: Optional[DatabaseGroupInfo] = None
|
||||
self.user_info = DatabaseUserInfo(
|
||||
user_id=kwargs.get("user_id"), # type: ignore
|
||||
user_nickname=kwargs.get("user_nickname"), # type: ignore
|
||||
user_cardname=kwargs.get("user_cardname"), # type: ignore
|
||||
platform=kwargs.get("user_platform"), # type: ignore
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
platform=user_platform,
|
||||
)
|
||||
if kwargs.get("chat_info_group_id") and kwargs.get("chat_info_group_name"):
|
||||
if chat_info_group_id and chat_info_group_name:
|
||||
self.group_info = DatabaseGroupInfo(
|
||||
group_id=kwargs.get("chat_info_group_id"), # type: ignore
|
||||
group_name=kwargs.get("chat_info_group_name"), # type: ignore
|
||||
group_platform=kwargs.get("chat_info_group_platform"), # type: ignore
|
||||
group_id=chat_info_group_id,
|
||||
group_name=chat_info_group_name,
|
||||
group_platform=chat_info_group_platform,
|
||||
)
|
||||
|
||||
chat_user_info = DatabaseUserInfo(
|
||||
user_id=kwargs.get("chat_info_user_id"), # type: ignore
|
||||
user_nickname=kwargs.get("chat_info_user_nickname"), # type: ignore
|
||||
user_cardname=kwargs.get("chat_info_user_cardname"), # type: ignore
|
||||
platform=kwargs.get("chat_info_user_platform"), # type: ignore
|
||||
)
|
||||
|
||||
self.chat_info = DatabaseChatInfo(
|
||||
stream_id=kwargs.get("chat_info_stream_id"), # type: ignore
|
||||
platform=kwargs.get("chat_info_platform"), # type: ignore
|
||||
create_time=kwargs.get("chat_info_create_time"), # type: ignore
|
||||
last_active_time=kwargs.get("chat_info_last_active_time"), # type: ignore
|
||||
user_info=chat_user_info,
|
||||
stream_id=chat_info_stream_id,
|
||||
platform=chat_info_platform,
|
||||
create_time=chat_info_create_time,
|
||||
last_active_time=chat_info_last_active_time,
|
||||
user_info=DatabaseUserInfo(
|
||||
user_id=chat_info_user_id,
|
||||
user_nickname=chat_info_user_nickname,
|
||||
user_cardname=chat_info_user_cardname,
|
||||
platform=chat_info_user_platform,
|
||||
),
|
||||
group_info=self.group_info,
|
||||
)
|
||||
|
||||
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.message_id, str), "message_id must be a string"
|
||||
# assert isinstance(self.time, float), "time must be a float"
|
||||
# assert isinstance(self.chat_id, str), "chat_id must be a string"
|
||||
# assert isinstance(self.reply_to, str) or self.reply_to is None, "reply_to must be a string or None"
|
||||
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
|
||||
# "interest_value must be a float or None"
|
||||
# )
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageAndActionModel:
|
||||
time: float = field(default_factory=float)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_platform: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
processed_plain_text: Optional[str] = None
|
||||
display_message: Optional[str] = None
|
||||
chat_info_platform: str = field(default_factory=str)
|
||||
is_action_record: bool = field(default=False)
|
||||
action_name: Optional[str] = None
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试del_memory函数的脚本
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加src目录到Python路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
|
||||
from person_info.person_info import Person
|
||||
|
||||
def test_del_memory():
|
||||
"""测试del_memory函数"""
|
||||
print("开始测试del_memory函数...")
|
||||
|
||||
# 创建一个测试用的Person实例(不连接数据库)
|
||||
person = Person.__new__(Person)
|
||||
person.person_id = "test_person"
|
||||
person.memory_points = [
|
||||
"性格:这个人很友善:5.0",
|
||||
"性格:这个人很友善:4.0",
|
||||
"爱好:喜欢打游戏:3.0",
|
||||
"爱好:喜欢打游戏:2.0",
|
||||
"工作:是一名程序员:1.0",
|
||||
"性格:这个人很友善:6.0"
|
||||
]
|
||||
|
||||
print(f"原始记忆点数量: {len(person.memory_points)}")
|
||||
print("原始记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
# 测试删除"性格"分类中"这个人很友善"的记忆
|
||||
print("\n测试1: 删除'性格'分类中'这个人很友善'的记忆")
|
||||
deleted_count = person.del_memory("性格", "这个人很友善")
|
||||
print(f"删除了 {deleted_count} 个记忆点")
|
||||
print("删除后的记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
# 测试删除"爱好"分类中"喜欢打游戏"的记忆
|
||||
print("\n测试2: 删除'爱好'分类中'喜欢打游戏'的记忆")
|
||||
deleted_count = person.del_memory("爱好", "喜欢打游戏")
|
||||
print(f"删除了 {deleted_count} 个记忆点")
|
||||
print("删除后的记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
# 测试相似度匹配
|
||||
print("\n测试3: 测试相似度匹配")
|
||||
person.memory_points = [
|
||||
"性格:这个人非常友善:5.0",
|
||||
"性格:这个人很友善:4.0",
|
||||
"性格:这个人友善:3.0"
|
||||
]
|
||||
print("原始记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
# 删除"这个人很友善"(应该匹配"这个人很友善"和"这个人友善")
|
||||
deleted_count = person.del_memory("性格", "这个人很友善", similarity_threshold=0.8)
|
||||
print(f"删除了 {deleted_count} 个记忆点")
|
||||
print("删除后的记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
print("\n测试完成!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_del_memory()
|
||||
|
|
@ -1,124 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试修复后的memory_points处理
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加src目录到Python路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
|
||||
from person_info.person_info import Person
|
||||
|
||||
def test_memory_points_with_none():
|
||||
"""测试包含None值的memory_points处理"""
|
||||
print("测试包含None值的memory_points处理...")
|
||||
|
||||
# 创建一个测试Person实例
|
||||
person = Person(person_id="test_user_123")
|
||||
|
||||
# 模拟包含None值的memory_points
|
||||
person.memory_points = [
|
||||
"喜好:喜欢咖啡:1.0",
|
||||
None, # 模拟None值
|
||||
"性格:开朗:1.0",
|
||||
None, # 模拟另一个None值
|
||||
"兴趣:编程:1.0"
|
||||
]
|
||||
|
||||
print(f"原始memory_points: {person.memory_points}")
|
||||
|
||||
# 测试get_all_category方法
|
||||
try:
|
||||
categories = person.get_all_category()
|
||||
print(f"获取到的分类: {categories}")
|
||||
print("✓ get_all_category方法正常工作")
|
||||
except Exception as e:
|
||||
print(f"✗ get_all_category方法出错: {e}")
|
||||
return False
|
||||
|
||||
# 测试get_memory_list_by_category方法
|
||||
try:
|
||||
memories = person.get_memory_list_by_category("喜好")
|
||||
print(f"获取到的喜好记忆: {memories}")
|
||||
print("✓ get_memory_list_by_category方法正常工作")
|
||||
except Exception as e:
|
||||
print(f"✗ get_memory_list_by_category方法出错: {e}")
|
||||
return False
|
||||
|
||||
# 测试del_memory方法
|
||||
try:
|
||||
deleted_count = person.del_memory("喜好", "喜欢咖啡")
|
||||
print(f"删除的记忆点数量: {deleted_count}")
|
||||
print(f"删除后的memory_points: {person.memory_points}")
|
||||
print("✓ del_memory方法正常工作")
|
||||
except Exception as e:
|
||||
print(f"✗ del_memory方法出错: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_memory_points_empty():
|
||||
"""测试空的memory_points处理"""
|
||||
print("\n测试空的memory_points处理...")
|
||||
|
||||
person = Person(person_id="test_user_456")
|
||||
person.memory_points = []
|
||||
|
||||
try:
|
||||
categories = person.get_all_category()
|
||||
print(f"空列表的分类: {categories}")
|
||||
print("✓ 空列表处理正常")
|
||||
except Exception as e:
|
||||
print(f"✗ 空列表处理出错: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
memories = person.get_memory_list_by_category("测试分类")
|
||||
print(f"空列表的记忆: {memories}")
|
||||
print("✓ 空列表分类查询正常")
|
||||
except Exception as e:
|
||||
print(f"✗ 空列表分类查询出错: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_memory_points_all_none():
|
||||
"""测试全部为None的memory_points处理"""
|
||||
print("\n测试全部为None的memory_points处理...")
|
||||
|
||||
person = Person(person_id="test_user_789")
|
||||
person.memory_points = [None, None, None]
|
||||
|
||||
try:
|
||||
categories = person.get_all_category()
|
||||
print(f"全None列表的分类: {categories}")
|
||||
print("✓ 全None列表处理正常")
|
||||
except Exception as e:
|
||||
print(f"✗ 全None列表处理出错: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
memories = person.get_memory_list_by_category("测试分类")
|
||||
print(f"全None列表的记忆: {memories}")
|
||||
print("✓ 全None列表分类查询正常")
|
||||
except Exception as e:
|
||||
print(f"✗ 全None列表分类查询出错: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始测试修复后的memory_points处理...")
|
||||
|
||||
success = True
|
||||
success &= test_memory_points_with_none()
|
||||
success &= test_memory_points_empty()
|
||||
success &= test_memory_points_all_none()
|
||||
|
||||
if success:
|
||||
print("\n🎉 所有测试通过!memory_points的None值处理已修复。")
|
||||
else:
|
||||
print("\n❌ 部分测试失败,需要进一步检查。")
|
||||
Loading…
Reference in New Issue