pull/1207/head
Windpicker-owo 2025-08-21 14:31:52 +08:00
commit d93091ff1d
38 changed files with 990 additions and 1180 deletions

View File

@ -25,8 +25,8 @@
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,支持normal和focus统一化处理
- 🔌 **强大插件系统**:全面重构的插件架构,支持完整的管理API和权限控制
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制
- 🔌 **强大插件系统**:全面重构的插件架构,更多API
- 🤔 **实时思维系统**:模拟人类思考过程。
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
- 💝 **情感表达系统**:情绪系统和表情包系统。
@ -46,7 +46,7 @@
## 🔥 更新和安装
**最新版本: v0.9.1** ([更新日志](changelogs/changelog.md))
**最新版本: v0.10.0** ([更新日志](changelogs/changelog.md))
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
@ -56,7 +56,6 @@
- `classical`: 旧版本(停止维护)
### 最新版本部署教程
- [从0.6/0.7升级须知](https://docs.mai-mai.org/faq/maibot/update_to_07.html)
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
> [!WARNING]
@ -64,7 +63,6 @@
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
> - 文档未完善,有问题可以提交 Issue 或者 Discussion。
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
> - 由于持续迭代,可能存在一些已知或未知的 bug。
> - 由于程序处于开发中,可能消耗较多 token。
## 💬 讨论

View File

@ -1,16 +1,74 @@
# Changelog
## [0.10.0] - 2025-7-1
### 主要功能更改
### 🌟 主要功能更改
- 优化的回复生成,现在的回复对上下文把控更加精准
- 新的回复逻辑控制现在合并了normal和focus模式更加统一
- 优化表达方式系统,现在学习和使用更加精准
- 新的关系系统,现在的关系构建更精准也更克制
- 工具系统重构,现在合并到了插件系统中
- 彻底重构了整个LLM Request了现在支持模型轮询和更多灵活的参数
- 同时重构了整个模型配置系统升级需要重新配置llm配置文件
- 随着LLM Request的重构插件系统彻底重构完成。插件系统进入稳定状态仅增加新的API
- 具体相比于之前的更改可以查看[changes.md](./changes.md)
- **警告所有插件开发者:插件系统即将迎来不稳定时期,随时会发动更改。**
#### 🔧 工具系统重构
- **工具系统整合**: 工具系统现在完全合并到插件系统中,提供统一的扩展能力
- **工具启用控制**: 支持配置是否启用特定工具,提供更人性化的直接调用方式
- **配置文件读取**: 工具现在支持读取配置文件,增强配置灵活性
#### 🚀 LLM系统全面重构
- **LLM Request重构**: 彻底重构了整个LLM Request系统现在支持模型轮询和更多灵活的参数
- **模型配置升级**: 同时重构了整个模型配置系统升级需要重新配置llm配置文件
- **任务类型支持**: 新增任务类型和能力字段至模型配置,增强模型初始化逻辑
- **异常处理增强**: 增强LLMRequest类的异常处理添加统一的模型异常处理方法
#### 🔌 插件系统稳定化
- **插件系统重构完成**: 随着LLM Request的重构插件系统彻底重构完成进入稳定状态
- **API扩展**: 仅增加新的API保持向后兼容性
- **插件管理优化**: 让插件管理配置真正有用,提升管理体验
#### 💾 记忆系统优化
- **及时构建**: 记忆系统再优化,现在及时构建,并且不会重复构建
- **精确提取**: 记忆提取更精确,提升记忆质量
#### 🎭 表达方式系统
- **表达方式记录**: 记录使用的表达方式,提供更好的学习追踪
- **学习优化**: 优化表达方式提取,修复表达学习出错问题
- **配置优化**: 优化表达方式配置和逻辑,提升系统稳定性
#### 🔄 聊天系统统一
- **normal和focus合并**: 彻底合并normal和focus完全基于planner决定target message
- **no_reply内置**: 将no_reply功能移动到主循环中简化系统架构
- **回复优化**: 优化reply填补缺失值让麦麦可以回复自己的消息
- **频率控制API**: 加入聊天频率控制相关API提供更精细的控制
#### 日志系统改进
- **日志颜色优化**: 修改了log的颜色更加护眼
- **日志清理优化**: 修复了日志清理先等24h的问题提升系统性能
- **计时定位**: 通过计时定位LLM异常延时提升问题排查效率
### 🐛 问题修复
#### 代码质量提升
- **lint问题修复**: 修复了lint爆炸的问题代码更加规范了
- **导入优化**: 修复导入爆炸和文档错误,优化代码结构
#### 系统稳定性
- **循环导入**: 修复了import时循环导入的问题
- **并行动作**: 修复并行动作炸裂问题,提升并发处理能力
- **空响应处理**: 空响应就raise避免系统异常
#### 功能修复
- **API问题**: 修复api问题提升系统可用性
- **notice问题**: 为组件方法提供新参数暂时解决notice问题
- **关系构建**: 修复不认识的用户构建关系问题
- **流式解析**: 修复流式解析越界问题避免空choices的SSE帧错误
#### 配置和兼容性
- **默认值**: 添加默认值,提升配置灵活性
- **类型问题**: 修复类型问题,提升代码健壮性
- **配置加载**: 优化配置加载逻辑,提升系统启动稳定性
### 细节优化
- 修复了lint爆炸的问题代码更加规范了
- 修改了log的颜色更加护眼
## [0.9.1] - 2025-7-26

View File

@ -84,6 +84,7 @@ services:
# - ./data/MaiMBot:/data/MaiMBot
# networks:
# - maim_bot
volumes:
site-packages:
networks:

View File

@ -47,3 +47,4 @@ reportportal-client
scikit-learn
seaborn
structlog
google.genai

View File

@ -6,6 +6,7 @@
import sys
import os
import asyncio
from time import sleep
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@ -172,7 +173,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
return True
def main(): # sourcery skip: dict-comprehension
async def main_async(): # sourcery skip: dict-comprehension
# 新增确认提示
print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
@ -239,6 +240,29 @@ def main(): # sourcery skip: dict-comprehension
return None
def main():
"""主函数 - 设置新的事件循环并运行异步主函数"""
# 检查是否有现有的事件循环
try:
loop = asyncio.get_running_loop()
if loop.is_closed():
# 如果事件循环已关闭,创建新的
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
# 没有运行的事件循环,创建新的
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 在新的事件循环中运行异步主函数
loop.run_until_complete(main_async())
finally:
# 确保事件循环被正确关闭
if not loop.is_closed():
loop.close()
if __name__ == "__main__":
# logger.info(f"111111111111111111111111{ROOT_PATH}")
main()

View File

@ -1,6 +1,7 @@
import asyncio
import time
import traceback
import math
import random
from typing import List, Optional, Dict, Any, Tuple
from rich.traceback import install
@ -8,6 +9,7 @@ from collections import deque
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
@ -15,21 +17,19 @@ from src.chat.planner_actions.planner import ActionPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.chat_loop.hfc_utils import CycleDetail
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.frequency_control.talk_frequency_control import talk_frequency_control
from src.chat.frequency_control.focus_value_control import focus_value_control
from src.chat.express.expression_learner import expression_learner_manager
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ChatMode, EventType
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.mais4u.mai_think import mai_thinking_manager
import math
from src.mais4u.s4u_config import s4u_config
# no_action逻辑已集成到heartFC_chat.py中不再需要导入
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
# 导入记忆系统
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.frequency_control.talk_frequency_control import talk_frequency_control
from src.chat.frequency_control.focus_value_control import focus_value_control
ERROR_LOOP_INFO = {
"loop_plan_info": {
@ -62,10 +62,7 @@ class HeartFChatting:
其生命周期现在由其关联的 SubHeartflow FOCUSED 状态控制
"""
def __init__(
self,
chat_id: str,
):
def __init__(self, chat_id: str):
"""
HeartFChatting 初始化函数
@ -83,7 +80,7 @@ class HeartFChatting:
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
self.talk_frequency_control = talk_frequency_control.get_talk_frequency_control(self.stream_id)
self.focus_value_control = focus_value_control.get_focus_value_control(self.stream_id)
@ -104,7 +101,7 @@ class HeartFChatting:
self.plan_timeout_count = 0
self.last_read_time = time.time() - 1
self.focus_energy = 1
self.no_action_consecutive = 0
# 最近三次no_action的新消息兴趣度记录
@ -166,27 +163,26 @@ class HeartFChatting:
# 获取动作类型,兼容新旧格式
action_type = "未知动作"
if hasattr(self, '_current_cycle_detail') and self._current_cycle_detail:
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
loop_plan_info = self._current_cycle_detail.loop_plan_info
if isinstance(loop_plan_info, dict):
action_result = loop_plan_info.get('action_result', {})
action_result = loop_plan_info.get("action_result", {})
if isinstance(action_result, dict):
# 旧格式action_result是字典
action_type = action_result.get('action_type', '未知动作')
action_type = action_result.get("action_type", "未知动作")
elif isinstance(action_result, list) and action_result:
# 新格式action_result是actions列表
action_type = action_result[0].get('action_type', '未知动作')
action_type = action_result[0].get("action_type", "未知动作")
elif isinstance(loop_plan_info, list) and loop_plan_info:
# 直接是actions列表的情况
action_type = loop_plan_info[0].get('action_type', '未知动作')
action_type = loop_plan_info[0].get("action_type", "未知动作")
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
f"选择动作: {action_type}"
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
def _determine_form_type(self) -> None:
"""判断使用哪种形式的no_action"""
# 如果连续no_action次数少于3次使用waiting形式
@ -195,42 +191,44 @@ class HeartFChatting:
else:
# 计算最近三次记录的兴趣度总和
total_recent_interest = sum(self.recent_interest_records)
# 计算调整后的阈值
adjusted_threshold = 1 / self.talk_frequency_control.get_current_talk_frequency()
logger.info(f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
logger.info(
f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}"
)
# 如果兴趣度总和小于阈值进入breaking形式
if total_recent_interest < adjusted_threshold:
logger.info(f"{self.log_prefix} 兴趣度不足,进入休息")
self.focus_energy = random.randint(3, 6)
else:
logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息")
self.focus_energy = 1
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool,float]:
self.focus_energy = 1
async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]:
"""
判断是否应该处理消息
Args:
new_message: 新消息列表
mode: 当前聊天模式
Returns:
bool: 是否应该处理消息
"""
new_message_count = len(new_message)
talk_frequency = self.talk_frequency_control.get_current_talk_frequency()
modified_exit_count_threshold = self.focus_energy * 0.5 / talk_frequency
modified_exit_interest_threshold = 1.5 / talk_frequency
total_interest = 0.0
for msg_dict in new_message:
interest_value = msg_dict.get("interest_value")
if interest_value is not None and msg_dict.get("processed_plain_text", ""):
for msg in new_message:
interest_value = msg.interest_value
if interest_value is not None and msg.processed_plain_text:
total_interest += float(interest_value)
if new_message_count >= modified_exit_count_threshold:
self.recent_interest_records.append(total_interest)
logger.info(
@ -244,9 +242,11 @@ class HeartFChatting:
if new_message_count > 0:
# 只在兴趣值变化时输出log
if not hasattr(self, "_last_accumulated_interest") or total_interest != self._last_accumulated_interest:
logger.info(f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}")
logger.info(
f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}"
)
self._last_accumulated_interest = total_interest
if total_interest >= modified_exit_interest_threshold:
# 记录兴趣度到列表
self.recent_interest_records.append(total_interest)
@ -261,29 +261,25 @@ class HeartFChatting:
f"{self.log_prefix} 已等待{time.time() - self.last_read_time:.0f}秒,累计{new_message_count}条消息,累计兴趣{total_interest:.1f},继续等待..."
)
await asyncio.sleep(0.5)
return False,0.0
return False, 0.0
async def _loopbody(self):
recent_messages_dict = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=self.last_read_time,
end_time=time.time(),
limit = 10,
limit=10,
limit_mode="latest",
filter_mai=True,
filter_command=True,
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
temp_recent_messages_dict = [temporarily_transform_class_to_dict(msg) for msg in recent_messages_dict]
)
# 统一的消息处理逻辑
should_process,interest_value = await self._should_process_messages(temp_recent_messages_dict)
should_process, interest_value = await self._should_process_messages(recent_messages_dict)
if should_process:
self.last_read_time = time.time()
await self._observe(interest_value = interest_value)
await self._observe(interest_value=interest_value)
else:
# Normal模式消息数量不足等待
@ -298,22 +294,21 @@ class HeartFChatting:
cycle_timers: Dict[str, float],
thinking_id,
actions,
selected_expressions:List[int] = None,
selected_expressions: List[int] = None,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers):
reply_text = await self._send_response(
reply_set=response_set,
message_data=action_message,
selected_expressions=selected_expressions,
)
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
platform = action_message.get("chat_info_platform")
if platform is None:
platform = getattr(self.chat_stream, "platform", "unknown")
person = Person(platform = platform ,user_id = action_message.get("user_id", ""))
person = Person(platform=platform, user_id=action_message.get("user_id", ""))
person_name = person.person_name
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
@ -342,12 +337,10 @@ class HeartFChatting:
return loop_info, reply_text, cycle_timers
async def _observe(self,interest_value:float = 0.0) -> bool:
async def _observe(self, interest_value: float = 0.0) -> bool:
action_type = "no_action"
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
# 使用sigmoid函数将interest_value转换为概率
# 当interest_value为0时概率接近0使用Focus模式
# 当interest_value很高时概率接近1使用Normal模式
@ -361,12 +354,14 @@ class HeartFChatting:
x0 = 1.0 # 控制曲线中心点
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 0.5 / self.talk_frequency_control.get_current_talk_frequency()
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 2 * self.talk_frequency_control.get_current_talk_frequency()
# 根据概率决定使用哪种模式
if random.random() < normal_mode_probability:
mode = ChatMode.NORMAL
logger.info(f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability*100:.0f}%概率下选择回复")
logger.info(
f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability * 100:.0f}%概率下选择回复"
)
else:
mode = ChatMode.FOCUS
@ -387,10 +382,9 @@ class HeartFChatting:
await hippocampus_manager.build_memory_for_chat(self.stream_id)
except Exception as e:
logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS:
#如果激活度没有激活并且聊天活跃度低有可能不进行plan相当于不在电脑前不进行认真思考
# 如果激活度没有激活并且聊天活跃度低有可能不进行plan相当于不在电脑前不进行认真思考
actions = [
{
"action_type": "no_action",
@ -420,23 +414,21 @@ class HeartFChatting:
):
return False
with Timer("规划器", cycle_timers):
actions, _= await self.action_planner.plan(
actions, _ = await self.action_planner.plan(
mode=mode,
loop_start_time=self.last_read_time,
available_actions=available_actions,
)
# 3. 并行执行所有动作
async def execute_action(action_info,actions):
async def execute_action(action_info, actions):
"""执行单个动作的通用函数"""
try:
if action_info["action_type"] == "no_action":
# 直接处理no_action逻辑不再通过动作系统
reason = action_info.get("reasoning", "选择不回复")
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_action信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
@ -447,13 +439,8 @@ class HeartFChatting:
action_data={"reason": reason},
action_name="no_action",
)
return {
"action_type": "no_action",
"success": True,
"reply_text": "",
"command": ""
}
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
elif action_info["action_type"] != "reply":
# 执行普通动作
with Timer("动作执行", cycle_timers):
@ -463,20 +450,19 @@ class HeartFChatting:
action_info["action_data"],
cycle_timers,
thinking_id,
action_info["action_message"]
action_info["action_message"],
)
return {
"action_type": action_info["action_type"],
"success": success,
"reply_text": reply_text,
"command": command
"command": command,
}
else:
try:
success, response_set, prompt_selected_expressions = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message = action_info["action_message"],
reply_message=action_info["action_message"],
available_actions=available_actions,
choosen_actions=actions,
reply_reason=action_info.get("reasoning", ""),
@ -485,29 +471,21 @@ class HeartFChatting:
from_plugin=False,
return_expressions=True,
)
if prompt_selected_expressions and len(prompt_selected_expressions) > 1:
_,selected_expressions = prompt_selected_expressions
_, selected_expressions = prompt_selected_expressions
else:
selected_expressions = []
if not success or not response_set:
logger.info(f"{action_info['action_message'].get('processed_plain_text')} 的回复生成失败")
return {
"action_type": "reply",
"success": False,
"reply_text": "",
"loop_info": None
}
logger.info(
f"{action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
)
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {
"action_type": "reply",
"success": False,
"reply_text": "",
"loop_info": None
}
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
response_set=response_set,
@ -521,7 +499,7 @@ class HeartFChatting:
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"loop_info": loop_info
"loop_info": loop_info,
}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
@ -531,26 +509,26 @@ class HeartFChatting:
"success": False,
"reply_text": "",
"loop_info": None,
"error": str(e)
"error": str(e),
}
action_tasks = [asyncio.create_task(execute_action(action,actions)) for action in actions]
action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions]
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
# 处理执行结果
reply_loop_info = None
reply_text_from_reply = ""
action_success = False
action_reply_text = ""
action_command = ""
for i, result in enumerate(results):
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
_cur_action = actions[i]
if result["action_type"] != "reply":
action_success = result["success"]
@ -590,7 +568,6 @@ class HeartFChatting:
},
}
reply_text = action_reply_text
if s4u_config.enable_s4u:
await stop_typing()
@ -602,7 +579,7 @@ class HeartFChatting:
# await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
action_type = actions[0]["action_type"] if actions else "no_action"
# 管理no_action计数器当执行了非no_action动作时重置计数器
if action_type != "no_action":
# no_action逻辑已集成到heartFC_chat.py中直接重置计数器
@ -610,7 +587,7 @@ class HeartFChatting:
self.no_action_consecutive = 0
logger.debug(f"{self.log_prefix} 执行了{action_type}动作重置no_action计数器")
return True
if action_type == "no_action":
self.no_action_consecutive += 1
self._determine_form_type()
@ -692,11 +669,12 @@ class HeartFChatting:
traceback.print_exc()
return False, "", ""
async def _send_response(self,
reply_set,
message_data,
selected_expressions:List[int] = None,
) -> str:
async def _send_response(
self,
reply_set,
message_data,
selected_expressions: List[int] = None,
) -> str:
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
@ -714,7 +692,7 @@ class HeartFChatting:
await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_message = message_data,
reply_message=message_data,
set_reply=need_reply,
typing=False,
selected_expressions=selected_expressions,
@ -724,7 +702,7 @@ class HeartFChatting:
await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_message = message_data,
reply_message=message_data,
set_reply=False,
typing=True,
selected_expressions=selected_expressions,

View File

@ -709,36 +709,36 @@ class EmojiManager:
return emoji
return None # 如果循环结束还没找到,则返回 None
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
"""根据哈希值获取已注册表情包的描述
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]:
"""根据哈希值获取已注册表情包的情感标签列表
Args:
emoji_hash: 表情包的哈希值
Returns:
Optional[str]: 表情包描述如果未找到则返回None
Optional[List[str]]: 情感标签列表如果未找到则返回None
"""
try:
# 先从内存中查找
emoji = await self.get_emoji_from_manager(emoji_hash)
if emoji and emoji.emotion:
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.emotion}...")
return ",".join(emoji.emotion)
logger.info(f"[缓存命中] 从内存获取表情包情感标签: {emoji.emotion}...")
return emoji.emotion
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
return emoji_record.emotion
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.split(',')
except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}")
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
return None
except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
logger.error(f"获取表情包情感标签失败 (Hash: {emoji_hash}): {str(e)}")
return None
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:

View File

@ -8,6 +8,7 @@ from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.data_models.database_data_model import DatabaseMessages
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
@ -346,21 +347,17 @@ class ExpressionLearner:
current_time = time.time()
# 获取上次学习时间
random_msg_temp = get_raw_msg_by_timestamp_with_chat_inclusive(
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=current_time,
limit=num,
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
random_msg: Optional[List[Dict[str, Any]]] = [temporarily_transform_class_to_dict(msg) for msg in random_msg_temp] if random_msg_temp else None
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 转化成str
chat_id: str = random_msg[0]["chat_id"]
chat_id: str = random_msg[0].chat_id
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
random_msg_str: str = await build_anonymous_messages(random_msg)
# print(f"random_msg_str:{random_msg_str}")

View File

@ -117,30 +117,36 @@ class EmbeddingStore:
self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]:
"""获取字符串的嵌入向量,处理异步调用"""
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 尝试获取当前事件循环
asyncio.get_running_loop()
# 如果在事件循环中,使用线程池执行
import concurrent.futures
def run_in_thread():
return asyncio.run(get_embedding(s))
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
result = future.result()
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
except RuntimeError:
# 没有运行的事件循环,直接运行
result = asyncio.run(get_embedding(s))
if result is None:
# 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
# 使用新的事件循环运行异步方法
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
if embedding and len(embedding) > 0:
return embedding
else:
logger.error(f"获取嵌入失败: {s}")
return []
return result
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
return []
finally:
# 确保事件循环被正确关闭
try:
loop.close()
except Exception:
pass
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量
@ -181,8 +187,14 @@ class EmbeddingStore:
for i, s in enumerate(chunk_strs):
try:
# 直接使用异步函数
embedding = asyncio.run(llm.get_embedding(s))
# 在线程中创建独立的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
embedding = loop.run_until_complete(llm.get_embedding(s))
finally:
loop.close()
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:

View File

@ -9,7 +9,6 @@ import networkx as nx
import numpy as np
from typing import List, Tuple, Set, Coroutine, Any, Dict
from collections import Counter
from itertools import combinations
import traceback
from rich.traceback import install
@ -23,6 +22,8 @@ from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_by_timestamp_with_chat_inclusive,
) # 导入 build_readable_messages
# 添加cosine_similarity函数
def cosine_similarity(v1, v2):
"""计算余弦相似度"""
@ -51,18 +52,9 @@ def calculate_information_content(text):
return entropy
logger = get_logger("memory")
class MemoryGraph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
@ -96,7 +88,7 @@ class MemoryGraph:
if "memory_items" in self.G.nodes[concept]:
# 获取现有的记忆项已经是str格式
existing_memory = self.G.nodes[concept]["memory_items"]
# 如果现有记忆不为空则使用LLM整合新旧记忆
if existing_memory and hippocampus_instance and hippocampus_instance.model_small:
try:
@ -170,16 +162,16 @@ class MemoryGraph:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
async def _integrate_memories_with_llm(self, existing_memory: str, new_memory: str, llm_model: LLMRequest) -> str:
"""
使用LLM整合新旧记忆内容
Args:
existing_memory: 现有的记忆内容字符串格式可能包含多条记忆
new_memory: 新的记忆内容
llm_model: LLM模型实例
Returns:
str: 整合后的记忆内容
"""
@ -203,8 +195,10 @@ class MemoryGraph:
整合后的记忆"""
# 调用LLM进行整合
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(integration_prompt)
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(
integration_prompt
)
if content and content.strip():
integrated_content = content.strip()
logger.debug(f"LLM记忆整合成功模型: {model_name}")
@ -212,7 +206,7 @@ class MemoryGraph:
else:
logger.warning("LLM返回的整合结果为空使用默认连接方式")
return f"{existing_memory} | {new_memory}"
except Exception as e:
logger.error(f"LLM记忆整合过程中出错: {e}")
return f"{existing_memory} | {new_memory}"
@ -238,7 +232,11 @@ class MemoryGraph:
if memory_items:
# 删除整个节点
self.G.remove_node(topic)
return f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." if len(memory_items) > 50 else f"删除了节点 {topic} 的完整记忆: {memory_items}"
return (
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
if len(memory_items) > 50
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
)
else:
# 如果没有记忆项,删除该节点
self.G.remove_node(topic)
@ -263,38 +261,40 @@ class Hippocampus:
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.modify")
self.model_small = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="memory.modify"
)
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
return list(self.memory_graph.G.nodes())
def calculate_weighted_activation(self, current_activation: float, edge_strength: int, target_node: str) -> float:
"""
计算考虑节点权重的激活值
Args:
current_activation: 当前激活值
edge_strength: 边的强度
target_node: 目标节点名称
Returns:
float: 计算后的激活值
"""
# 基础激活值计算
base_activation = current_activation - (1 / edge_strength)
if base_activation <= 0:
return 0.0
# 获取目标节点的权重
if target_node in self.memory_graph.G:
node_data = self.memory_graph.G.nodes[target_node]
node_weight = node_data.get("weight", 1.0)
# 权重加成每次整合增加10%激活值最大加成200%
weight_multiplier = 1.0 + min((node_weight - 1.0) * 0.1, 2.0)
return base_activation * weight_multiplier
else:
return base_activation
@ -332,9 +332,7 @@ class Hippocampus:
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
f"如果确定找不出主题或者没有明显主题,返回<none>。"
)
return prompt
@staticmethod
@ -403,7 +401,7 @@ class Hippocampus:
memories.sort(key=lambda x: x[2], reverse=True)
return memories
async def get_keywords_from_text(self, text: str) -> list:
async def get_keywords_from_text(self, text: str) -> Tuple[List[str], List]:
"""从文本中提取关键词。
Args:
@ -418,16 +416,13 @@ class Hippocampus:
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
text_length = len(text)
topic_num: int | list[int] = 0
words = jieba.cut(text)
keywords_lite = [word for word in words if len(word) > 1]
keywords_lite = list(set(keywords_lite))
if keywords_lite:
logger.debug(f"提取关键词极简版: {keywords_lite}")
if text_length <= 12:
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
elif text_length <= 20:
@ -455,7 +450,7 @@ class Hippocampus:
if keywords:
logger.debug(f"提取关键词: {keywords}")
return keywords,keywords_lite
return keywords, keywords_lite
async def get_memory_from_topic(
self,
@ -570,20 +565,17 @@ class Hippocampus:
for node, activation in remember_map.items():
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
node_data = self.memory_graph.G.nodes[node]
memory_items = node_data.get("memory_items", "")
# 直接使用完整的记忆内容
if memory_items:
if memory_items := node_data.get("memory_items", ""):
logger.debug("节点包含完整记忆")
# 计算记忆与关键词的相似度
memory_words = set(jieba.cut(memory_items))
text_words = set(keywords)
all_words = memory_words | text_words
if all_words:
if all_words := memory_words | text_words:
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
v1 = [1 if word in memory_words else 0 for word in all_words]
v2 = [1 if word in text_words else 0 for word in all_words]
_ = cosine_similarity(v1, v2) # 计算但不使用用_表示
# 添加完整记忆到结果中
all_memories.append((node, memory_items, activation))
else:
@ -613,7 +605,9 @@ class Hippocampus:
return result
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str],list[str]]:
async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str], list[str]]:
"""从文本中提取关键词并获取相关记忆。
Args:
@ -627,13 +621,13 @@ class Hippocampus:
float: 激活节点数与总节点数的比值
list[str]: 有效的关键词
"""
keywords,keywords_lite = await self.get_keywords_from_text(text)
keywords, keywords_lite = await self.get_keywords_from_text(text)
# 过滤掉不存在于记忆图中的关键词
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
if not valid_keywords:
# logger.info("没有找到有效的关键词节点")
return 0, keywords,keywords_lite
return 0, keywords, keywords_lite
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
@ -700,7 +694,7 @@ class Hippocampus:
activation_ratio = activation_ratio * 50
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
return activation_ratio, keywords,keywords_lite
return activation_ratio, keywords, keywords_lite
# 负责海马体与其他部分的交互
@ -730,7 +724,7 @@ class EntorhinalCortex:
continue
memory_items = data.get("memory_items", "")
# 直接检查字符串是否为空,不需要分割成列表
if not memory_items or memory_items.strip() == "":
self.memory_graph.G.remove_node(concept)
@ -865,7 +859,9 @@ class EntorhinalCortex:
end_time = time.time()
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}")
logger.info(f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边")
logger.info(
f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边"
)
async def resync_memory_to_db(self):
"""清空数据库并重新同步所有记忆数据"""
@ -888,7 +884,7 @@ class EntorhinalCortex:
nodes_data = []
for concept, data in memory_nodes:
memory_items = data.get("memory_items", "")
# 直接检查字符串是否为空,不需要分割成列表
if not memory_items or memory_items.strip() == "":
self.memory_graph.G.remove_node(concept)
@ -960,7 +956,7 @@ class EntorhinalCortex:
# 清空当前图
self.memory_graph.G.clear()
# 统计加载情况
total_nodes = 0
loaded_nodes = 0
@ -969,7 +965,7 @@ class EntorhinalCortex:
# 从数据库加载所有节点
nodes = list(GraphNodes.select())
total_nodes = len(nodes)
for node in nodes:
concept = node.concept
try:
@ -978,7 +974,7 @@ class EntorhinalCortex:
logger.warning(f"节点 {concept} 的memory_items为空跳过")
skipped_nodes += 1
continue
# 直接使用memory_items
memory_items = node.memory_items.strip()
@ -999,11 +995,15 @@ class EntorhinalCortex:
last_modified = node.last_modified or current_time
# 获取权重属性
weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0
weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
# 添加节点到图中
self.memory_graph.G.add_node(
concept, memory_items=memory_items, weight=weight, created_time=created_time, last_modified=last_modified
concept,
memory_items=memory_items,
weight=weight,
created_time=created_time,
last_modified=last_modified,
)
loaded_nodes += 1
except Exception as e:
@ -1044,9 +1044,11 @@ class EntorhinalCortex:
if need_update:
logger.info("[数据库] 已为缺失的时间字段进行补充")
# 输出加载统计信息
logger.info(f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes}")
logger.info(
f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes}"
)
# 负责整合,遗忘,合并记忆
@ -1054,10 +1056,12 @@ class ParahippocampalGyrus:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.memory_modify_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory.modify")
async def memory_compress(self, messages: list, compress_rate=0.1):
self.memory_modify_model = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="memory.modify"
)
async def memory_compress(self, messages: list[DatabaseMessages], compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。
Args:
@ -1083,7 +1087,6 @@ class ParahippocampalGyrus:
# build_readable_messages 只返回一个字符串,不需要解包
input_text = build_readable_messages(
messages,
merge_messages=True, # 合并连续消息
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
replace_bot_name=False, # 保留原始用户名
)
@ -1163,7 +1166,7 @@ class ParahippocampalGyrus:
similar_topics.sort(key=lambda x: x[1], reverse=True)
similar_topics = similar_topics[:3]
similar_topics_dict[topic] = similar_topics
if global_config.debug.show_prompt:
logger.info(f"prompt: {topic_what_prompt}")
logger.info(f"压缩后的记忆: {compressed_memory}")
@ -1259,14 +1262,14 @@ class ParahippocampalGyrus:
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
last_modified = node_data.get("last_modified", current_time)
node_weight = node_data.get("weight", 1.0)
# 条件1检查是否长时间未修改 (使用配置的遗忘时间)
time_threshold = 3600 * global_config.memory.memory_forget_time
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
# 权重为1时使用默认阈值权重越高阈值越大越难遗忘
adjusted_threshold = time_threshold * node_weight
if current_time - last_modified > adjusted_threshold and memory_items:
# 既然每个节点现在是完整记忆,直接删除整个节点
try:
@ -1315,8 +1318,6 @@ class ParahippocampalGyrus:
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}")
class HippocampusManager:
def __init__(self):
self._hippocampus: Hippocampus = None # type: ignore
@ -1361,29 +1362,32 @@ class HippocampusManager:
"""为指定chat_id构建记忆在heartFC_chat.py中调用"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
try:
# 检查是否需要构建记忆
logger.info(f"{chat_id} 构建记忆")
if memory_segment_manager.check_and_build_memory_for_chat(chat_id):
logger.info(f"{chat_id} 构建记忆,需要构建记忆")
messages = memory_segment_manager.get_messages_for_memory_build(chat_id, 50)
build_probability = 0.3 * global_config.memory.memory_build_frequency
if messages and random.random() < build_probability:
logger.info(f"{chat_id} 构建记忆,消息数量: {len(messages)}")
# 调用记忆压缩和构建
compressed_memory, similar_topics_dict = await self._hippocampus.parahippocampal_gyrus.memory_compress(
(
compressed_memory,
similar_topics_dict,
) = await self._hippocampus.parahippocampal_gyrus.memory_compress(
messages, global_config.memory.memory_compress_rate
)
# 添加记忆节点
current_time = time.time()
for topic, memory in compressed_memory:
await self._hippocampus.memory_graph.add_dot(topic, memory, self._hippocampus)
# 连接相似主题
if topic in similar_topics_dict:
similar_topics = similar_topics_dict[topic]
@ -1391,23 +1395,23 @@ class HippocampusManager:
if topic != similar_topic:
strength = int(similarity * 10)
self._hippocampus.memory_graph.G.add_edge(
topic, similar_topic,
topic,
similar_topic,
strength=strength,
created_time=current_time,
last_modified=current_time
last_modified=current_time,
)
# 同步到数据库
await self._hippocampus.entorhinal_cortex.sync_memory_to_db()
logger.info(f"{chat_id} 构建记忆完成")
return True
except Exception as e:
logger.error(f"{chat_id} 构建记忆失败: {e}")
return False
return False
return False
async def get_memory_from_topic(
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
@ -1424,16 +1428,20 @@ class HippocampusManager:
response = []
return response
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str], list[str]]:
"""从文本中获取激活值的公共接口"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
try:
response, keywords,keywords_lite = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
response, keywords, keywords_lite = await self._hippocampus.get_activate_from_text(
text, max_depth, fast_retrieval
)
except Exception as e:
logger.error(f"文本产生激活值失败: {e}")
logger.error(traceback.format_exc())
return 0.0, [],[]
return 0.0, [], []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
"""从关键词获取相关记忆的公共接口"""
@ -1455,81 +1463,78 @@ hippocampus_manager = HippocampusManager()
# 在Hippocampus类中添加新的记忆构建管理器
class MemoryBuilder:
"""记忆构建器
为每个chat_id维护消息缓存和触发机制类似ExpressionLearner
"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.last_update_time: float = time.time()
self.last_processed_time: float = 0.0
def should_trigger_memory_build(self) -> bool:
"""检查是否应该触发记忆构建"""
current_time = time.time()
# 检查时间间隔
time_diff = current_time - self.last_update_time
if time_diff < 600 /global_config.memory.memory_build_frequency:
if time_diff < 600 / global_config.memory.memory_build_frequency:
return False
# 检查消息数量
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_update_time,
timestamp_end=current_time,
)
logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}")
if not recent_messages or len(recent_messages) < 30/global_config.memory.memory_build_frequency :
if not recent_messages or len(recent_messages) < 30 / global_config.memory.memory_build_frequency:
return False
return True
def get_messages_for_memory_build(self, threshold: int = 25) -> List[Dict[str, Any]]:
def get_messages_for_memory_build(self, threshold: int = 25) -> List[DatabaseMessages]:
"""获取用于记忆构建的消息"""
current_time = time.time()
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_update_time,
timestamp_end=current_time,
limit=threshold,
)
tmp_msg = [msg.__dict__ for msg in messages] if messages else []
if messages:
# 更新最后处理时间
self.last_processed_time = current_time
self.last_update_time = current_time
return tmp_msg or []
return messages or []
class MemorySegmentManager:
"""记忆段管理器
管理所有chat_id的MemoryBuilder实例自动检查和触发记忆构建
"""
def __init__(self):
self.builders: Dict[str, MemoryBuilder] = {}
def get_or_create_builder(self, chat_id: str) -> MemoryBuilder:
"""获取或创建指定chat_id的MemoryBuilder"""
if chat_id not in self.builders:
self.builders[chat_id] = MemoryBuilder(chat_id)
return self.builders[chat_id]
def check_and_build_memory_for_chat(self, chat_id: str) -> bool:
"""检查指定chat_id是否需要构建记忆如果需要则返回True"""
builder = self.get_or_create_builder(chat_id)
return builder.should_trigger_memory_build()
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[Dict[str, Any]]:
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[DatabaseMessages]:
"""获取指定chat_id用于记忆构建的消息"""
if chat_id not in self.builders:
return []
@ -1538,4 +1543,3 @@ class MemorySegmentManager:
# 创建全局实例
memory_segment_manager = MemorySegmentManager()

View File

@ -1,17 +1,17 @@
import json
import random
from json_repair import repair_json
from typing import List, Tuple
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.utils.utils import parse_keywords_string
from src.chat.utils.chat_message_builder import build_readable_messages
import random
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
@ -75,19 +75,20 @@ class MemoryActivator:
request_type="memory.selection",
)
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]:
async def activate_memory_with_chat_history(
self, target_message, chat_history: List[DatabaseMessages]
) -> List[Tuple[str, str]]:
"""
激活记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
keywords_list = set()
for msg in chat_history_prompt:
keywords = parse_keywords_string(msg.get("key_words", ""))
for msg in chat_history:
keywords = parse_keywords_string(msg.key_words)
if keywords:
if len(keywords_list) < 30:
# 最多容纳30个关键词
@ -95,24 +96,22 @@ class MemoryActivator:
logger.debug(f"提取关键词: {keywords_list}")
else:
break
if not keywords_list:
logger.debug("没有提取到关键词,返回空记忆列表")
return []
# 从海马体获取相关记忆
related_memory = await hippocampus_manager.get_memory_from_topic(
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
)
# logger.info(f"当前记忆关键词: {keywords_list}")
logger.debug(f"获取到的记忆: {related_memory}")
if not related_memory:
logger.debug("海马体没有返回相关记忆")
return []
used_ids = set()
candidate_memories = []
@ -120,12 +119,7 @@ class MemoryActivator:
# 为每个记忆分配随机ID并过滤相关记忆
for memory in related_memory:
keyword, content = memory
found = False
for kw in keywords_list:
if kw in content:
found = True
break
found = any(kw in content for kw in keywords_list)
if found:
# 随机分配一个不重复的2位数id
while True:
@ -138,95 +132,83 @@ class MemoryActivator:
if not candidate_memories:
logger.info("没有找到相关的候选记忆")
return []
# 如果只有少量记忆,直接返回
if len(candidate_memories) <= 2:
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
# 使用 LLM 选择合适的记忆
selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories)
return selected_memories
async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]:
return await self._select_memories_with_llm(target_message, chat_history, candidate_memories)
async def _select_memories_with_llm(
self, target_message, chat_history: List[DatabaseMessages], candidate_memories
) -> List[Tuple[str, str]]:
"""
使用 LLM 选择合适的记忆
Args:
target_message: 目标消息
chat_history_prompt: 聊天历史
candidate_memories: 候选记忆列表每个记忆包含 memory_idkeywordcontent
Returns:
List[Tuple[str, str]]: 选择的记忆列表格式为 (keyword, content)
"""
try:
# 构建聊天历史字符串
obs_info_text = build_readable_messages(
chat_history_prompt,
chat_history,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
)
# 构建记忆信息字符串
memory_lines = []
for memory in candidate_memories:
memory_id = memory["memory_id"]
keyword = memory["keyword"]
content = memory["content"]
# 将 content 列表转换为字符串
if isinstance(content, list):
content_str = " | ".join(str(item) for item in content)
else:
content_str = str(content)
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
memory_info = "\n".join(memory_lines)
# 获取并格式化 prompt
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
formatted_prompt = prompt_template.format(
obs_info_text=obs_info_text,
target_message=target_message,
memory_info=memory_info
obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info
)
# 调用 LLM
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
formatted_prompt,
temperature=0.3,
max_tokens=150
formatted_prompt, temperature=0.3, max_tokens=150
)
if global_config.debug.show_prompt:
logger.info(f"记忆选择 prompt: {formatted_prompt}")
logger.info(f"LLM 记忆选择响应: {response}")
else:
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
logger.debug(f"LLM 记忆选择响应: {response}")
# 解析响应获取选择的记忆编号
try:
fixed_json = repair_json(response)
# 解析为 Python 对象
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
# 提取 memory_ids 字段
memory_ids_str = result.get("memory_ids", "")
# 解析逗号分隔的编号
if memory_ids_str:
# 提取 memory_ids 字段并解析逗号分隔的编号
if memory_ids_str := result.get("memory_ids", ""):
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
# 过滤掉空字符串和无效编号
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
@ -236,26 +218,24 @@ class MemoryActivator:
except Exception as e:
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
selected_memory_ids = []
# 根据编号筛选记忆
selected_memories = []
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
for memory_id in selected_memory_ids:
if memory_id in memory_id_to_memory:
selected_memories.append(memory_id_to_memory[memory_id])
selected_memories = [
memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory
]
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
except Exception as e:
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
# 出错时返回前3个候选记忆作为备选转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
init_prompt()

View File

@ -70,13 +70,10 @@ class ActionModifier:
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
chat_content = build_readable_messages(
temp_msg_list_before_now_half,
message_list_before_now_half,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,

View File

@ -33,6 +33,9 @@ def init_prompt():
{time_block}
{identity_block}
你现在需要根据聊天内容选择的合适的action来参与聊天
请你根据以下行事风格来决定action:
{plan_style}
{chat_context_description}以下是具体的聊天内容
{chat_content_block}
@ -280,11 +283,8 @@ class ActionPlanner:
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
temp_msg_list_before_now = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=temp_msg_list_before_now,
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.last_obs_time_mark,
truncate=True,
@ -388,6 +388,7 @@ class ActionPlanner:
action_options_text=action_options_block,
moderation_prompt=moderation_prompt_block,
identity_block=identity_block,
plan_style = global_config.personality.plan_style
)
return prompt, message_id_list
except Exception as e:

View File

@ -8,6 +8,7 @@ from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
@ -156,7 +157,7 @@ class DefaultReplyer:
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
choosen_actions: Optional[List[Dict[str, Any]]] = None,
chosen_actions: Optional[List[Dict[str, Any]]] = None,
enable_tool: bool = True,
from_plugin: bool = True,
stream_id: Optional[str] = None,
@ -171,7 +172,7 @@ class DefaultReplyer:
extra_info: 额外信息用于补充上下文
reply_reason: 回复原因
available_actions: 可用的动作信息字典
choosen_actions: 已选动作
chosen_actions: 已选动作
enable_tool: 是否启用工具调用
from_plugin: 是否来自插件
@ -189,7 +190,7 @@ class DefaultReplyer:
prompt, selected_expressions = await self.build_prompt_reply_context(
extra_info=extra_info,
available_actions=available_actions,
choosen_actions=choosen_actions,
chosen_actions=chosen_actions,
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,
@ -296,7 +297,7 @@ class DefaultReplyer:
if not sender:
return ""
if sender == global_config.bot.nickname:
return ""
@ -352,7 +353,7 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_memory_block(self, chat_history: List[Dict[str, Any]], target: str) -> str:
async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
"""构建记忆块
Args:
@ -369,7 +370,7 @@ class DefaultReplyer:
instant_memory = None
running_memories = await self.memory_activator.activate_memory_with_chat_history(
target_message=target, chat_history_prompt=chat_history
target_message=target, chat_history=chat_history
)
if global_config.memory.enable_instant_memory:
@ -433,7 +434,7 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}")
return ""
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
def _parse_reply_target(self, target_message: Optional[str]) -> Tuple[str, str]:
"""解析回复目标消息
Args:
@ -514,7 +515,7 @@ class DefaultReplyer:
return name, result, duration
def build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""
构建 s4u 风格的分离对话 prompt
@ -530,16 +531,16 @@ class DefaultReplyer:
bot_id = str(global_config.bot.qq_account)
# 过滤消息分离bot和目标用户的对话 vs 其他用户的对话
for msg_dict in message_list_before_now:
for msg in message_list_before_now:
try:
msg_user_id = str(msg_dict.get("user_id"))
reply_to = msg_dict.get("reply_to", "")
msg_user_id = str(msg.user_info.user_id)
reply_to = msg.reply_to
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
# bot 和目标用户的对话
core_dialogue_list.append(msg_dict)
core_dialogue_list.append(msg)
except Exception as e:
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
# 构建背景对话 prompt
all_dialogue_prompt = ""
@ -574,7 +575,6 @@ class DefaultReplyer:
core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
@ -640,7 +640,7 @@ class DefaultReplyer:
action_descriptions = ""
if available_actions:
action_descriptions = "你可以做以下这些动作:\n"
action_descriptions = "除了进行回复之外,你可以做以下这些动作,不过这些动作由另一个模型决定,\n"
for action_name, action_info in available_actions.items():
action_description = action_info.description
action_descriptions += f"- {action_name}: {action_description}\n"
@ -658,7 +658,7 @@ class DefaultReplyer:
choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
if choosen_action_descriptions:
action_descriptions += "根据聊天情况,决定在回复的同时做以下这些动作:\n"
action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n"
action_descriptions += choosen_action_descriptions
return action_descriptions
@ -668,7 +668,7 @@ class DefaultReplyer:
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
choosen_actions: Optional[List[Dict[str, Any]]] = None,
chosen_actions: Optional[List[Dict[str, Any]]] = None,
enable_tool: bool = True,
reply_message: Optional[Dict[str, Any]] = None,
) -> Tuple[str, List[int]]:
@ -679,7 +679,7 @@ class DefaultReplyer:
extra_info: 额外信息用于补充上下文
reply_reason: 回复原因
available_actions: 可用动作
choosen_actions: 已选动作
chosen_actions: 已选动作
enable_timeout: 是否启用超时处理
enable_tool: 是否启用工具调用
reply_message: 回复的原始消息
@ -712,27 +712,21 @@ class DefaultReplyer:
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size * 1,
)
temp_msg_list_before_long = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_long]
# TODO: 修复!
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
temp_msg_list_before_short = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_short]
chat_talking_prompt_short = build_readable_messages(
temp_msg_list_before_short,
message_list_before_short,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
@ -744,12 +738,12 @@ class DefaultReplyer:
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
),
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
self._time_and_run_task(self.build_memory_block(temp_msg_list_before_short, target), "memory_block"),
self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
),
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions, choosen_actions), "actions_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
)
# 任务名称中英文映射
@ -810,25 +804,9 @@ class DefaultReplyer:
else:
reply_target_block = ""
# if is_group_chat:
# chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
# else:
# chat_target_name = "对方"
# if self.chat_target_info:
# chat_target_name = (
# self.chat_target_info.get("person_name") or self.chat_target_info.get("user_nickname") or "对方"
# )
# chat_target_1 = await global_prompt_manager.format_prompt(
# "chat_target_private1", sender_name=chat_target_name
# )
# chat_target_2 = await global_prompt_manager.format_prompt(
# "chat_target_private2", sender_name=chat_target_name
# )
# 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
temp_msg_list_before_long, user_id, sender
message_list_before_now_long, user_id, sender
)
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
@ -879,7 +857,7 @@ class DefaultReplyer:
reason: str,
reply_to: str,
reply_message: Optional[Dict[str, Any]] = None,
) -> Tuple[str, List[int]]: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info)
@ -902,20 +880,16 @@ class DefaultReplyer:
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
chat_talking_prompt_half = build_readable_messages(
temp_msg_list_before_now_half,
message_list_before_now_half,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
)
# 并行执行2个构建任务
(expression_habits_block, selected_expressions), relation_info = await asyncio.gather(
(expression_habits_block, _), relation_info = await asyncio.gather(
self.build_expression_habits(chat_talking_prompt_half, target),
self.build_relation_info(sender, target),
)

View File

@ -1,4 +1,4 @@
import time # 导入 time 模块以获取当前时间
import time
import random
import re
@ -6,14 +6,17 @@ from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import MessageAndActionModel
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.person_info.person_info import Person, get_person_id
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
install(extra_lines=3)
logger = get_logger("chat_message_builder")
def replace_user_references_sync(
@ -349,7 +352,9 @@ def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[DatabaseMessages]:
def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@ -390,16 +395,16 @@ def num_new_messages_since_with_users(
def _build_readable_messages_internal(
messages: List[Dict[str, Any]],
messages: List[MessageAndActionModel],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
message_id_list: Optional[List[DatabaseMessages]] = None,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
# sourcery skip: use-getitem-for-re-match-groups
"""
内部辅助函数构建可读消息字符串和原始消息详情列表
@ -418,7 +423,7 @@ def _build_readable_messages_internal(
if not messages:
return "", [], pic_id_mapping or {}, pic_counter
message_details_raw: List[Tuple[float, str, str, bool]] = []
detailed_messages_raw: List[Tuple[float, str, str, bool]] = []
# 使用传入的映射字典,如果没有则创建新的
if pic_id_mapping is None:
@ -426,25 +431,26 @@ def _build_readable_messages_internal(
current_pic_counter = pic_counter
# 创建时间戳到消息ID的映射用于在消息前添加[id]标识符
timestamp_to_id = {}
timestamp_to_id_mapping: Dict[float, str] = {}
if message_id_list:
for item in message_id_list:
message = item.get("message", {})
timestamp = message.get("time")
for msg in message_id_list:
timestamp = msg.time
if timestamp is not None:
timestamp_to_id[timestamp] = item.get("id", "")
timestamp_to_id_mapping[timestamp] = msg.message_id
def process_pic_ids(content: str) -> str:
def process_pic_ids(content: Optional[str]) -> str:
"""处理内容中的图片ID将其替换为[图片x]格式"""
nonlocal current_pic_counter
if content is None:
logger.warning("Content is None when processing pic IDs.")
raise ValueError("Content is None")
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match):
def replace_pic_id(match: re.Match) -> str:
nonlocal current_pic_counter
nonlocal pic_counter
pic_id = match.group(1)
if pic_id not in pic_id_mapping:
pic_id_mapping[pic_id] = f"图片{current_pic_counter}"
current_pic_counter += 1
@ -453,42 +459,23 @@ def _build_readable_messages_internal(
return re.sub(pic_pattern, replace_pic_id, content)
# 1 & 2: 获取发送者信息并提取消息组件
for msg in messages:
# 检查是否是动作记录
if msg.get("is_action_record", False):
is_action = True
timestamp: float = msg.get("time") # type: ignore
content = msg.get("display_message", "")
# 1: 获取发送者信息并提取消息组件
for message in messages:
if message.is_action_record:
# 对于动作记录也处理图片ID
content = process_pic_ids(content)
message_details_raw.append((timestamp, global_config.bot.nickname, content, is_action))
content = process_pic_ids(message.display_message)
detailed_messages_raw.append((message.time, message.user_nickname, content, True))
continue
# 检查并修复缺少的user_info字段
if "user_info" not in msg:
# 创建user_info字段
msg["user_info"] = {
"platform": msg.get("user_platform", ""),
"user_id": msg.get("user_id", ""),
"user_nickname": msg.get("user_nickname", ""),
"user_cardname": msg.get("user_cardname", ""),
}
platform = message.user_platform
user_id = message.user_id
user_nickname = message.user_nickname
user_cardname = message.user_cardname
user_info = msg.get("user_info", {})
platform = user_info.get("platform")
user_id = user_info.get("user_id")
user_nickname = user_info.get("user_nickname")
user_cardname = user_info.get("user_cardname")
timestamp: float = msg.get("time") # type: ignore
content: str
if msg.get("display_message"):
content = msg.get("display_message", "")
else:
content = msg.get("processed_plain_text", "") # 默认空字符串
timestamp = message.time
content = message.display_message or message.processed_plain_text or ""
# 向下兼容
if "" in content:
content = content.replace("", "")
if "" in content:
@ -504,52 +491,32 @@ def _build_readable_messages_internal(
person = Person(platform=platform, user_id=user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称
person_name: str
person_name = (
person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
)
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = person.person_name or user_id # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
if user_cardname:
person_name = f"昵称:{user_cardname}"
elif user_nickname:
person_name = f"{user_nickname}"
else:
person_name = "某人"
# 使用独立函数处理用户引用格式
content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name)
if content := replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name):
detailed_messages_raw.append((timestamp, person_name, content, False))
target_str = "这是QQ的一个功能用于提及某人但没那么明显"
if target_str in content and random.random() < 0.6:
content = content.replace(target_str, "")
if content != "":
message_details_raw.append((timestamp, person_name, content, False))
if not message_details_raw:
if not detailed_messages_raw:
return "", [], pic_id_mapping, current_pic_counter
message_details_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面
detailed_messages_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面
detailed_message: List[Tuple[float, str, str, bool]] = []
# 为每条消息添加一个标记,指示它是否是动作记录
message_details_with_flags = []
for timestamp, name, content, is_action in message_details_raw:
message_details_with_flags.append((timestamp, name, content, is_action))
# 应用截断逻辑 (如果 truncate 为 True)
message_details: List[Tuple[float, str, str, bool]] = []
n_messages = len(message_details_with_flags)
if truncate and n_messages > 0:
for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags):
# 2. 应用消息截断逻辑
messages_count = len(detailed_messages_raw)
if truncate and messages_count > 0:
for i, (timestamp, name, content, is_action) in enumerate(detailed_messages_raw):
# 对于动作记录,不进行截断
if is_action:
message_details.append((timestamp, name, content, is_action))
detailed_message.append((timestamp, name, content, is_action))
continue
percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
percentile = i / messages_count # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
original_len = len(content)
limit = -1 # 默认不截断
@ -562,116 +529,42 @@ def _build_readable_messages_internal(
elif percentile < 0.7: # 60% 到 80% 之前的消息 (即中间的 20%)
limit = 200
replace_content = "......(内容太长了)"
elif percentile < 1.0: # 80% 到 100% 之前的消息 (即较新的 20%)
elif percentile <= 1.0: # 80% 到 100% 之前的消息 (即较新的 20%)
limit = 400
replace_content = "......太长了)"
replace_content = "......内容太长了)"
truncated_content = content
if 0 < limit < original_len:
truncated_content = f"{content[:limit]}{replace_content}"
message_details.append((timestamp, name, truncated_content, is_action))
detailed_message.append((timestamp, name, truncated_content, is_action))
else:
# 如果不截断,直接使用原始列表
message_details = message_details_with_flags
detailed_message = detailed_messages_raw
# 3: 合并连续消息 (如果 merge_messages 为 True)
merged_messages = []
if merge_messages and message_details:
# 初始化第一个合并块
current_merge = {
"name": message_details[0][1],
"start_time": message_details[0][0],
"end_time": message_details[0][0],
"content": [message_details[0][2]],
"is_action": message_details[0][3],
}
# 3: 格式化为字符串
output_lines: List[str] = []
for i in range(1, len(message_details)):
timestamp, name, content, is_action = message_details[i]
for timestamp, name, content, is_action in detailed_message:
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
# 对于动作记录,不进行合并
if is_action or current_merge["is_action"]:
# 保存当前的合并块
merged_messages.append(current_merge)
# 创建新的块
current_merge = {
"name": name,
"start_time": timestamp,
"end_time": timestamp,
"content": [content],
"is_action": is_action,
}
continue
# 查找消息id如果有并构建id_prefix
message_id = timestamp_to_id_mapping.get(timestamp)
id_prefix = f"[{message_id}]" if message_id else ""
# 如果是同一个人发送的连续消息且时间间隔小于等于60秒
if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60):
current_merge["content"].append(content)
current_merge["end_time"] = timestamp # 更新最后消息时间
else:
# 保存上一个合并块
merged_messages.append(current_merge)
# 开始新的合并块
current_merge = {
"name": name,
"start_time": timestamp,
"end_time": timestamp,
"content": [content],
"is_action": is_action,
}
# 添加最后一个合并块
merged_messages.append(current_merge)
elif message_details: # 如果不合并消息,则每个消息都是一个独立的块
for timestamp, name, content, is_action in message_details:
merged_messages.append(
{
"name": name,
"start_time": timestamp, # 起始和结束时间相同
"end_time": timestamp,
"content": [content], # 内容只有一个元素
"is_action": is_action,
}
)
# 4 & 5: 格式化为字符串
output_lines = []
for _i, merged in enumerate(merged_messages):
# 使用指定的 timestamp_mode 格式化时间
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
# 查找对应的消息ID
message_id = timestamp_to_id.get(merged["start_time"], "")
id_prefix = f"[{message_id}] " if message_id else ""
# 检查是否是动作记录
if merged["is_action"]:
if is_action:
# 对于动作记录,使用特殊格式
output_lines.append(f"{id_prefix}{readable_time}, {merged['content'][0]}")
output_lines.append(f"{id_prefix}{readable_time}, {content}")
else:
header = f"{id_prefix}{readable_time}, {merged['name']} :"
output_lines.append(header)
# 将内容合并,并添加缩进
for line in merged["content"]:
stripped_line = line.strip()
if stripped_line: # 过滤空行
# 移除末尾句号,添加分号 - 这个逻辑似乎有点奇怪,暂时保留
if stripped_line.endswith(""):
stripped_line = stripped_line[:-1]
# 如果内容被截断,结尾已经是 ...(内容太长),不再添加分号
if not stripped_line.endswith("(内容太长)"):
output_lines.append(f"{stripped_line}")
else:
output_lines.append(stripped_line) # 直接添加截断后的内容
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
# 移除可能的多余换行,然后合并
formatted_string = "".join(output_lines).strip()
# 返回格式化后的字符串、消息详情列表、图片映射字典和更新后的计数器
return (
formatted_string,
[(t, n, c) for t, n, c, is_action in message_details if not is_action],
[(t, n, c) for t, n, c, is_action in detailed_message if not is_action],
pic_id_mapping,
current_pic_counter,
)
@ -755,9 +648,8 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
async def build_readable_messages_with_list(
messages: List[Dict[str, Any]],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
@ -766,7 +658,7 @@ async def build_readable_messages_with_list(
允许通过参数控制格式化行为
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
convert_DatabaseMessages_to_MessageAndActionModel(messages), replace_bot_name, timestamp_mode, truncate
)
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
@ -776,15 +668,14 @@ async def build_readable_messages_with_list(
def build_readable_messages_with_id(
messages: List[Dict[str, Any]],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
) -> Tuple[str, List[Dict[str, Any]]]:
) -> Tuple[str, List[DatabaseMessages]]:
"""
将消息列表转换为可读的文本格式并返回原始(时间戳, 昵称, 内容)列表
允许通过参数控制格式化行为
@ -794,7 +685,6 @@ def build_readable_messages_with_id(
formatted_string = build_readable_messages(
messages=messages,
replace_bot_name=replace_bot_name,
merge_messages=merge_messages,
timestamp_mode=timestamp_mode,
truncate=truncate,
show_actions=show_actions,
@ -807,15 +697,14 @@ def build_readable_messages_with_id(
def build_readable_messages(
messages: List[Dict[str, Any]],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
message_id_list: Optional[List[DatabaseMessages]] = None,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式
@ -831,19 +720,32 @@ def build_readable_messages(
truncate: 是否截断长消息
show_actions: 是否显示动作记录
"""
# WIP HERE and BELOW ----------------------------------------------
# 创建messages的深拷贝避免修改原始列表
if not messages:
return ""
copy_messages = [msg.copy() for msg in messages]
copy_messages: List[MessageAndActionModel] = [
MessageAndActionModel(
msg.time,
msg.user_info.user_id,
msg.user_info.platform,
msg.user_info.user_nickname,
msg.user_info.user_cardname,
msg.processed_plain_text,
msg.display_message,
msg.chat_info.platform,
)
for msg in messages
]
if show_actions and copy_messages:
# 获取所有消息的时间范围
min_time = min(msg.get("time", 0) for msg in copy_messages)
max_time = max(msg.get("time", 0) for msg in copy_messages)
min_time = min(msg.time or 0 for msg in copy_messages)
max_time = max(msg.time or 0 for msg in copy_messages)
# 从第一条消息中获取chat_id
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
chat_id = messages[0].chat_id if messages else None
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (
@ -863,34 +765,34 @@ def build_readable_messages(
)
# 合并两部分动作记录
actions = list(actions_in_range) + list(action_after_latest)
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
# 将动作记录转换为消息格式
for action in actions:
# 只有当build_into_prompt为True时才添加动作记录
if action.action_build_into_prompt:
action_msg = {
"time": action.time,
"user_id": global_config.bot.qq_account, # 使用机器人的QQ账号
"user_nickname": global_config.bot.nickname, # 使用机器人的昵称
"user_cardname": "", # 机器人没有群名片
"processed_plain_text": f"{action.action_prompt_display}",
"display_message": f"{action.action_prompt_display}",
"chat_info_platform": action.chat_info_platform,
"is_action_record": True, # 添加标识字段
"action_name": action.action_name, # 保存动作名称
}
action_msg = MessageAndActionModel(
time=float(action.time), # type: ignore
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
user_platform=global_config.bot.platform, # 使用机器人的平台
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
user_cardname="", # 机器人没有群名片
processed_plain_text=f"{action.action_prompt_display}",
display_message=f"{action.action_prompt_display}",
chat_info_platform=str(action.chat_info_platform),
is_action_record=True, # 添加标识字段
action_name=str(action.action_name), # 保存动作名称
)
copy_messages.append(action_msg)
# 重新按时间排序
copy_messages.sort(key=lambda x: x.get("time", 0))
copy_messages.sort(key=lambda x: x.time or 0)
if read_mark <= 0:
# 没有有效的 read_mark直接格式化所有消息
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
copy_messages,
replace_bot_name,
merge_messages,
timestamp_mode,
truncate,
show_pic=show_pic,
@ -905,8 +807,8 @@ def build_readable_messages(
return formatted_string
else:
# 按 read_mark 分割消息
messages_before_mark = [msg for msg in copy_messages if msg.get("time", 0) <= read_mark]
messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark]
messages_before_mark = [msg for msg in copy_messages if (msg.time or 0) <= read_mark]
messages_after_mark = [msg for msg in copy_messages if (msg.time or 0) > read_mark]
# 共享的图片映射字典和计数器
pic_id_mapping = {}
@ -916,7 +818,6 @@ def build_readable_messages(
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
messages_before_mark,
replace_bot_name,
merge_messages,
timestamp_mode,
truncate,
pic_id_mapping,
@ -927,7 +828,6 @@ def build_readable_messages(
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
messages_after_mark,
replace_bot_name,
merge_messages,
timestamp_mode,
False,
pic_id_mapping,
@ -960,13 +860,13 @@ def build_readable_messages(
return "".join(result_parts)
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
"""
构建匿名可读消息将不同人的名称转为唯一占位符ABC...bot自己用SELF
处理 回复<aaa:bbb> @<aaa:bbb> 字段将bbb映射为匿名占位符
"""
if not messages:
print("111111111111没有消息,无法构建匿名消息")
logger.warning("没有消息,无法构建匿名消息")
return ""
person_map = {}
@ -1017,14 +917,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
for msg in messages:
try:
platform: str = msg.get("chat_info_platform") # type: ignore
user_id = msg.get("user_id")
_timestamp = msg.get("time")
content: str = ""
if msg.get("display_message"):
content = msg.get("display_message", "")
else:
content = msg.get("processed_plain_text", "")
platform = msg.chat_info.platform
user_id = msg.user_info.user_id
content = msg.display_message or msg.processed_plain_text or ""
if "" in content:
content = content.replace("", "")
@ -1101,3 +996,22 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回
def convert_DatabaseMessages_to_MessageAndActionModel(message: List[DatabaseMessages]) -> List[MessageAndActionModel]:
"""
DatabaseMessages 列表转换为 MessageAndActionModel 列表
"""
return [
MessageAndActionModel(
time=msg.time,
user_id=msg.user_info.user_id,
user_platform=msg.user_info.platform,
user_nickname=msg.user_info.user_nickname,
user_cardname=msg.user_info.user_cardname,
processed_plain_text=msg.processed_plain_text,
display_message=msg.display_message,
chat_info_platform=msg.chat_info.platform,
)
for msg in message
]

View File

@ -12,6 +12,7 @@ from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger
from src.common.data_models.info_data_model import TargetPersonInfo
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
@ -113,6 +114,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量"""
# 每次都创建新的LLMRequest实例以避免事件循环冲突
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
try:
embedding, _ = await llm.get_embedding(text)
@ -151,10 +153,13 @@ def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> li
if (
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
and db_msg.user_info.user_id != global_config.bot.qq_account
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) not in who_chat_in_group
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
who_chat_in_group.append((db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname))
who_chat_in_group.append(
(db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
)
return who_chat_in_group
@ -640,9 +645,9 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
target_info = TargetPersonInfo(
platform=platform,
user_id=user_id,
user_nickname=user_info.user_nickname, # type: ignore
user_nickname=user_info.user_nickname, # type: ignore
person_id=None,
person_name=None
person_name=None,
)
# Try to fetch person info
@ -669,17 +674,17 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
return is_group_chat, chat_target_info
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
def assign_message_ids(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
"""
为消息列表中的每个消息分配唯一的简短随机ID
Args:
messages: 消息列表
Returns:
包含 {'id': str, 'message': any} 格式的字典列表
List[DatabaseMessages]: 分配了唯一ID的消息列表(写入message_id属性)
"""
result = []
result: List[DatabaseMessages] = list(messages) # 复制原始消息列表
used_ids = set()
len_i = len(messages)
if len_i > 100:
@ -688,95 +693,86 @@ def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
else:
a = 1
b = 9
for i, message in enumerate(messages):
for i, _ in enumerate(result):
# 生成唯一的简短ID
while True:
# 使用索引+随机数生成简短ID
random_suffix = random.randint(a, b)
message_id = f"m{i+1}{random_suffix}"
message_id = f"m{i + 1}{random_suffix}"
if message_id not in used_ids:
used_ids.add(message_id)
break
result.append({
'id': message_id,
'message': message
})
result[i].message_id = message_id
return result
def assign_message_ids_flexible(
messages: list,
prefix: str = "msg",
id_length: int = 6,
use_timestamp: bool = False
) -> list:
"""
为消息列表中的每个消息分配唯一的简短随机ID增强版
Args:
messages: 消息列表
prefix: ID前缀默认为"msg"
id_length: ID的总长度不包括前缀默认为6
use_timestamp: 是否在ID中包含时间戳默认为False
Returns:
包含 {'id': str, 'message': any} 格式的字典列表
"""
result = []
used_ids = set()
for i, message in enumerate(messages):
# 生成唯一的ID
while True:
if use_timestamp:
# 使用时间戳的后几位 + 随机字符
timestamp_suffix = str(int(time.time() * 1000))[-3:]
remaining_length = id_length - 3
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
message_id = f"{prefix}{timestamp_suffix}{random_chars}"
else:
# 使用索引 + 随机字符
index_str = str(i + 1)
remaining_length = max(1, id_length - len(index_str))
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
message_id = f"{prefix}{index_str}{random_chars}"
if message_id not in used_ids:
used_ids.add(message_id)
break
result.append({
'id': message_id,
'message': message
})
return result
# def assign_message_ids_flexible(
# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
# ) -> list:
# """
# 为消息列表中的每个消息分配唯一的简短随机ID增强版
# Args:
# messages: 消息列表
# prefix: ID前缀默认为"msg"
# id_length: ID的总长度不包括前缀默认为6
# use_timestamp: 是否在ID中包含时间戳默认为False
# Returns:
# 包含 {'id': str, 'message': any} 格式的字典列表
# """
# result = []
# used_ids = set()
# for i, message in enumerate(messages):
# # 生成唯一的ID
# while True:
# if use_timestamp:
# # 使用时间戳的后几位 + 随机字符
# timestamp_suffix = str(int(time.time() * 1000))[-3:]
# remaining_length = id_length - 3
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{timestamp_suffix}{random_chars}"
# else:
# # 使用索引 + 随机字符
# index_str = str(i + 1)
# remaining_length = max(1, id_length - len(index_str))
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{index_str}{random_chars}"
# if message_id not in used_ids:
# used_ids.add(message_id)
# break
# result.append({"id": message_id, "message": message})
# return result
# 使用示例:
# messages = ["Hello", "World", "Test message"]
#
#
# # 基础版本
# result1 = assign_message_ids(messages)
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
#
#
# # 增强版本 - 自定义前缀和长度
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
#
#
# # 增强版本 - 使用时间戳
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
def parse_keywords_string(keywords_input) -> list[str]:
# sourcery skip: use-contextlib-suppress
"""
统一的关键词解析函数支持多种格式的关键词字符串解析
支持的格式
1. 字符串列表格式'["utils.py", "修改", "代码", "动作"]'
2. 斜杠分隔格式'utils.py/修改/代码/动作'
@ -784,25 +780,25 @@ def parse_keywords_string(keywords_input) -> list[str]:
4. 空格分隔格式'utils.py 修改 代码 动作'
5. 已经是列表的情况["utils.py", "修改", "代码", "动作"]
6. JSON格式字符串'{"keywords": ["utils.py", "修改", "代码", "动作"]}'
Args:
keywords_input: 关键词输入可以是字符串或列表
Returns:
list[str]: 解析后的关键词列表去除空白项
"""
if not keywords_input:
return []
# 如果已经是列表,直接处理
if isinstance(keywords_input, list):
return [str(k).strip() for k in keywords_input if str(k).strip()]
# 转换为字符串处理
keywords_str = str(keywords_input).strip()
if not keywords_str:
return []
try:
# 尝试作为JSON对象解析支持 {"keywords": [...]} 格式)
json_data = json.loads(keywords_str)
@ -815,7 +811,7 @@ def parse_keywords_string(keywords_input) -> list[str]:
return [str(k).strip() for k in json_data if str(k).strip()]
except (json.JSONDecodeError, ValueError):
pass
try:
# 尝试使用 ast.literal_eval 解析支持Python字面量格式
parsed = ast.literal_eval(keywords_str)
@ -823,15 +819,15 @@ def parse_keywords_string(keywords_input) -> list[str]:
return [str(k).strip() for k in parsed if str(k).strip()]
except (ValueError, SyntaxError):
pass
# 尝试不同的分隔符
separators = ['/', ',', ' ', '|', ';']
separators = ["/", ",", " ", "|", ";"]
for separator in separators:
if separator in keywords_str:
keywords_list = [k.strip() for k in keywords_str.split(separator) if k.strip()]
if len(keywords_list) > 1: # 确保分割有效
return keywords_list
# 如果没有分隔符,返回单个关键词
return [keywords_str] if keywords_str else []
return [keywords_str] if keywords_str else []

View File

@ -1,26 +1,28 @@
from typing import Dict, Any
import copy
from typing import Any
class AbstractClassFlag:
pass
class BaseDataModel:
def deepcopy(self):
return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else
"""
将对象或容器中的 AbstractClassFlag 子类类对象 AbstractClassFlag 实例
将对象或容器中的 BaseDataModel 子类类对象 BaseDataModel 实例
递归转换为普通 dict不修改原对象
- 对于类对象isinstance(value, type) issubclass(..., AbstractClassFlag)
- 对于类对象isinstance(value, type) issubclass(..., BaseDataModel)
读取类的 __dict__ 中非 dunder 项并递归转换
- 对于实例isinstance(value, AbstractClassFlag)读取 vars(instance) 并递归转换
- 对于实例isinstance(value, BaseDataModel)读取 vars(instance) 并递归转换
"""
def _transform(value: Any) -> Any:
# 值是类对象且为 AbstractClassFlag 的子类
if isinstance(value, type) and issubclass(value, AbstractClassFlag):
# 值是类对象且为 BaseDataModel 的子类
if isinstance(value, type) and issubclass(value, BaseDataModel):
return {k: _transform(v) for k, v in value.__dict__.items() if not k.startswith("__") and not callable(v)}
# 值是 AbstractClassFlag 的实例
if isinstance(value, AbstractClassFlag):
# 值是 BaseDataModel 的实例
if isinstance(value, BaseDataModel):
return {k: _transform(v) for k, v in vars(value).items()}
# 常见容器类型,递归处理

View File

@ -1,36 +1,41 @@
from typing import Optional, Dict, Any
from dataclasses import dataclass, field, fields, MISSING
from typing import Optional, Any
from dataclasses import dataclass, field
from . import BaseDataModel
from . import AbstractClassFlag
@dataclass
class DatabaseUserInfo(AbstractClassFlag):
class DatabaseUserInfo(BaseDataModel):
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
user_cardname: Optional[str] = None
def __post_init__(self):
assert isinstance(self.platform, str), "platform must be a string"
assert isinstance(self.user_id, str), "user_id must be a string"
assert isinstance(self.user_nickname, str), "user_nickname must be a string"
assert isinstance(self.user_cardname, str) or self.user_cardname is None, "user_cardname must be a string or None"
# def __post_init__(self):
# assert isinstance(self.platform, str), "platform must be a string"
# assert isinstance(self.user_id, str), "user_id must be a string"
# assert isinstance(self.user_nickname, str), "user_nickname must be a string"
# assert isinstance(self.user_cardname, str) or self.user_cardname is None, (
# "user_cardname must be a string or None"
# )
@dataclass
class DatabaseGroupInfo(AbstractClassFlag):
class DatabaseGroupInfo(BaseDataModel):
group_id: str = field(default_factory=str)
group_name: str = field(default_factory=str)
group_platform: Optional[str] = None
def __post_init__(self):
assert isinstance(self.group_id, str), "group_id must be a string"
assert isinstance(self.group_name, str), "group_name must be a string"
assert isinstance(self.group_platform, str) or self.group_platform is None, "group_platform must be a string or None"
# def __post_init__(self):
# assert isinstance(self.group_id, str), "group_id must be a string"
# assert isinstance(self.group_name, str), "group_name must be a string"
# assert isinstance(self.group_platform, str) or self.group_platform is None, (
# "group_platform must be a string or None"
# )
@dataclass
class DatabaseChatInfo(AbstractClassFlag):
class DatabaseChatInfo(BaseDataModel):
stream_id: str = field(default_factory=str)
platform: str = field(default_factory=str)
create_time: float = field(default_factory=float)
@ -38,123 +43,117 @@ class DatabaseChatInfo(AbstractClassFlag):
user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo)
group_info: Optional[DatabaseGroupInfo] = None
def __post_init__(self):
assert isinstance(self.stream_id, str), "stream_id must be a string"
assert isinstance(self.platform, str), "platform must be a string"
assert isinstance(self.create_time, float), "create_time must be a float"
assert isinstance(self.last_active_time, float), "last_active_time must be a float"
assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance"
assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, "group_info must be a DatabaseGroupInfo instance or None"
# def __post_init__(self):
# assert isinstance(self.stream_id, str), "stream_id must be a string"
# assert isinstance(self.platform, str), "platform must be a string"
# assert isinstance(self.create_time, float), "create_time must be a float"
# assert isinstance(self.last_active_time, float), "last_active_time must be a float"
# assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance"
# assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, (
# "group_info must be a DatabaseGroupInfo instance or None"
# )
@dataclass(init=False)
class DatabaseMessages(AbstractClassFlag):
# chat_info: DatabaseChatInfo
# user_info: DatabaseUserInfo
# group_info: Optional[DatabaseGroupInfo] = None
class DatabaseMessages(BaseDataModel):
def __init__(
self,
message_id: str = "",
time: float = 0.0,
chat_id: str = "",
reply_to: Optional[str] = None,
interest_value: Optional[float] = None,
key_words: Optional[str] = None,
key_words_lite: Optional[str] = None,
is_mentioned: Optional[bool] = None,
processed_plain_text: Optional[str] = None,
display_message: Optional[str] = None,
priority_mode: Optional[str] = None,
priority_info: Optional[str] = None,
additional_config: Optional[str] = None,
is_emoji: bool = False,
is_picid: bool = False,
is_command: bool = False,
is_notify: bool = False,
selected_expressions: Optional[str] = None,
user_id: str = "",
user_nickname: str = "",
user_cardname: Optional[str] = None,
user_platform: str = "",
chat_info_group_id: Optional[str] = None,
chat_info_group_name: Optional[str] = None,
chat_info_group_platform: Optional[str] = None,
chat_info_user_id: str = "",
chat_info_user_nickname: str = "",
chat_info_user_cardname: Optional[str] = None,
chat_info_user_platform: str = "",
chat_info_stream_id: str = "",
chat_info_platform: str = "",
chat_info_create_time: float = 0.0,
chat_info_last_active_time: float = 0.0,
**kwargs: Any,
):
self.message_id = message_id
self.time = time
self.chat_id = chat_id
self.reply_to = reply_to
self.interest_value = interest_value
message_id: str = field(default_factory=str)
time: float = field(default_factory=float)
chat_id: str = field(default_factory=str)
reply_to: Optional[str] = None
interest_value: Optional[float] = None
self.key_words = key_words
self.key_words_lite = key_words_lite
self.is_mentioned = is_mentioned
key_words: Optional[str] = None
key_words_lite: Optional[str] = None
is_mentioned: Optional[bool] = None
self.processed_plain_text = processed_plain_text
self.display_message = display_message
# 从 chat_info 扁平化而来的字段
# chat_info_stream_id: str = field(default_factory=str)
# chat_info_platform: str = field(default_factory=str)
# chat_info_user_platform: str = field(default_factory=str)
# chat_info_user_id: str = field(default_factory=str)
# chat_info_user_nickname: str = field(default_factory=str)
# chat_info_user_cardname: Optional[str] = None
# chat_info_group_platform: Optional[str] = None
# chat_info_group_id: Optional[str] = None
# chat_info_group_name: Optional[str] = None
# chat_info_create_time: float = field(default_factory=float)
# chat_info_last_active_time: float = field(default_factory=float)
self.priority_mode = priority_mode
self.priority_info = priority_info
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
# user_platform: str = field(default_factory=str)
# user_id: str = field(default_factory=str)
# user_nickname: str = field(default_factory=str)
# user_cardname: Optional[str] = None
self.additional_config = additional_config
self.is_emoji = is_emoji
self.is_picid = is_picid
self.is_command = is_command
self.is_notify = is_notify
processed_plain_text: Optional[str] = None # 处理后的纯文本消息
display_message: Optional[str] = None # 显示的消息
priority_mode: Optional[str] = None
priority_info: Optional[str] = None
additional_config: Optional[str] = None
is_emoji: bool = False
is_picid: bool = False
is_command: bool = False
is_notify: bool = False
selected_expressions: Optional[str] = None
# def __post_init__(self):
# if self.chat_info_group_id and self.chat_info_group_name:
# self.group_info = DatabaseGroupInfo(
# group_id=self.chat_info_group_id,
# group_name=self.chat_info_group_name,
# group_platform=self.chat_info_group_platform,
# )
# chat_user_info = DatabaseUserInfo(
# user_id=self.chat_info_user_id,
# user_nickname=self.chat_info_user_nickname,
# user_cardname=self.chat_info_user_cardname,
# platform=self.chat_info_user_platform,
# )
# self.chat_info = DatabaseChatInfo(
# stream_id=self.chat_info_stream_id,
# platform=self.chat_info_platform,
# create_time=self.chat_info_create_time,
# last_active_time=self.chat_info_last_active_time,
# user_info=chat_user_info,
# group_info=self.group_info,
# )
def __init__(self, **kwargs: Any):
defined = {f.name: f for f in fields(self.__class__)}
for name, f in defined.items():
if name in kwargs:
setattr(self, name, kwargs.pop(name))
elif f.default is not MISSING:
setattr(self, name, f.default)
else:
raise TypeError(f"缺失必需字段: {name}")
self.selected_expressions = selected_expressions
self.group_info: Optional[DatabaseGroupInfo] = None
self.user_info = DatabaseUserInfo(
user_id=kwargs.get("user_id"), # type: ignore
user_nickname=kwargs.get("user_nickname"), # type: ignore
user_cardname=kwargs.get("user_cardname"), # type: ignore
platform=kwargs.get("user_platform"), # type: ignore
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
platform=user_platform,
)
if kwargs.get("chat_info_group_id") and kwargs.get("chat_info_group_name"):
if chat_info_group_id and chat_info_group_name:
self.group_info = DatabaseGroupInfo(
group_id=kwargs.get("chat_info_group_id"), # type: ignore
group_name=kwargs.get("chat_info_group_name"), # type: ignore
group_platform=kwargs.get("chat_info_group_platform"), # type: ignore
group_id=chat_info_group_id,
group_name=chat_info_group_name,
group_platform=chat_info_group_platform,
)
chat_user_info = DatabaseUserInfo(
user_id=kwargs.get("chat_info_user_id"), # type: ignore
user_nickname=kwargs.get("chat_info_user_nickname"), # type: ignore
user_cardname=kwargs.get("chat_info_user_cardname"), # type: ignore
platform=kwargs.get("chat_info_user_platform"), # type: ignore
)
self.chat_info = DatabaseChatInfo(
stream_id=kwargs.get("chat_info_stream_id"), # type: ignore
platform=kwargs.get("chat_info_platform"), # type: ignore
create_time=kwargs.get("chat_info_create_time"), # type: ignore
last_active_time=kwargs.get("chat_info_last_active_time"), # type: ignore
user_info=chat_user_info,
stream_id=chat_info_stream_id,
platform=chat_info_platform,
create_time=chat_info_create_time,
last_active_time=chat_info_last_active_time,
user_info=DatabaseUserInfo(
user_id=chat_info_user_id,
user_nickname=chat_info_user_nickname,
user_cardname=chat_info_user_cardname,
platform=chat_info_user_platform,
),
group_info=self.group_info,
)
if kwargs:
for key, value in kwargs.items():
setattr(self, key, value)
# def __post_init__(self):
# assert isinstance(self.message_id, str), "message_id must be a string"
# assert isinstance(self.time, float), "time must be a float"
# assert isinstance(self.chat_id, str), "chat_id must be a string"
# assert isinstance(self.reply_to, str) or self.reply_to is None, "reply_to must be a string or None"
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
# "interest_value must be a float or None"
# )

View File

@ -1,8 +1,10 @@
from dataclasses import dataclass, field
from typing import Optional
from . import BaseDataModel
@dataclass
class TargetPersonInfo:
class TargetPersonInfo(BaseDataModel):
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)

View File

@ -0,0 +1,17 @@
from typing import Optional
from dataclasses import dataclass, field
from . import BaseDataModel
@dataclass
class MessageAndActionModel(BaseDataModel):
time: float = field(default_factory=float)
user_id: str = field(default_factory=str)
user_platform: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
user_cardname: Optional[str] = None
processed_plain_text: Optional[str] = None
display_message: Optional[str] = None
chat_info_platform: str = field(default_factory=str)
is_action_record: bool = field(default=False)
action_name: Optional[str] = None

View File

@ -46,6 +46,8 @@ class PersonalityConfig(ConfigBase):
reply_style: str = ""
"""表达风格"""
plan_style: str = ""
compress_personality: bool = True
"""是否压缩人格压缩后会精简人格信息节省token消耗并提高回复性能但是会丢失一些信息如果人设不长可以关闭"""

View File

@ -159,14 +159,23 @@ class ClientRegistry:
return decorator
def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient:
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
"""
获取注册的API客户端实例
Args:
api_provider: APIProvider实例
force_new: 是否强制创建新实例用于解决事件循环问题
Returns:
BaseClient: 注册的API客户端实例
"""
# 如果强制创建新实例,直接创建不使用缓存
if force_new:
if client_class := self.client_registry.get(api_provider.client_type):
return client_class(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = client_class(api_provider)

View File

@ -44,6 +44,14 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("Gemini客户端")
# gemini_thinking参数默认范围
# 不同模型的思考预算范围配置
THINKING_BUDGET_LIMITS = {
"gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True},
"gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True},
"gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False},
}
gemini_safe_settings = [
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
@ -328,6 +336,48 @@ class GeminiClient(BaseClient):
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
# 思维预算特殊值
THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定
THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用)
@staticmethod
def clamp_thinking_budget(tb: int, model_id: str):
"""
按模型限制思考预算范围仅支持指定的模型支持带数字后缀的新版本
"""
limits = None
matched_key = None
# 优先尝试精确匹配
if model_id in THINKING_BUDGET_LIMITS:
limits = THINKING_BUDGET_LIMITS[model_id]
matched_key = model_id
else:
# 按 key 长度倒序,保证更长的(更具体的,如 -lite优先
sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True)
for key in sorted_keys:
# 必须满足:完全等于 或者 前缀匹配(带 "-" 边界)
if model_id == key or model_id.startswith(key + "-"):
limits = THINKING_BUDGET_LIMITS[key]
matched_key = key
break
# 特殊值处理
if tb == GeminiClient.THINKING_BUDGET_AUTO:
return GeminiClient.THINKING_BUDGET_AUTO
if tb == GeminiClient.THINKING_BUDGET_DISABLED:
if limits and limits.get("can_disable", False):
return GeminiClient.THINKING_BUDGET_DISABLED
return limits["min"] if limits else GeminiClient.THINKING_BUDGET_AUTO
# 已知模型裁剪到范围
if limits:
return max(limits["min"], min(tb, limits["max"]))
# 未知模型,返回动态模式
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
return GeminiClient.THINKING_BUDGET_AUTO
async def get_response(
self,
model_info: ModelInfo,
@ -373,6 +423,19 @@ class GeminiClient(BaseClient):
messages = _convert_messages(message_list)
# 将tool_options转换为Gemini API所需的格式
tools = _convert_tool_options(tool_options) if tool_options else None
tb = GeminiClient.THINKING_BUDGET_AUTO
#空处理
if extra_params and "thinking_budget" in extra_params:
try:
tb = int(extra_params["thinking_budget"])
except (ValueError, TypeError):
logger.warning(
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}"
)
# 裁剪到模型支持的范围
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
# 将response_format转换为Gemini API所需的格式
generation_config_dict = {
"max_output_tokens": max_tokens,
@ -380,11 +443,7 @@ class GeminiClient(BaseClient):
"response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig(
include_thoughts=True,
thinking_budget=(
extra_params["thinking_budget"]
if extra_params and "thinking_budget" in extra_params
else int(max_tokens / 2) # 默认思考预算为最大token数的一半防止空回复
),
thinking_budget=tb,
),
"safety_settings": gemini_safe_settings, # 防止空回复问题
}

View File

@ -388,6 +388,7 @@ class OpenaiClient(BaseClient):
base_url=api_provider.base_url,
api_key=api_provider.api_key,
max_retries=0,
timeout=api_provider.timeout,
)
async def get_response(
@ -520,6 +521,11 @@ class OpenaiClient(BaseClient):
extra_body=extra_params,
)
except APIConnectionError as e:
# 添加详细的错误信息以便调试
logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}")
logger.error(f"错误类型: {type(e)}")
if hasattr(e, '__cause__') and e.__cause__:
logger.error(f"底层错误: {str(e.__cause__)}")
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException

View File

@ -195,7 +195,7 @@ class LLMRequest:
if not content:
if raise_when_empty:
logger.warning("生成的响应为空")
logger.warning(f"生成的响应为空, 请求类型: {self.request_type}")
raise RuntimeError("生成的响应为空")
content = "生成的响应为空,请检查模型配置或输入内容是否正确"
@ -248,7 +248,11 @@ class LLMRequest:
)
model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
client = client_registry.get_client_class_instance(api_provider)
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
force_new_client = (self.request_type == "embedding")
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"选择请求模型: {model_info.name}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用

View File

@ -163,13 +163,9 @@ class ChatAction:
limit=15,
limit_mode="last",
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages(
tmp_msgs,
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
@ -230,13 +226,9 @@ class ChatAction:
limit=10,
limit_mode="last",
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages(
tmp_msgs,
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,

View File

@ -166,13 +166,10 @@ class ChatMood:
limit=10,
limit_mode="last",
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages(
tmp_msgs,
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
@ -248,13 +245,10 @@ class ChatMood:
limit=5,
limit_mode="last",
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages(
tmp_msgs,
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,

View File

@ -17,6 +17,10 @@ from src.mais4u.mais4u_chat.screen_manager import screen_manager
from src.chat.express.expression_selector import expression_selector
from .s4u_mood_manager import mood_manager
from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.data_models.database_data_model import DatabaseMessages
from typing import List
logger = get_logger("prompt")
@ -58,7 +62,7 @@ def init_prompt():
""",
"s4u_prompt", # New template for private CHAT chat
)
Prompt(
"""
你的名字是麦麦, 是千石可乐开发的程序可以在QQ微信等平台发言你现在正在哔哩哔哩作为虚拟主播进行直播
@ -95,14 +99,13 @@ class PromptBuilder:
def __init__(self):
self.prompt_built = ""
self.activate_messages = ""
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions ,_ = await expression_selector.select_suitable_expressions_llm(
selected_expressions, _ = await expression_selector.select_suitable_expressions_llm(
chat_stream.stream_id, chat_history, max_num=12, target_message=target
)
@ -122,7 +125,6 @@ class PromptBuilder:
if style_habits_str.strip():
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
return expression_habits_block
async def build_relation_info(self, chat_stream) -> str:
@ -148,9 +150,7 @@ class PromptBuilder:
person_ids.append(person_id)
# 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为
relation_info_list = [
Person(person_id=person_id).build_relationship() for person_id in person_ids
]
relation_info_list = [Person(person_id=person_id).build_relationship() for person_id in person_ids]
if relation_info := "".join(relation_info_list):
relation_prompt = await global_prompt_manager.format_prompt(
"relation_prompt", relation_info=relation_info
@ -160,7 +160,7 @@ class PromptBuilder:
async def build_memory_block(self, text: str) -> str:
# 待更新记忆系统
return ""
related_memory = await hippocampus_manager.get_memory_from_text(
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
)
@ -176,38 +176,37 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
limit=300,
)
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
core_dialogue_list = []
background_dialogue_list = []
core_dialogue_list: List[DatabaseMessages] = []
background_dialogue_list: List[DatabaseMessages] = []
bot_id = str(global_config.bot.qq_account)
target_user_id = str(message.chat_stream.user_info.user_id)
# TODO: 修复之!
for msg in message_list_before_now:
try:
msg_user_id = str(msg.user_info.user_id)
if msg_user_id == bot_id:
if msg.reply_to and talk_type == msg.reply_to:
core_dialogue_list.append(msg.__dict__)
core_dialogue_list.append(msg)
elif msg.reply_to and talk_type != msg.reply_to:
background_dialogue_list.append(msg.__dict__)
background_dialogue_list.append(msg)
# else:
# background_dialogue_list.append(msg_dict)
# background_dialogue_list.append(msg_dict)
elif msg_user_id == target_user_id:
core_dialogue_list.append(msg.__dict__)
core_dialogue_list.append(msg)
else:
background_dialogue_list.append(msg.__dict__)
background_dialogue_list.append(msg)
except Exception as e:
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
background_dialogue_prompt = ""
if background_dialogue_list:
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length:]
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
background_dialogue_prompt_str = build_readable_messages(
context_msgs,
timestamp_mode="normal_no_YMD",
@ -217,10 +216,10 @@ class PromptBuilder:
core_msg_str = ""
if core_dialogue_list:
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length:]
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
first_msg = core_dialogue_list[0]
start_speaking_user_id = first_msg.get("user_id")
start_speaking_user_id = first_msg.user_info.user_id
if start_speaking_user_id == bot_id:
last_speaking_user_id = bot_id
msg_seg_str = "你的发言:\n"
@ -229,13 +228,13 @@ class PromptBuilder:
last_speaking_user_id = start_speaking_user_id
msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
all_msg_seg_list = []
for msg in core_dialogue_list[1:]:
speaker = msg.get("user_id")
speaker = msg.user_info.user_id
if speaker == last_speaking_user_id:
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
else:
msg_seg_str = f"{msg_seg_str}\n"
all_msg_seg_list.append(msg_seg_str)
@ -252,46 +251,40 @@ class PromptBuilder:
for msg in all_msg_seg_list:
core_msg_str += msg
all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat(
all_dialogue_history = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=20,
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in all_dialogue_prompt]
all_dialogue_prompt_str = build_readable_messages(
tmp_msgs,
all_dialogue_history,
timestamp_mode="normal_no_YMD",
show_pic=False,
)
return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
def build_gift_info(self, message: MessageRecvS4U):
if message.is_gift:
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
else:
if message.is_fake_gift:
return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)"
return ""
def build_sc_info(self, message: MessageRecvS4U):
super_chat_manager = get_super_chat_manager()
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
async def build_prompt_normal(
self,
message: MessageRecvS4U,
message_txt: str,
) -> str:
chat_stream = message.chat_stream
person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id)
person_name = person.person_name
@ -302,28 +295,31 @@ class PromptBuilder:
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
else:
sender_name = f"用户({message.chat_stream.user_info.user_id})"
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
self.build_relation_info(chat_stream), self.build_memory_block(message_txt), self.build_expression_habits(chat_stream, message_txt, sender_name)
self.build_relation_info(chat_stream),
self.build_memory_block(message_txt),
self.build_expression_habits(chat_stream, message_txt, sender_name),
)
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts(
chat_stream, message
)
core_dialogue_prompt, background_dialogue_prompt,all_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message)
gift_info = self.build_gift_info(message)
sc_info = self.build_sc_info(message)
screen_info = screen_manager.get_screen_str()
internal_state = internal_manager.get_internal_state_str()
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id)
template_name = "s4u_prompt"
if not message.is_internal:
prompt = await global_prompt_manager.format_prompt(
template_name,
@ -356,7 +352,7 @@ class PromptBuilder:
mind=message.processed_plain_text,
mood_state=mood.mood_state,
)
# print(prompt)
return prompt

View File

@ -99,13 +99,10 @@ class ChatMood:
limit=int(global_config.chat.max_context_size / 3),
limit_mode="last",
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages(
tmp_msgs,
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
@ -151,13 +148,10 @@ class ChatMood:
limit=15,
limit_mode="last",
)
# TODO: 修复
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_talking_prompt = build_readable_messages(
tmp_msgs,
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,

View File

@ -5,7 +5,7 @@ import time
import random
from json_repair import repair_json
from typing import Union
from typing import Union, Optional
from src.common.logger import get_logger
from src.common.database.database import db
@ -253,8 +253,8 @@ class Person:
# 初始化默认值
self.nickname = ""
self.person_name = None
self.name_reason = None
self.person_name: Optional[str] = None
self.name_reason: Optional[str] = None
self.know_times = 0
self.know_since = None
self.last_know = None

View File

@ -1,18 +1,21 @@
import json
import traceback
from json_repair import repair_json
from datetime import datetime
from typing import List
from src.common.logger import get_logger
from .person_info import Person
import random
from src.common.data_models.database_data_model import DatabaseMessages
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.chat.utils.chat_message_builder import build_readable_messages
import json
from json_repair import repair_json
from datetime import datetime
from typing import List, Dict, Any
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import traceback
from .person_info import Person
logger = get_logger("relation")
def init_prompt():
Prompt(
"""
@ -45,8 +48,7 @@ def init_prompt():
""",
"attitude_to_me_prompt",
)
Prompt(
"""
你的名字是{bot_name}{bot_name}的别名是{alias_str}
@ -80,104 +82,102 @@ def init_prompt():
"neuroticism_prompt",
)
class RelationshipManager:
def __init__(self):
self.relationship_llm = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="relationship.person"
)
)
async def get_attitude_to_me(self, readable_messages, timestamp, person: Person):
alias_str = ", ".join(global_config.bot.alias_names)
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
# 解析当前态度值
current_attitude_score = person.attitude_to_me
total_confidence = person.attitude_to_me_confidence
prompt = await global_prompt_manager.format_prompt(
"attitude_to_me_prompt",
bot_name = global_config.bot.nickname,
alias_str = alias_str,
person_name = person.person_name,
nickname = person.nickname,
readable_messages = readable_messages,
current_time = current_time,
bot_name=global_config.bot.nickname,
alias_str=alias_str,
person_name=person.person_name,
nickname=person.nickname,
readable_messages=readable_messages,
current_time=current_time,
)
attitude, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
attitude = repair_json(attitude)
attitude_data = json.loads(attitude)
if not attitude_data or (isinstance(attitude_data, list) and len(attitude_data) == 0):
return ""
# 确保 attitude_data 是字典格式
if not isinstance(attitude_data, dict):
logger.warning(f"LLM返回了错误的JSON格式跳过解析: {type(attitude_data)}, 内容: {attitude_data}")
return ""
attitude_score = attitude_data["attitude"]
confidence = pow(attitude_data["confidence"],2)
confidence = pow(attitude_data["confidence"], 2)
new_confidence = total_confidence + confidence
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence)/new_confidence
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence) / new_confidence
person.attitude_to_me = new_attitude_score
person.attitude_to_me_confidence = new_confidence
return person
async def get_neuroticism(self, readable_messages, timestamp, person: Person):
alias_str = ", ".join(global_config.bot.alias_names)
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
# 解析当前态度值
current_neuroticism_score = person.neuroticism
total_confidence = person.neuroticism_confidence
prompt = await global_prompt_manager.format_prompt(
"neuroticism_prompt",
bot_name = global_config.bot.nickname,
alias_str = alias_str,
person_name = person.person_name,
nickname = person.nickname,
readable_messages = readable_messages,
current_time = current_time,
bot_name=global_config.bot.nickname,
alias_str=alias_str,
person_name=person.person_name,
nickname=person.nickname,
readable_messages=readable_messages,
current_time=current_time,
)
neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
# logger.info(f"prompt: {prompt}")
# logger.info(f"neuroticism: {neuroticism}")
neuroticism = repair_json(neuroticism)
neuroticism_data = json.loads(neuroticism)
if not neuroticism_data or (isinstance(neuroticism_data, list) and len(neuroticism_data) == 0):
return ""
# 确保 neuroticism_data 是字典格式
if not isinstance(neuroticism_data, dict):
logger.warning(f"LLM返回了错误的JSON格式跳过解析: {type(neuroticism_data)}, 内容: {neuroticism_data}")
return ""
neuroticism_score = neuroticism_data["neuroticism"]
confidence = pow(neuroticism_data["confidence"],2)
confidence = pow(neuroticism_data["confidence"], 2)
new_confidence = total_confidence + confidence
new_neuroticism_score = (current_neuroticism_score * total_confidence + neuroticism_score * confidence)/new_confidence
new_neuroticism_score = (
current_neuroticism_score * total_confidence + neuroticism_score * confidence
) / new_confidence
person.neuroticism = new_neuroticism_score
person.neuroticism_confidence = new_confidence
return person
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
return person
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]):
"""更新用户印象
Args:
@ -202,12 +202,11 @@ class RelationshipManager:
# 遍历消息,构建映射
for msg in user_messages:
if msg.get("user_id") == "system":
if msg.user_info.user_id == "system":
continue
try:
user_id = msg.get("user_id")
platform = msg.get("chat_info_platform")
user_id = msg.user_info.user_id
platform = msg.chat_info.platform
assert isinstance(user_id, str) and isinstance(platform, str)
msg_person = Person(user_id=user_id, platform=platform)
@ -242,19 +241,16 @@ class RelationshipManager:
# 确保 original_name 和 mapped_name 都不为 None
if original_name is not None and mapped_name is not None:
readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}")
# await self.get_points(
# readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person)
# readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person)
await self.get_attitude_to_me(readable_messages=readable_messages, timestamp=timestamp, person=person)
await self.get_neuroticism(readable_messages=readable_messages, timestamp=timestamp, person=person)
person.know_times = know_times + 1
person.last_know = timestamp
person.sync_to_database()
person.sync_to_database()
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
"""计算基于时间的权重系数"""
@ -280,6 +276,7 @@ class RelationshipManager:
logger.error(f"计算时间权重失败: {e}")
return 0.5 # 发生错误时返回中等权重
init_prompt()
relationship_manager = None
@ -290,4 +287,3 @@ def get_relationship_manager():
if relationship_manager is None:
relationship_manager = RelationshipManager()
return relationship_manager

View File

@ -127,7 +127,7 @@ async def generate_reply(
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
extra_info=extra_info,
available_actions=available_actions,
choosen_actions=choosen_actions,
chosen_actions=choosen_actions,
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,

View File

@ -294,7 +294,9 @@ def get_messages_before_time_in_chat(
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[DatabaseMessages]:
def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]:
"""
获取指定用户在指定时间戳之前的消息
@ -410,9 +412,8 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
def build_readable_messages_to_str(
messages: List[Dict[str, Any]],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
@ -434,14 +435,13 @@ def build_readable_messages_to_str(
格式化后的可读字符串
"""
return build_readable_messages(
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions
)
async def build_readable_messages_with_details(
messages: List[Dict[str, Any]],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
@ -458,7 +458,7 @@ async def build_readable_messages_with_details(
Returns:
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
"""
return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:

View File

@ -2,13 +2,14 @@ import random
from typing import Tuple
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
from src.plugin_system import BaseAction, ActionActivationType
# 导入依赖的系统组件
from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式
from src.plugin_system.apis import emoji_api, llm_api, message_api
# NoReplyAction已集成到heartFC_chat.py中不再需要导入
from src.config.config import global_config
@ -84,11 +85,8 @@ class EmojiAction(BaseAction):
messages_text = ""
if recent_messages:
# 使用message_api构建可读的消息字符串
# TODO: 修复
from src.common.data_models import temporarily_transform_class_to_dict
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in recent_messages]
messages_text = message_api.build_readable_messages(
messages=tmp_msgs,
messages=recent_messages,
timestamp_mode="normal_no_YMD",
truncate=False,
show_actions=False,

View File

@ -1,5 +1,5 @@
[inner]
version = "6.4.6"
version = "6.5.0"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@ -29,6 +29,8 @@ identity = "年龄为19岁,是女孩子,身高为160cm,有黑色的短发"
# 描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容
reply_style = "回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。不要浮夸,不要夸张修辞。"
plan_style = "当你刚刚发送了消息没有人回复时不要选择action如果有别的动作非回复满足条件可以选择当你一次发送了太多消息为了避免打扰聊天节奏不要选择动作"
compress_personality = false # 是否压缩人格压缩后会精简人格信息节省token消耗并提高回复性能但是会丢失一些信息如果人设不长可以关闭
compress_identity = true # 是否压缩身份压缩后会精简身份信息节省token消耗并提高回复性能但是会丢失一些信息如果不长可以关闭

View File

@ -5,7 +5,7 @@ version = "1.3.0"
[[api_providers]] # API服务提供商可以配置多个
name = "DeepSeek" # API服务商名称可随意命名在models的api-provider中需使用这个命名
base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL
base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL
api_key = "your-api-key-here" # API密钥请替换为实际的API密钥
client_type = "openai" # 请求客户端(可选,默认值为"openai"使用gimini等Google系模型时请配置为"gemini"
max_retry = 2 # 最大重试次数单个模型API调用失败最多重试的次数

View File

@ -1,73 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试del_memory函数的脚本
"""
import sys
import os
# 添加src目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from person_info.person_info import Person
def test_del_memory():
"""测试del_memory函数"""
print("开始测试del_memory函数...")
# 创建一个测试用的Person实例不连接数据库
person = Person.__new__(Person)
person.person_id = "test_person"
person.memory_points = [
"性格:这个人很友善:5.0",
"性格:这个人很友善:4.0",
"爱好:喜欢打游戏:3.0",
"爱好:喜欢打游戏:2.0",
"工作:是一名程序员:1.0",
"性格:这个人很友善:6.0"
]
print(f"原始记忆点数量: {len(person.memory_points)}")
print("原始记忆点:")
for i, memory in enumerate(person.memory_points):
print(f" {i+1}. {memory}")
# 测试删除"性格"分类中"这个人很友善"的记忆
print("\n测试1: 删除'性格'分类中'这个人很友善'的记忆")
deleted_count = person.del_memory("性格", "这个人很友善")
print(f"删除了 {deleted_count} 个记忆点")
print("删除后的记忆点:")
for i, memory in enumerate(person.memory_points):
print(f" {i+1}. {memory}")
# 测试删除"爱好"分类中"喜欢打游戏"的记忆
print("\n测试2: 删除'爱好'分类中'喜欢打游戏'的记忆")
deleted_count = person.del_memory("爱好", "喜欢打游戏")
print(f"删除了 {deleted_count} 个记忆点")
print("删除后的记忆点:")
for i, memory in enumerate(person.memory_points):
print(f" {i+1}. {memory}")
# 测试相似度匹配
print("\n测试3: 测试相似度匹配")
person.memory_points = [
"性格:这个人非常友善:5.0",
"性格:这个人很友善:4.0",
"性格:这个人友善:3.0"
]
print("原始记忆点:")
for i, memory in enumerate(person.memory_points):
print(f" {i+1}. {memory}")
# 删除"这个人很友善"(应该匹配"这个人很友善"和"这个人友善"
deleted_count = person.del_memory("性格", "这个人很友善", similarity_threshold=0.8)
print(f"删除了 {deleted_count} 个记忆点")
print("删除后的记忆点:")
for i, memory in enumerate(person.memory_points):
print(f" {i+1}. {memory}")
print("\n测试完成!")
if __name__ == "__main__":
test_del_memory()

View File

@ -1,124 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试修复后的memory_points处理
"""
import sys
import os
# 添加src目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from person_info.person_info import Person
def test_memory_points_with_none():
"""测试包含None值的memory_points处理"""
print("测试包含None值的memory_points处理...")
# 创建一个测试Person实例
person = Person(person_id="test_user_123")
# 模拟包含None值的memory_points
person.memory_points = [
"喜好:喜欢咖啡:1.0",
None, # 模拟None值
"性格:开朗:1.0",
None, # 模拟另一个None值
"兴趣:编程:1.0"
]
print(f"原始memory_points: {person.memory_points}")
# 测试get_all_category方法
try:
categories = person.get_all_category()
print(f"获取到的分类: {categories}")
print("✓ get_all_category方法正常工作")
except Exception as e:
print(f"✗ get_all_category方法出错: {e}")
return False
# 测试get_memory_list_by_category方法
try:
memories = person.get_memory_list_by_category("喜好")
print(f"获取到的喜好记忆: {memories}")
print("✓ get_memory_list_by_category方法正常工作")
except Exception as e:
print(f"✗ get_memory_list_by_category方法出错: {e}")
return False
# 测试del_memory方法
try:
deleted_count = person.del_memory("喜好", "喜欢咖啡")
print(f"删除的记忆点数量: {deleted_count}")
print(f"删除后的memory_points: {person.memory_points}")
print("✓ del_memory方法正常工作")
except Exception as e:
print(f"✗ del_memory方法出错: {e}")
return False
return True
def test_memory_points_empty():
"""测试空的memory_points处理"""
print("\n测试空的memory_points处理...")
person = Person(person_id="test_user_456")
person.memory_points = []
try:
categories = person.get_all_category()
print(f"空列表的分类: {categories}")
print("✓ 空列表处理正常")
except Exception as e:
print(f"✗ 空列表处理出错: {e}")
return False
try:
memories = person.get_memory_list_by_category("测试分类")
print(f"空列表的记忆: {memories}")
print("✓ 空列表分类查询正常")
except Exception as e:
print(f"✗ 空列表分类查询出错: {e}")
return False
return True
def test_memory_points_all_none():
"""测试全部为None的memory_points处理"""
print("\n测试全部为None的memory_points处理...")
person = Person(person_id="test_user_789")
person.memory_points = [None, None, None]
try:
categories = person.get_all_category()
print(f"全None列表的分类: {categories}")
print("✓ 全None列表处理正常")
except Exception as e:
print(f"✗ 全None列表处理出错: {e}")
return False
try:
memories = person.get_memory_list_by_category("测试分类")
print(f"全None列表的记忆: {memories}")
print("✓ 全None列表分类查询正常")
except Exception as e:
print(f"✗ 全None列表分类查询出错: {e}")
return False
return True
if __name__ == "__main__":
print("开始测试修复后的memory_points处理...")
success = True
success &= test_memory_points_with_none()
success &= test_memory_points_empty()
success &= test_memory_points_all_none()
if success:
print("\n🎉 所有测试通过memory_points的None值处理已修复。")
else:
print("\n❌ 部分测试失败,需要进一步检查。")