diff --git a/README.md b/README.md index 3a9e14f8..11c71c2a 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,8 @@ **🍔MaiCore 是一个基于大语言模型的可交互智能体** -- 💭 **智能对话系统**:基于 LLM 的自然语言交互,支持normal和focus统一化处理。 -- 🔌 **强大插件系统**:全面重构的插件架构,支持完整的管理API和权限控制。 +- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制。 +- 🔌 **强大插件系统**:全面重构的插件架构,更多API。 - 🤔 **实时思维系统**:模拟人类思考过程。 - 🧠 **表达学习功能**:学习群友的说话风格和表达方式 - 💝 **情感表达系统**:情绪系统和表情包系统。 @@ -46,7 +46,7 @@ ## 🔥 更新和安装 -**最新版本: v0.9.1** ([更新日志](changelogs/changelog.md)) +**最新版本: v0.10.0** ([更新日志](changelogs/changelog.md)) 可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本 可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器 @@ -56,7 +56,6 @@ - `classical`: 旧版本(停止维护) ### 最新版本部署教程 -- [从0.6/0.7升级须知](https://docs.mai-mai.org/faq/maibot/update_to_07.html) - [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容) > [!WARNING] @@ -64,7 +63,6 @@ > - 项目处于活跃开发阶段,功能和 API 可能随时调整。 > - 文档未完善,有问题可以提交 Issue 或者 Discussion。 > - QQ 机器人存在被限制风险,请自行了解,谨慎使用。 -> - 由于持续迭代,可能存在一些已知或未知的 bug。 > - 由于程序处于开发中,可能消耗较多 token。 ## 💬 讨论 diff --git a/changelogs/changelog.md b/changelogs/changelog.md index 00cb7ca9..1b4d18e3 100644 --- a/changelogs/changelog.md +++ b/changelogs/changelog.md @@ -1,16 +1,74 @@ # Changelog ## [0.10.0] - 2025-7-1 -### 主要功能更改 +### 🌟 主要功能更改 +- 优化的回复生成,现在的回复对上下文把控更加精准 +- 新的回复逻辑控制,现在合并了normal和focus模式,更加统一 +- 优化表达方式系统,现在学习和使用更加精准 +- 新的关系系统,现在的关系构建更精准也更克制 - 工具系统重构,现在合并到了插件系统中 - 彻底重构了整个LLM Request了,现在支持模型轮询和更多灵活的参数 - 同时重构了整个模型配置系统,升级需要重新配置llm配置文件 -- 随着LLM Request的重构,插件系统彻底重构完成。插件系统进入稳定状态,仅增加新的API - - 具体相比于之前的更改可以查看[changes.md](./changes.md) +- **警告所有插件开发者:插件系统即将迎来不稳定时期,随时会发动更改。** + +#### 🔧 工具系统重构 +- **工具系统整合**: 工具系统现在完全合并到插件系统中,提供统一的扩展能力 +- **工具启用控制**: 支持配置是否启用特定工具,提供更人性化的直接调用方式 +- **配置文件读取**: 工具现在支持读取配置文件,增强配置灵活性 + +#### 🚀 LLM系统全面重构 +- **LLM Request重构**: 彻底重构了整个LLM Request系统,现在支持模型轮询和更多灵活的参数 +- **模型配置升级**: 同时重构了整个模型配置系统,升级需要重新配置llm配置文件 +- **任务类型支持**: 新增任务类型和能力字段至模型配置,增强模型初始化逻辑 +- **异常处理增强**: 增强LLMRequest类的异常处理,添加统一的模型异常处理方法 + +#### 🔌 插件系统稳定化 +- **插件系统重构完成**: 随着LLM Request的重构,插件系统彻底重构完成,进入稳定状态 +- **API扩展**: 仅增加新的API,保持向后兼容性 +- **插件管理优化**: 让插件管理配置真正有用,提升管理体验 + +#### 💾 记忆系统优化 +- **及时构建**: 记忆系统再优化,现在及时构建,并且不会重复构建 +- **精确提取**: 记忆提取更精确,提升记忆质量 + +#### 🎭 表达方式系统 +- **表达方式记录**: 记录使用的表达方式,提供更好的学习追踪 +- **学习优化**: 优化表达方式提取,修复表达学习出错问题 +- **配置优化**: 优化表达方式配置和逻辑,提升系统稳定性 + +#### 🔄 聊天系统统一 +- **normal和focus合并**: 彻底合并normal和focus,完全基于planner决定target message +- **no_reply内置**: 将no_reply功能移动到主循环中,简化系统架构 +- **回复优化**: 优化reply,填补缺失值,让麦麦可以回复自己的消息 +- **频率控制API**: 加入聊天频率控制相关API,提供更精细的控制 + +#### 日志系统改进 +- **日志颜色优化**: 修改了log的颜色,更加护眼 +- **日志清理优化**: 修复了日志清理先等24h的问题,提升系统性能 +- **计时定位**: 通过计时定位LLM异常延时,提升问题排查效率 + +### 🐛 问题修复 + +#### 代码质量提升 +- **lint问题修复**: 修复了lint爆炸的问题,代码更加规范了 +- **导入优化**: 修复导入爆炸和文档错误,优化代码结构 + +#### 系统稳定性 +- **循环导入**: 修复了import时循环导入的问题 +- **并行动作**: 修复并行动作炸裂问题,提升并发处理能力 +- **空响应处理**: 空响应就raise,避免系统异常 + +#### 功能修复 +- **API问题**: 修复api问题,提升系统可用性 +- **notice问题**: 为组件方法提供新参数,暂时解决notice问题 +- **关系构建**: 修复不认识的用户构建关系问题 +- **流式解析**: 修复流式解析越界问题,避免空choices的SSE帧错误 + +#### 配置和兼容性 +- **默认值**: 添加默认值,提升配置灵活性 +- **类型问题**: 修复类型问题,提升代码健壮性 +- **配置加载**: 优化配置加载逻辑,提升系统启动稳定性 -### 细节优化 -- 修复了lint爆炸的问题,代码更加规范了 -- 修改了log的颜色,更加护眼 ## [0.9.1] - 2025-7-26 diff --git a/docker-compose.yml b/docker-compose.yml index e4519d30..3bcf0e54 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -84,6 +84,7 @@ services: # - ./data/MaiMBot:/data/MaiMBot # networks: # - maim_bot + volumes: site-packages: networks: diff --git a/requirements.txt b/requirements.txt index 999bd5fd..721cf95f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,3 +47,4 @@ reportportal-client scikit-learn seaborn structlog +google.genai diff --git a/scripts/import_openie.py b/scripts/import_openie.py index fe9f5269..c4367892 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -6,6 +6,7 @@ import sys import os +import asyncio from time import sleep sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -172,7 +173,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k return True -def main(): # sourcery skip: dict-comprehension +async def main_async(): # sourcery skip: dict-comprehension # 新增确认提示 print("=== 重要操作确认 ===") print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") @@ -239,6 +240,29 @@ def main(): # sourcery skip: dict-comprehension return None +def main(): + """主函数 - 设置新的事件循环并运行异步主函数""" + # 检查是否有现有的事件循环 + try: + loop = asyncio.get_running_loop() + if loop.is_closed(): + # 如果事件循环已关闭,创建新的 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # 没有运行的事件循环,创建新的 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # 在新的事件循环中运行异步主函数 + loop.run_until_complete(main_async()) + finally: + # 确保事件循环被正确关闭 + if not loop.is_closed(): + loop.close() + + if __name__ == "__main__": # logger.info(f"111111111111111111111111{ROOT_PATH}") main() diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 7f55bc0d..b781dc16 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -1,6 +1,7 @@ import asyncio import time import traceback +import math import random from typing import List, Optional, Dict, Any, Tuple from rich.traceback import install @@ -8,6 +9,7 @@ from collections import deque from src.config.config import global_config from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.timer_calculator import Timer @@ -15,21 +17,19 @@ from src.chat.planner_actions.planner import ActionPlanner from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_manager import ActionManager from src.chat.chat_loop.hfc_utils import CycleDetail -from src.person_info.relationship_builder_manager import relationship_builder_manager +from src.chat.chat_loop.hfc_utils import send_typing, stop_typing +from src.chat.memory_system.Hippocampus import hippocampus_manager +from src.chat.frequency_control.talk_frequency_control import talk_frequency_control +from src.chat.frequency_control.focus_value_control import focus_value_control from src.chat.express.expression_learner import expression_learner_manager +from src.person_info.relationship_builder_manager import relationship_builder_manager from src.person_info.person_info import Person from src.plugin_system.base.component_types import ChatMode, EventType from src.plugin_system.core import events_manager from src.plugin_system.apis import generator_api, send_api, message_api, database_api from src.mais4u.mai_think import mai_thinking_manager -import math from src.mais4u.s4u_config import s4u_config -# no_action逻辑已集成到heartFC_chat.py中,不再需要导入 -from src.chat.chat_loop.hfc_utils import send_typing, stop_typing -# 导入记忆系统 -from src.chat.memory_system.Hippocampus import hippocampus_manager -from src.chat.frequency_control.talk_frequency_control import talk_frequency_control -from src.chat.frequency_control.focus_value_control import focus_value_control + ERROR_LOOP_INFO = { "loop_plan_info": { @@ -62,10 +62,7 @@ class HeartFChatting: 其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。 """ - def __init__( - self, - chat_id: str, - ): + def __init__(self, chat_id: str): """ HeartFChatting 初始化函数 @@ -83,7 +80,7 @@ class HeartFChatting: self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id) self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id) - + self.talk_frequency_control = talk_frequency_control.get_talk_frequency_control(self.stream_id) self.focus_value_control = focus_value_control.get_focus_value_control(self.stream_id) @@ -104,7 +101,7 @@ class HeartFChatting: self.plan_timeout_count = 0 self.last_read_time = time.time() - 1 - + self.focus_energy = 1 self.no_action_consecutive = 0 # 最近三次no_action的新消息兴趣度记录 @@ -166,27 +163,26 @@ class HeartFChatting: # 获取动作类型,兼容新旧格式 action_type = "未知动作" - if hasattr(self, '_current_cycle_detail') and self._current_cycle_detail: + if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail: loop_plan_info = self._current_cycle_detail.loop_plan_info if isinstance(loop_plan_info, dict): - action_result = loop_plan_info.get('action_result', {}) + action_result = loop_plan_info.get("action_result", {}) if isinstance(action_result, dict): # 旧格式:action_result是字典 - action_type = action_result.get('action_type', '未知动作') + action_type = action_result.get("action_type", "未知动作") elif isinstance(action_result, list) and action_result: # 新格式:action_result是actions列表 - action_type = action_result[0].get('action_type', '未知动作') + action_type = action_result[0].get("action_type", "未知动作") elif isinstance(loop_plan_info, list) and loop_plan_info: # 直接是actions列表的情况 - action_type = loop_plan_info[0].get('action_type', '未知动作') + action_type = loop_plan_info[0].get("action_type", "未知动作") logger.info( f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore - f"选择动作: {action_type}" - + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") + f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") ) - + def _determine_form_type(self) -> None: """判断使用哪种形式的no_action""" # 如果连续no_action次数少于3次,使用waiting形式 @@ -195,42 +191,44 @@ class HeartFChatting: else: # 计算最近三次记录的兴趣度总和 total_recent_interest = sum(self.recent_interest_records) - + # 计算调整后的阈值 adjusted_threshold = 1 / self.talk_frequency_control.get_current_talk_frequency() - - logger.info(f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}") - + + logger.info( + f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}" + ) + # 如果兴趣度总和小于阈值,进入breaking形式 if total_recent_interest < adjusted_threshold: logger.info(f"{self.log_prefix} 兴趣度不足,进入休息") self.focus_energy = random.randint(3, 6) else: logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息") - self.focus_energy = 1 - - async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool,float]: + self.focus_energy = 1 + + async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]: """ 判断是否应该处理消息 - + Args: new_message: 新消息列表 mode: 当前聊天模式 - + Returns: bool: 是否应该处理消息 """ new_message_count = len(new_message) talk_frequency = self.talk_frequency_control.get_current_talk_frequency() - + modified_exit_count_threshold = self.focus_energy * 0.5 / talk_frequency modified_exit_interest_threshold = 1.5 / talk_frequency total_interest = 0.0 - for msg_dict in new_message: - interest_value = msg_dict.get("interest_value") - if interest_value is not None and msg_dict.get("processed_plain_text", ""): + for msg in new_message: + interest_value = msg.interest_value + if interest_value is not None and msg.processed_plain_text: total_interest += float(interest_value) - + if new_message_count >= modified_exit_count_threshold: self.recent_interest_records.append(total_interest) logger.info( @@ -244,9 +242,11 @@ class HeartFChatting: if new_message_count > 0: # 只在兴趣值变化时输出log if not hasattr(self, "_last_accumulated_interest") or total_interest != self._last_accumulated_interest: - logger.info(f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}") + logger.info( + f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}" + ) self._last_accumulated_interest = total_interest - + if total_interest >= modified_exit_interest_threshold: # 记录兴趣度到列表 self.recent_interest_records.append(total_interest) @@ -261,29 +261,25 @@ class HeartFChatting: f"{self.log_prefix} 已等待{time.time() - self.last_read_time:.0f}秒,累计{new_message_count}条消息,累计兴趣{total_interest:.1f},继续等待..." ) await asyncio.sleep(0.5) - - return False,0.0 + return False, 0.0 async def _loopbody(self): recent_messages_dict = message_api.get_messages_by_time_in_chat( chat_id=self.stream_id, start_time=self.last_read_time, end_time=time.time(), - limit = 10, + limit=10, limit_mode="latest", filter_mai=True, filter_command=True, - ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - temp_recent_messages_dict = [temporarily_transform_class_to_dict(msg) for msg in recent_messages_dict] + ) # 统一的消息处理逻辑 - should_process,interest_value = await self._should_process_messages(temp_recent_messages_dict) + should_process, interest_value = await self._should_process_messages(recent_messages_dict) if should_process: self.last_read_time = time.time() - await self._observe(interest_value = interest_value) + await self._observe(interest_value=interest_value) else: # Normal模式:消息数量不足,等待 @@ -298,22 +294,21 @@ class HeartFChatting: cycle_timers: Dict[str, float], thinking_id, actions, - selected_expressions:List[int] = None, + selected_expressions: List[int] = None, ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: - with Timer("回复发送", cycle_timers): reply_text = await self._send_response( reply_set=response_set, message_data=action_message, selected_expressions=selected_expressions, ) - + # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 platform = action_message.get("chat_info_platform") if platform is None: platform = getattr(self.chat_stream, "platform", "unknown") - - person = Person(platform = platform ,user_id = action_message.get("user_id", "")) + + person = Person(platform=platform, user_id=action_message.get("user_id", "")) person_name = person.person_name action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" @@ -342,12 +337,10 @@ class HeartFChatting: return loop_info, reply_text, cycle_timers - async def _observe(self,interest_value:float = 0.0) -> bool: - + async def _observe(self, interest_value: float = 0.0) -> bool: action_type = "no_action" reply_text = "" # 初始化reply_text变量,避免UnboundLocalError - # 使用sigmoid函数将interest_value转换为概率 # 当interest_value为0时,概率接近0(使用Focus模式) # 当interest_value很高时,概率接近1(使用Normal模式) @@ -361,12 +354,14 @@ class HeartFChatting: x0 = 1.0 # 控制曲线中心点 return 1.0 / (1.0 + math.exp(-k * (interest_val - x0))) - normal_mode_probability = calculate_normal_mode_probability(interest_value) * 0.5 / self.talk_frequency_control.get_current_talk_frequency() + normal_mode_probability = calculate_normal_mode_probability(interest_value) * 2 * self.talk_frequency_control.get_current_talk_frequency() # 根据概率决定使用哪种模式 if random.random() < normal_mode_probability: mode = ChatMode.NORMAL - logger.info(f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability*100:.0f}%概率下选择回复") + logger.info( + f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability * 100:.0f}%概率下选择回复" + ) else: mode = ChatMode.FOCUS @@ -387,10 +382,9 @@ class HeartFChatting: await hippocampus_manager.build_memory_for_chat(self.stream_id) except Exception as e: logger.error(f"{self.log_prefix} 记忆构建失败: {e}") - if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS: - #如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考 + # 如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考 actions = [ { "action_type": "no_action", @@ -420,23 +414,21 @@ class HeartFChatting: ): return False with Timer("规划器", cycle_timers): - actions, _= await self.action_planner.plan( + actions, _ = await self.action_planner.plan( mode=mode, loop_start_time=self.last_read_time, available_actions=available_actions, ) - - # 3. 并行执行所有动作 - async def execute_action(action_info,actions): + async def execute_action(action_info, actions): """执行单个动作的通用函数""" try: if action_info["action_type"] == "no_action": # 直接处理no_action逻辑,不再通过动作系统 reason = action_info.get("reasoning", "选择不回复") logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") - + # 存储no_action信息到数据库 await database_api.store_action_info( chat_stream=self.chat_stream, @@ -447,13 +439,8 @@ class HeartFChatting: action_data={"reason": reason}, action_name="no_action", ) - - return { - "action_type": "no_action", - "success": True, - "reply_text": "", - "command": "" - } + + return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} elif action_info["action_type"] != "reply": # 执行普通动作 with Timer("动作执行", cycle_timers): @@ -463,20 +450,19 @@ class HeartFChatting: action_info["action_data"], cycle_timers, thinking_id, - action_info["action_message"] + action_info["action_message"], ) return { "action_type": action_info["action_type"], "success": success, "reply_text": reply_text, - "command": command + "command": command, } else: - try: success, response_set, prompt_selected_expressions = await generator_api.generate_reply( chat_stream=self.chat_stream, - reply_message = action_info["action_message"], + reply_message=action_info["action_message"], available_actions=available_actions, choosen_actions=actions, reply_reason=action_info.get("reasoning", ""), @@ -485,29 +471,21 @@ class HeartFChatting: from_plugin=False, return_expressions=True, ) - + if prompt_selected_expressions and len(prompt_selected_expressions) > 1: - _,selected_expressions = prompt_selected_expressions + _, selected_expressions = prompt_selected_expressions else: selected_expressions = [] if not success or not response_set: - logger.info(f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败") - return { - "action_type": "reply", - "success": False, - "reply_text": "", - "loop_info": None - } - + logger.info( + f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败" + ) + return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} + except asyncio.CancelledError: logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消") - return { - "action_type": "reply", - "success": False, - "reply_text": "", - "loop_info": None - } + return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None} loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply( response_set=response_set, @@ -521,7 +499,7 @@ class HeartFChatting: "action_type": "reply", "success": True, "reply_text": reply_text, - "loop_info": loop_info + "loop_info": loop_info, } except Exception as e: logger.error(f"{self.log_prefix} 执行动作时出错: {e}") @@ -531,26 +509,26 @@ class HeartFChatting: "success": False, "reply_text": "", "loop_info": None, - "error": str(e) + "error": str(e), } - - action_tasks = [asyncio.create_task(execute_action(action,actions)) for action in actions] - + + action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions] + # 并行执行所有任务 results = await asyncio.gather(*action_tasks, return_exceptions=True) - + # 处理执行结果 reply_loop_info = None reply_text_from_reply = "" action_success = False action_reply_text = "" action_command = "" - + for i, result in enumerate(results): if isinstance(result, BaseException): logger.error(f"{self.log_prefix} 动作执行异常: {result}") continue - + _cur_action = actions[i] if result["action_type"] != "reply": action_success = result["success"] @@ -590,7 +568,6 @@ class HeartFChatting: }, } reply_text = action_reply_text - if s4u_config.enable_s4u: await stop_typing() @@ -602,7 +579,7 @@ class HeartFChatting: # await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", "")) action_type = actions[0]["action_type"] if actions else "no_action" - + # 管理no_action计数器:当执行了非no_action动作时,重置计数器 if action_type != "no_action": # no_action逻辑已集成到heartFC_chat.py中,直接重置计数器 @@ -610,7 +587,7 @@ class HeartFChatting: self.no_action_consecutive = 0 logger.debug(f"{self.log_prefix} 执行了{action_type}动作,重置no_action计数器") return True - + if action_type == "no_action": self.no_action_consecutive += 1 self._determine_form_type() @@ -692,11 +669,12 @@ class HeartFChatting: traceback.print_exc() return False, "", "" - async def _send_response(self, - reply_set, - message_data, - selected_expressions:List[int] = None, - ) -> str: + async def _send_response( + self, + reply_set, + message_data, + selected_expressions: List[int] = None, + ) -> str: new_message_count = message_api.count_new_messages( chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time() ) @@ -714,7 +692,7 @@ class HeartFChatting: await send_api.text_to_stream( text=data, stream_id=self.chat_stream.stream_id, - reply_message = message_data, + reply_message=message_data, set_reply=need_reply, typing=False, selected_expressions=selected_expressions, @@ -724,7 +702,7 @@ class HeartFChatting: await send_api.text_to_stream( text=data, stream_id=self.chat_stream.stream_id, - reply_message = message_data, + reply_message=message_data, set_reply=False, typing=True, selected_expressions=selected_expressions, diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 10669b14..47a50865 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -709,36 +709,36 @@ class EmojiManager: return emoji return None # 如果循环结束还没找到,则返回 None - async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]: - """根据哈希值获取已注册表情包的描述 + async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]: + """根据哈希值获取已注册表情包的情感标签列表 Args: emoji_hash: 表情包的哈希值 Returns: - Optional[str]: 表情包描述,如果未找到则返回None + Optional[List[str]]: 情感标签列表,如果未找到则返回None """ try: # 先从内存中查找 emoji = await self.get_emoji_from_manager(emoji_hash) if emoji and emoji.emotion: - logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.emotion}...") - return ",".join(emoji.emotion) + logger.info(f"[缓存命中] 从内存获取表情包情感标签: {emoji.emotion}...") + return emoji.emotion # 如果内存中没有,从数据库查找 self._ensure_db() try: emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) if emoji_record and emoji_record.emotion: - logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") - return emoji_record.emotion + logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...") + return emoji_record.emotion.split(',') except Exception as e: - logger.error(f"从数据库查询表情包描述时出错: {e}") + logger.error(f"从数据库查询表情包情感标签时出错: {e}") return None except Exception as e: - logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") + logger.error(f"获取表情包情感标签失败 (Hash: {emoji_hash}): {str(e)}") return None async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]: diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index e5b5eb04..cc29d6f2 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -8,6 +8,7 @@ from typing import List, Dict, Optional, Any, Tuple from src.common.logger import get_logger from src.common.database.database_model import Expression +from src.common.data_models.database_data_model import DatabaseMessages from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages @@ -346,21 +347,17 @@ class ExpressionLearner: current_time = time.time() # 获取上次学习时间 - random_msg_temp = get_raw_msg_by_timestamp_with_chat_inclusive( + random_msg = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_learning_time, timestamp_end=current_time, limit=num, ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - random_msg: Optional[List[Dict[str, Any]]] = [temporarily_transform_class_to_dict(msg) for msg in random_msg_temp] if random_msg_temp else None - # print(random_msg) if not random_msg or random_msg == []: return None # 转化成str - chat_id: str = random_msg[0]["chat_id"] + chat_id: str = random_msg[0].chat_id # random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal") random_msg_str: str = await build_anonymous_messages(random_msg) # print(f"random_msg_str:{random_msg_str}") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index d0f6e774..dec5b595 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -117,30 +117,36 @@ class EmbeddingStore: self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - """获取字符串的嵌入向量,处理异步调用""" + """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" + # 创建新的事件循环并在完成后立即关闭 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: - # 尝试获取当前事件循环 - asyncio.get_running_loop() - # 如果在事件循环中,使用线程池执行 - import concurrent.futures - - def run_in_thread(): - return asyncio.run(get_embedding(s)) - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_thread) - result = future.result() - if result is None: - logger.error(f"获取嵌入失败: {s}") - return [] - return result - except RuntimeError: - # 没有运行的事件循环,直接运行 - result = asyncio.run(get_embedding(s)) - if result is None: + # 创建新的LLMRequest实例 + from src.llm_models.utils_model import LLMRequest + from src.config.config import model_config + + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + + # 使用新的事件循环运行异步方法 + embedding, _ = loop.run_until_complete(llm.get_embedding(s)) + + if embedding and len(embedding) > 0: + return embedding + else: logger.error(f"获取嵌入失败: {s}") return [] - return result + + except Exception as e: + logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") + return [] + finally: + # 确保事件循环被正确关闭 + try: + loop.close() + except Exception: + pass def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]: """使用多线程批量获取嵌入向量 @@ -181,8 +187,14 @@ class EmbeddingStore: for i, s in enumerate(chunk_strs): try: - # 直接使用异步函数 - embedding = asyncio.run(llm.get_embedding(s)) + # 在线程中创建独立的事件循环 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + embedding = loop.run_until_complete(llm.get_embedding(s)) + finally: + loop.close() + if embedding and len(embedding) > 0: chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 else: diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index a3a5741d..04ceccb2 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -9,7 +9,6 @@ import networkx as nx import numpy as np from typing import List, Tuple, Set, Coroutine, Any, Dict from collections import Counter -from itertools import combinations import traceback from rich.traceback import install @@ -23,6 +22,8 @@ from src.chat.utils.chat_message_builder import ( build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive, ) # 导入 build_readable_messages + + # 添加cosine_similarity函数 def cosine_similarity(v1, v2): """计算余弦相似度""" @@ -51,18 +52,9 @@ def calculate_information_content(text): return entropy - - - logger = get_logger("memory") - - - - - - class MemoryGraph: def __init__(self): self.G = nx.Graph() # 使用 networkx 的图结构 @@ -96,7 +88,7 @@ class MemoryGraph: if "memory_items" in self.G.nodes[concept]: # 获取现有的记忆项(已经是str格式) existing_memory = self.G.nodes[concept]["memory_items"] - + # 如果现有记忆不为空,则使用LLM整合新旧记忆 if existing_memory and hippocampus_instance and hippocampus_instance.model_small: try: @@ -170,16 +162,16 @@ class MemoryGraph: second_layer_items.append(memory_items) return first_layer_items, second_layer_items - + async def _integrate_memories_with_llm(self, existing_memory: str, new_memory: str, llm_model: LLMRequest) -> str: """ 使用LLM整合新旧记忆内容 - + Args: existing_memory: 现有的记忆内容(字符串格式,可能包含多条记忆) new_memory: 新的记忆内容 llm_model: LLM模型实例 - + Returns: str: 整合后的记忆内容 """ @@ -203,8 +195,10 @@ class MemoryGraph: 整合后的记忆:""" # 调用LLM进行整合 - content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(integration_prompt) - + content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async( + integration_prompt + ) + if content and content.strip(): integrated_content = content.strip() logger.debug(f"LLM记忆整合成功,模型: {model_name}") @@ -212,7 +206,7 @@ class MemoryGraph: else: logger.warning("LLM返回的整合结果为空,使用默认连接方式") return f"{existing_memory} | {new_memory}" - + except Exception as e: logger.error(f"LLM记忆整合过程中出错: {e}") return f"{existing_memory} | {new_memory}" @@ -238,7 +232,11 @@ class MemoryGraph: if memory_items: # 删除整个节点 self.G.remove_node(topic) - return f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." if len(memory_items) > 50 else f"删除了节点 {topic} 的完整记忆: {memory_items}" + return ( + f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." + if len(memory_items) > 50 + else f"删除了节点 {topic} 的完整记忆: {memory_items}" + ) else: # 如果没有记忆项,删除该节点 self.G.remove_node(topic) @@ -263,38 +261,40 @@ class Hippocampus: self.parahippocampal_gyrus = ParahippocampalGyrus(self) # 从数据库加载记忆图 self.entorhinal_cortex.sync_memory_from_db() - self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.modify") + self.model_small = LLMRequest( + model_set=model_config.model_task_config.utils_small, request_type="memory.modify" + ) def get_all_node_names(self) -> list: """获取记忆图中所有节点的名字列表""" return list(self.memory_graph.G.nodes()) - + def calculate_weighted_activation(self, current_activation: float, edge_strength: int, target_node: str) -> float: """ 计算考虑节点权重的激活值 - + Args: current_activation: 当前激活值 edge_strength: 边的强度 target_node: 目标节点名称 - + Returns: float: 计算后的激活值 """ # 基础激活值计算 base_activation = current_activation - (1 / edge_strength) - + if base_activation <= 0: return 0.0 - + # 获取目标节点的权重 if target_node in self.memory_graph.G: node_data = self.memory_graph.G.nodes[target_node] node_weight = node_data.get("weight", 1.0) - + # 权重加成:每次整合增加10%激活值,最大加成200% weight_multiplier = 1.0 + min((node_weight - 1.0) * 0.1, 2.0) - + return base_activation * weight_multiplier else: return base_activation @@ -332,9 +332,7 @@ class Hippocampus: f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" f"如果确定找不出主题或者没有明显主题,返回。" ) - - - + return prompt @staticmethod @@ -403,7 +401,7 @@ class Hippocampus: memories.sort(key=lambda x: x[2], reverse=True) return memories - async def get_keywords_from_text(self, text: str) -> list: + async def get_keywords_from_text(self, text: str) -> Tuple[List[str], List]: """从文本中提取关键词。 Args: @@ -418,16 +416,13 @@ class Hippocampus: # 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算 text_length = len(text) topic_num: int | list[int] = 0 - - + words = jieba.cut(text) keywords_lite = [word for word in words if len(word) > 1] keywords_lite = list(set(keywords_lite)) if keywords_lite: logger.debug(f"提取关键词极简版: {keywords_lite}") - - if text_length <= 12: topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本) elif text_length <= 20: @@ -455,7 +450,7 @@ class Hippocampus: if keywords: logger.debug(f"提取关键词: {keywords}") - return keywords,keywords_lite + return keywords, keywords_lite async def get_memory_from_topic( self, @@ -570,20 +565,17 @@ class Hippocampus: for node, activation in remember_map.items(): logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):") node_data = self.memory_graph.G.nodes[node] - memory_items = node_data.get("memory_items", "") - # 直接使用完整的记忆内容 - if memory_items: + if memory_items := node_data.get("memory_items", ""): logger.debug("节点包含完整记忆") # 计算记忆与关键词的相似度 memory_words = set(jieba.cut(memory_items)) text_words = set(keywords) - all_words = memory_words | text_words - if all_words: + if all_words := memory_words | text_words: # 计算相似度(虽然这里没有使用,但保持逻辑一致性) v1 = [1 if word in memory_words else 0 for word in all_words] v2 = [1 if word in text_words else 0 for word in all_words] _ = cosine_similarity(v1, v2) # 计算但不使用,用_表示 - + # 添加完整记忆到结果中 all_memories.append((node, memory_items, activation)) else: @@ -613,7 +605,9 @@ class Hippocampus: return result - async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str],list[str]]: + async def get_activate_from_text( + self, text: str, max_depth: int = 3, fast_retrieval: bool = False + ) -> tuple[float, list[str], list[str]]: """从文本中提取关键词并获取相关记忆。 Args: @@ -627,13 +621,13 @@ class Hippocampus: float: 激活节点数与总节点数的比值 list[str]: 有效的关键词 """ - keywords,keywords_lite = await self.get_keywords_from_text(text) + keywords, keywords_lite = await self.get_keywords_from_text(text) # 过滤掉不存在于记忆图中的关键词 valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] if not valid_keywords: # logger.info("没有找到有效的关键词节点") - return 0, keywords,keywords_lite + return 0, keywords, keywords_lite logger.debug(f"有效的关键词: {', '.join(valid_keywords)}") @@ -700,7 +694,7 @@ class Hippocampus: activation_ratio = activation_ratio * 50 logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}") - return activation_ratio, keywords,keywords_lite + return activation_ratio, keywords, keywords_lite # 负责海马体与其他部分的交互 @@ -730,7 +724,7 @@ class EntorhinalCortex: continue memory_items = data.get("memory_items", "") - + # 直接检查字符串是否为空,不需要分割成列表 if not memory_items or memory_items.strip() == "": self.memory_graph.G.remove_node(concept) @@ -865,7 +859,9 @@ class EntorhinalCortex: end_time = time.time() logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}秒") - logger.info(f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边") + logger.info( + f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边" + ) async def resync_memory_to_db(self): """清空数据库并重新同步所有记忆数据""" @@ -888,7 +884,7 @@ class EntorhinalCortex: nodes_data = [] for concept, data in memory_nodes: memory_items = data.get("memory_items", "") - + # 直接检查字符串是否为空,不需要分割成列表 if not memory_items or memory_items.strip() == "": self.memory_graph.G.remove_node(concept) @@ -960,7 +956,7 @@ class EntorhinalCortex: # 清空当前图 self.memory_graph.G.clear() - + # 统计加载情况 total_nodes = 0 loaded_nodes = 0 @@ -969,7 +965,7 @@ class EntorhinalCortex: # 从数据库加载所有节点 nodes = list(GraphNodes.select()) total_nodes = len(nodes) - + for node in nodes: concept = node.concept try: @@ -978,7 +974,7 @@ class EntorhinalCortex: logger.warning(f"节点 {concept} 的memory_items为空,跳过") skipped_nodes += 1 continue - + # 直接使用memory_items memory_items = node.memory_items.strip() @@ -999,11 +995,15 @@ class EntorhinalCortex: last_modified = node.last_modified or current_time # 获取权重属性 - weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0 - + weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0 + # 添加节点到图中 self.memory_graph.G.add_node( - concept, memory_items=memory_items, weight=weight, created_time=created_time, last_modified=last_modified + concept, + memory_items=memory_items, + weight=weight, + created_time=created_time, + last_modified=last_modified, ) loaded_nodes += 1 except Exception as e: @@ -1044,9 +1044,11 @@ class EntorhinalCortex: if need_update: logger.info("[数据库] 已为缺失的时间字段进行补充") - + # 输出加载统计信息 - logger.info(f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes} 个") + logger.info( + f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes} 个" + ) # 负责整合,遗忘,合并记忆 @@ -1054,10 +1056,12 @@ class ParahippocampalGyrus: def __init__(self, hippocampus: Hippocampus): self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph - - self.memory_modify_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory.modify") - async def memory_compress(self, messages: list, compress_rate=0.1): + self.memory_modify_model = LLMRequest( + model_set=model_config.model_task_config.utils, request_type="memory.modify" + ) + + async def memory_compress(self, messages: list[DatabaseMessages], compress_rate=0.1): """压缩和总结消息内容,生成记忆主题和摘要。 Args: @@ -1083,7 +1087,6 @@ class ParahippocampalGyrus: # build_readable_messages 只返回一个字符串,不需要解包 input_text = build_readable_messages( messages, - merge_messages=True, # 合并连续消息 timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式 replace_bot_name=False, # 保留原始用户名 ) @@ -1163,7 +1166,7 @@ class ParahippocampalGyrus: similar_topics.sort(key=lambda x: x[1], reverse=True) similar_topics = similar_topics[:3] similar_topics_dict[topic] = similar_topics - + if global_config.debug.show_prompt: logger.info(f"prompt: {topic_what_prompt}") logger.info(f"压缩后的记忆: {compressed_memory}") @@ -1259,14 +1262,14 @@ class ParahippocampalGyrus: # --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 --- last_modified = node_data.get("last_modified", current_time) node_weight = node_data.get("weight", 1.0) - + # 条件1:检查是否长时间未修改 (使用配置的遗忘时间) time_threshold = 3600 * global_config.memory.memory_forget_time - + # 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘 # 权重为1时使用默认阈值,权重越高阈值越大(越难遗忘) adjusted_threshold = time_threshold * node_weight - + if current_time - last_modified > adjusted_threshold and memory_items: # 既然每个节点现在是完整记忆,直接删除整个节点 try: @@ -1315,8 +1318,6 @@ class ParahippocampalGyrus: logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒") - - class HippocampusManager: def __init__(self): self._hippocampus: Hippocampus = None # type: ignore @@ -1361,29 +1362,32 @@ class HippocampusManager: """为指定chat_id构建记忆(在heartFC_chat.py中调用)""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - + try: # 检查是否需要构建记忆 logger.info(f"为 {chat_id} 构建记忆") if memory_segment_manager.check_and_build_memory_for_chat(chat_id): logger.info(f"为 {chat_id} 构建记忆,需要构建记忆") messages = memory_segment_manager.get_messages_for_memory_build(chat_id, 50) - + build_probability = 0.3 * global_config.memory.memory_build_frequency - + if messages and random.random() < build_probability: logger.info(f"为 {chat_id} 构建记忆,消息数量: {len(messages)}") - + # 调用记忆压缩和构建 - compressed_memory, similar_topics_dict = await self._hippocampus.parahippocampal_gyrus.memory_compress( + ( + compressed_memory, + similar_topics_dict, + ) = await self._hippocampus.parahippocampal_gyrus.memory_compress( messages, global_config.memory.memory_compress_rate ) - + # 添加记忆节点 current_time = time.time() for topic, memory in compressed_memory: await self._hippocampus.memory_graph.add_dot(topic, memory, self._hippocampus) - + # 连接相似主题 if topic in similar_topics_dict: similar_topics = similar_topics_dict[topic] @@ -1391,23 +1395,23 @@ class HippocampusManager: if topic != similar_topic: strength = int(similarity * 10) self._hippocampus.memory_graph.G.add_edge( - topic, similar_topic, + topic, + similar_topic, strength=strength, created_time=current_time, - last_modified=current_time + last_modified=current_time, ) - + # 同步到数据库 await self._hippocampus.entorhinal_cortex.sync_memory_to_db() logger.info(f"为 {chat_id} 构建记忆完成") return True - + except Exception as e: logger.error(f"为 {chat_id} 构建记忆失败: {e}") return False - - return False + return False async def get_memory_from_topic( self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 @@ -1424,16 +1428,20 @@ class HippocampusManager: response = [] return response - async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]: + async def get_activate_from_text( + self, text: str, max_depth: int = 3, fast_retrieval: bool = False + ) -> tuple[float, list[str], list[str]]: """从文本中获取激活值的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") try: - response, keywords,keywords_lite = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval) + response, keywords, keywords_lite = await self._hippocampus.get_activate_from_text( + text, max_depth, fast_retrieval + ) except Exception as e: logger.error(f"文本产生激活值失败: {e}") logger.error(traceback.format_exc()) - return 0.0, [],[] + return 0.0, [], [] def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: """从关键词获取相关记忆的公共接口""" @@ -1455,81 +1463,78 @@ hippocampus_manager = HippocampusManager() # 在Hippocampus类中添加新的记忆构建管理器 class MemoryBuilder: """记忆构建器 - + 为每个chat_id维护消息缓存和触发机制,类似ExpressionLearner """ - + def __init__(self, chat_id: str): self.chat_id = chat_id self.last_update_time: float = time.time() self.last_processed_time: float = 0.0 - + def should_trigger_memory_build(self) -> bool: """检查是否应该触发记忆构建""" current_time = time.time() - + # 检查时间间隔 time_diff = current_time - self.last_update_time - if time_diff < 600 /global_config.memory.memory_build_frequency: + if time_diff < 600 / global_config.memory.memory_build_frequency: return False - + # 检查消息数量 - + recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_update_time, timestamp_end=current_time, ) - + logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}") - - if not recent_messages or len(recent_messages) < 30/global_config.memory.memory_build_frequency : + + if not recent_messages or len(recent_messages) < 30 / global_config.memory.memory_build_frequency: return False - + return True - - def get_messages_for_memory_build(self, threshold: int = 25) -> List[Dict[str, Any]]: + + def get_messages_for_memory_build(self, threshold: int = 25) -> List[DatabaseMessages]: """获取用于记忆构建的消息""" current_time = time.time() - - + messages = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_update_time, timestamp_end=current_time, limit=threshold, ) - tmp_msg = [msg.__dict__ for msg in messages] if messages else [] if messages: # 更新最后处理时间 self.last_processed_time = current_time self.last_update_time = current_time - return tmp_msg or [] - + return messages or [] class MemorySegmentManager: """记忆段管理器 - + 管理所有chat_id的MemoryBuilder实例,自动检查和触发记忆构建 """ - + def __init__(self): self.builders: Dict[str, MemoryBuilder] = {} - + def get_or_create_builder(self, chat_id: str) -> MemoryBuilder: """获取或创建指定chat_id的MemoryBuilder""" if chat_id not in self.builders: self.builders[chat_id] = MemoryBuilder(chat_id) return self.builders[chat_id] - + def check_and_build_memory_for_chat(self, chat_id: str) -> bool: """检查指定chat_id是否需要构建记忆,如果需要则返回True""" builder = self.get_or_create_builder(chat_id) return builder.should_trigger_memory_build() - - def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[Dict[str, Any]]: + + def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[DatabaseMessages]: """获取指定chat_id用于记忆构建的消息""" if chat_id not in self.builders: return [] @@ -1538,4 +1543,3 @@ class MemorySegmentManager: # 创建全局实例 memory_segment_manager = MemorySegmentManager() - diff --git a/src/chat/memory_system/memory_activator.py b/src/chat/memory_system/memory_activator.py index 7c773530..ce7daef5 100644 --- a/src/chat/memory_system/memory_activator.py +++ b/src/chat/memory_system/memory_activator.py @@ -1,17 +1,17 @@ import json +import random from json_repair import repair_json from typing import List, Tuple - -from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.utils.utils import parse_keywords_string from src.chat.utils.chat_message_builder import build_readable_messages -import random +from src.chat.memory_system.Hippocampus import hippocampus_manager +from src.llm_models.utils_model import LLMRequest logger = get_logger("memory_activator") @@ -75,19 +75,20 @@ class MemoryActivator: request_type="memory.selection", ) - - async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]: + async def activate_memory_with_chat_history( + self, target_message, chat_history: List[DatabaseMessages] + ) -> List[Tuple[str, str]]: """ 激活记忆 """ # 如果记忆系统被禁用,直接返回空列表 if not global_config.memory.enable_memory: return [] - + keywords_list = set() - - for msg in chat_history_prompt: - keywords = parse_keywords_string(msg.get("key_words", "")) + + for msg in chat_history: + keywords = parse_keywords_string(msg.key_words) if keywords: if len(keywords_list) < 30: # 最多容纳30个关键词 @@ -95,24 +96,22 @@ class MemoryActivator: logger.debug(f"提取关键词: {keywords_list}") else: break - + if not keywords_list: logger.debug("没有提取到关键词,返回空记忆列表") return [] - + # 从海马体获取相关记忆 related_memory = await hippocampus_manager.get_memory_from_topic( valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3 ) - + # logger.info(f"当前记忆关键词: {keywords_list}") logger.debug(f"获取到的记忆: {related_memory}") - + if not related_memory: logger.debug("海马体没有返回相关记忆") return [] - - used_ids = set() candidate_memories = [] @@ -120,12 +119,7 @@ class MemoryActivator: # 为每个记忆分配随机ID并过滤相关记忆 for memory in related_memory: keyword, content = memory - found = False - for kw in keywords_list: - if kw in content: - found = True - break - + found = any(kw in content for kw in keywords_list) if found: # 随机分配一个不重复的2位数id while True: @@ -138,95 +132,83 @@ class MemoryActivator: if not candidate_memories: logger.info("没有找到相关的候选记忆") return [] - + # 如果只有少量记忆,直接返回 if len(candidate_memories) <= 2: logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回") # 转换为 (keyword, content) 格式 return [(mem["keyword"], mem["content"]) for mem in candidate_memories] - - # 使用 LLM 选择合适的记忆 - selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories) - - return selected_memories - async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]: + return await self._select_memories_with_llm(target_message, chat_history, candidate_memories) + + async def _select_memories_with_llm( + self, target_message, chat_history: List[DatabaseMessages], candidate_memories + ) -> List[Tuple[str, str]]: """ 使用 LLM 选择合适的记忆 - + Args: target_message: 目标消息 chat_history_prompt: 聊天历史 candidate_memories: 候选记忆列表,每个记忆包含 memory_id、keyword、content - + Returns: List[Tuple[str, str]]: 选择的记忆列表,格式为 (keyword, content) """ try: # 构建聊天历史字符串 obs_info_text = build_readable_messages( - chat_history_prompt, + chat_history, replace_bot_name=True, - merge_messages=False, timestamp_mode="relative", read_mark=0.0, show_actions=True, ) - - + # 构建记忆信息字符串 memory_lines = [] for memory in candidate_memories: memory_id = memory["memory_id"] keyword = memory["keyword"] content = memory["content"] - + # 将 content 列表转换为字符串 if isinstance(content, list): content_str = " | ".join(str(item) for item in content) else: content_str = str(content) - + memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}") - + memory_info = "\n".join(memory_lines) - + # 获取并格式化 prompt prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt") formatted_prompt = prompt_template.format( - obs_info_text=obs_info_text, - target_message=target_message, - memory_info=memory_info + obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info ) - - - + # 调用 LLM response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async( - formatted_prompt, - temperature=0.3, - max_tokens=150 + formatted_prompt, temperature=0.3, max_tokens=150 ) - + if global_config.debug.show_prompt: logger.info(f"记忆选择 prompt: {formatted_prompt}") logger.info(f"LLM 记忆选择响应: {response}") else: logger.debug(f"记忆选择 prompt: {formatted_prompt}") logger.debug(f"LLM 记忆选择响应: {response}") - + # 解析响应获取选择的记忆编号 try: fixed_json = repair_json(response) - + # 解析为 Python 对象 result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json - - # 提取 memory_ids 字段 - memory_ids_str = result.get("memory_ids", "") - - # 解析逗号分隔的编号 - if memory_ids_str: + + # 提取 memory_ids 字段并解析逗号分隔的编号 + if memory_ids_str := result.get("memory_ids", ""): memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()] # 过滤掉空字符串和无效编号 valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3] @@ -236,26 +218,24 @@ class MemoryActivator: except Exception as e: logger.error(f"解析记忆选择响应失败: {e}", exc_info=True) selected_memory_ids = [] - + # 根据编号筛选记忆 selected_memories = [] memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories} - - for memory_id in selected_memory_ids: - if memory_id in memory_id_to_memory: - selected_memories.append(memory_id_to_memory[memory_id]) - + + selected_memories = [ + memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory + ] logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}") logger.info(f"最终选择的记忆数量: {len(selected_memories)}") - + # 转换为 (keyword, content) 格式 return [(mem["keyword"], mem["content"]) for mem in selected_memories] - + except Exception as e: logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True) # 出错时返回前3个候选记忆作为备选,转换为 (keyword, content) 格式 return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]] - init_prompt() diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index aa63aa8f..03c72ffc 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -70,13 +70,10 @@ class ActionModifier: timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 10), ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half] + chat_content = build_readable_messages( - temp_msg_list_before_now_half, + message_list_before_now_half, replace_bot_name=True, - merge_messages=False, timestamp_mode="relative", read_mark=0.0, show_actions=True, diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 4b0320ff..5e0695c3 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -33,6 +33,9 @@ def init_prompt(): {time_block} {identity_block} 你现在需要根据聊天内容,选择的合适的action来参与聊天。 +请你根据以下行事风格来决定action: +{plan_style} + {chat_context_description},以下是具体的聊天内容 {chat_content_block} @@ -280,11 +283,8 @@ class ActionPlanner: timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.6), ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - temp_msg_list_before_now = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_content_block, message_id_list = build_readable_messages_with_id( - messages=temp_msg_list_before_now, + messages=message_list_before_now, timestamp_mode="normal_no_YMD", read_mark=self.last_obs_time_mark, truncate=True, @@ -388,6 +388,7 @@ class ActionPlanner: action_options_text=action_options_block, moderation_prompt=moderation_prompt_block, identity_block=identity_block, + plan_style = global_config.personality.plan_style ) return prompt, message_id_list except Exception as e: diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index adba061a..87ff7bdb 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -8,6 +8,7 @@ from typing import List, Optional, Dict, Any, Tuple from datetime import datetime from src.mais4u.mai_think import mai_thinking_manager from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages from src.config.config import global_config, model_config from src.individuality.individuality import get_individuality from src.llm_models.utils_model import LLMRequest @@ -156,7 +157,7 @@ class DefaultReplyer: extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, - choosen_actions: Optional[List[Dict[str, Any]]] = None, + chosen_actions: Optional[List[Dict[str, Any]]] = None, enable_tool: bool = True, from_plugin: bool = True, stream_id: Optional[str] = None, @@ -171,7 +172,7 @@ class DefaultReplyer: extra_info: 额外信息,用于补充上下文 reply_reason: 回复原因 available_actions: 可用的动作信息字典 - choosen_actions: 已选动作 + chosen_actions: 已选动作 enable_tool: 是否启用工具调用 from_plugin: 是否来自插件 @@ -189,7 +190,7 @@ class DefaultReplyer: prompt, selected_expressions = await self.build_prompt_reply_context( extra_info=extra_info, available_actions=available_actions, - choosen_actions=choosen_actions, + chosen_actions=chosen_actions, enable_tool=enable_tool, reply_message=reply_message, reply_reason=reply_reason, @@ -296,7 +297,7 @@ class DefaultReplyer: if not sender: return "" - + if sender == global_config.bot.nickname: return "" @@ -352,7 +353,7 @@ class DefaultReplyer: return f"{expression_habits_title}\n{expression_habits_block}", selected_ids - async def build_memory_block(self, chat_history: List[Dict[str, Any]], target: str) -> str: + async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str: """构建记忆块 Args: @@ -369,7 +370,7 @@ class DefaultReplyer: instant_memory = None running_memories = await self.memory_activator.activate_memory_with_chat_history( - target_message=target, chat_history_prompt=chat_history + target_message=target, chat_history=chat_history ) if global_config.memory.enable_instant_memory: @@ -433,7 +434,7 @@ class DefaultReplyer: logger.error(f"工具信息获取失败: {e}") return "" - def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: + def _parse_reply_target(self, target_message: Optional[str]) -> Tuple[str, str]: """解析回复目标消息 Args: @@ -514,7 +515,7 @@ class DefaultReplyer: return name, result, duration def build_s4u_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str + self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str ) -> Tuple[str, str]: """ 构建 s4u 风格的分离对话 prompt @@ -530,16 +531,16 @@ class DefaultReplyer: bot_id = str(global_config.bot.qq_account) # 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话 - for msg_dict in message_list_before_now: + for msg in message_list_before_now: try: - msg_user_id = str(msg_dict.get("user_id")) - reply_to = msg_dict.get("reply_to", "") + msg_user_id = str(msg.user_info.user_id) + reply_to = msg.reply_to _platform, reply_to_user_id = self._parse_reply_target(reply_to) if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id: # bot 和目标用户的对话 - core_dialogue_list.append(msg_dict) + core_dialogue_list.append(msg) except Exception as e: - logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}") + logger.error(f"处理消息记录时出错: {msg}, 错误: {e}") # 构建背景对话 prompt all_dialogue_prompt = "" @@ -574,7 +575,6 @@ class DefaultReplyer: core_dialogue_prompt_str = build_readable_messages( core_dialogue_list, replace_bot_name=True, - merge_messages=False, timestamp_mode="normal_no_YMD", read_mark=0.0, truncate=True, @@ -640,7 +640,7 @@ class DefaultReplyer: action_descriptions = "" if available_actions: - action_descriptions = "你可以做以下这些动作:\n" + action_descriptions = "除了进行回复之外,你可以做以下这些动作,不过这些动作由另一个模型决定,:\n" for action_name, action_info in available_actions.items(): action_description = action_info.description action_descriptions += f"- {action_name}: {action_description}\n" @@ -658,7 +658,7 @@ class DefaultReplyer: choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n" if choosen_action_descriptions: - action_descriptions += "根据聊天情况,你决定在回复的同时做以下这些动作:\n" + action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n" action_descriptions += choosen_action_descriptions return action_descriptions @@ -668,7 +668,7 @@ class DefaultReplyer: extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, - choosen_actions: Optional[List[Dict[str, Any]]] = None, + chosen_actions: Optional[List[Dict[str, Any]]] = None, enable_tool: bool = True, reply_message: Optional[Dict[str, Any]] = None, ) -> Tuple[str, List[int]]: @@ -679,7 +679,7 @@ class DefaultReplyer: extra_info: 额外信息,用于补充上下文 reply_reason: 回复原因 available_actions: 可用动作 - choosen_actions: 已选动作 + chosen_actions: 已选动作 enable_timeout: 是否启用超时处理 enable_tool: 是否启用工具调用 reply_message: 回复的原始消息 @@ -712,27 +712,21 @@ class DefaultReplyer: target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size * 1, ) - temp_msg_list_before_long = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_long] - # TODO: 修复! message_list_before_short = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.33), ) - temp_msg_list_before_short = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_short] chat_talking_prompt_short = build_readable_messages( - temp_msg_list_before_short, + message_list_before_short, replace_bot_name=True, - merge_messages=False, timestamp_mode="relative", read_mark=0.0, show_actions=True, @@ -744,12 +738,12 @@ class DefaultReplyer: self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits" ), self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"), - self._time_and_run_task(self.build_memory_block(temp_msg_list_before_short, target), "memory_block"), + self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"), self._time_and_run_task( self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info" ), self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"), - self._time_and_run_task(self.build_actions_prompt(available_actions, choosen_actions), "actions_info"), + self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"), ) # 任务名称中英文映射 @@ -810,25 +804,9 @@ class DefaultReplyer: else: reply_target_block = "" - # if is_group_chat: - # chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") - # chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2") - # else: - # chat_target_name = "对方" - # if self.chat_target_info: - # chat_target_name = ( - # self.chat_target_info.get("person_name") or self.chat_target_info.get("user_nickname") or "对方" - # ) - # chat_target_1 = await global_prompt_manager.format_prompt( - # "chat_target_private1", sender_name=chat_target_name - # ) - # chat_target_2 = await global_prompt_manager.format_prompt( - # "chat_target_private2", sender_name=chat_target_name - # ) - # 构建分离的对话 prompt core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts( - temp_msg_list_before_long, user_id, sender + message_list_before_now_long, user_id, sender ) if global_config.bot.qq_account == user_id and platform == global_config.bot.platform: @@ -879,7 +857,7 @@ class DefaultReplyer: reason: str, reply_to: str, reply_message: Optional[Dict[str, Any]] = None, - ) -> Tuple[str, List[int]]: # sourcery skip: merge-else-if-into-elif, remove-redundant-if + ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream chat_id = chat_stream.stream_id is_group_chat = bool(chat_stream.group_info) @@ -902,20 +880,16 @@ class DefaultReplyer: timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 15), ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half] chat_talking_prompt_half = build_readable_messages( - temp_msg_list_before_now_half, + message_list_before_now_half, replace_bot_name=True, - merge_messages=False, timestamp_mode="relative", read_mark=0.0, show_actions=True, ) # 并行执行2个构建任务 - (expression_habits_block, selected_expressions), relation_info = await asyncio.gather( + (expression_habits_block, _), relation_info = await asyncio.gather( self.build_expression_habits(chat_talking_prompt_half, target), self.build_relation_info(sender, target), ) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 64e81557..51ecb46d 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1,4 +1,4 @@ -import time # 导入 time 模块以获取当前时间 +import time import random import re @@ -6,14 +6,17 @@ from typing import List, Dict, Any, Tuple, Optional, Callable from rich.traceback import install from src.config.config import global_config +from src.common.logger import get_logger from src.common.message_repository import find_messages, count_messages from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.message_data_model import MessageAndActionModel from src.common.database.database_model import ActionRecords from src.common.database.database_model import Images from src.person_info.person_info import Person, get_person_id from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids install(extra_lines=3) +logger = get_logger("chat_message_builder") def replace_user_references_sync( @@ -349,7 +352,9 @@ def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[DatabaseMessages]: +def get_raw_msg_before_timestamp_with_users( + timestamp: float, person_ids: list, limit: int = 0 +) -> List[DatabaseMessages]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ @@ -390,16 +395,16 @@ def num_new_messages_since_with_users( def _build_readable_messages_internal( - messages: List[Dict[str, Any]], + messages: List[MessageAndActionModel], replace_bot_name: bool = True, - merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, pic_id_mapping: Optional[Dict[str, str]] = None, pic_counter: int = 1, show_pic: bool = True, - message_id_list: Optional[List[Dict[str, Any]]] = None, + message_id_list: Optional[List[DatabaseMessages]] = None, ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: + # sourcery skip: use-getitem-for-re-match-groups """ 内部辅助函数,构建可读消息字符串和原始消息详情列表。 @@ -418,7 +423,7 @@ def _build_readable_messages_internal( if not messages: return "", [], pic_id_mapping or {}, pic_counter - message_details_raw: List[Tuple[float, str, str, bool]] = [] + detailed_messages_raw: List[Tuple[float, str, str, bool]] = [] # 使用传入的映射字典,如果没有则创建新的 if pic_id_mapping is None: @@ -426,25 +431,26 @@ def _build_readable_messages_internal( current_pic_counter = pic_counter # 创建时间戳到消息ID的映射,用于在消息前添加[id]标识符 - timestamp_to_id = {} + timestamp_to_id_mapping: Dict[float, str] = {} if message_id_list: - for item in message_id_list: - message = item.get("message", {}) - timestamp = message.get("time") + for msg in message_id_list: + timestamp = msg.time if timestamp is not None: - timestamp_to_id[timestamp] = item.get("id", "") + timestamp_to_id_mapping[timestamp] = msg.message_id - def process_pic_ids(content: str) -> str: + def process_pic_ids(content: Optional[str]) -> str: """处理内容中的图片ID,将其替换为[图片x]格式""" - nonlocal current_pic_counter + if content is None: + logger.warning("Content is None when processing pic IDs.") + raise ValueError("Content is None") # 匹配 [picid:xxxxx] 格式 pic_pattern = r"\[picid:([^\]]+)\]" - def replace_pic_id(match): + def replace_pic_id(match: re.Match) -> str: nonlocal current_pic_counter + nonlocal pic_counter pic_id = match.group(1) - if pic_id not in pic_id_mapping: pic_id_mapping[pic_id] = f"图片{current_pic_counter}" current_pic_counter += 1 @@ -453,42 +459,23 @@ def _build_readable_messages_internal( return re.sub(pic_pattern, replace_pic_id, content) - # 1 & 2: 获取发送者信息并提取消息组件 - for msg in messages: - # 检查是否是动作记录 - if msg.get("is_action_record", False): - is_action = True - timestamp: float = msg.get("time") # type: ignore - content = msg.get("display_message", "") + # 1: 获取发送者信息并提取消息组件 + for message in messages: + if message.is_action_record: # 对于动作记录,也处理图片ID - content = process_pic_ids(content) - message_details_raw.append((timestamp, global_config.bot.nickname, content, is_action)) + content = process_pic_ids(message.display_message) + detailed_messages_raw.append((message.time, message.user_nickname, content, True)) continue - # 检查并修复缺少的user_info字段 - if "user_info" not in msg: - # 创建user_info字段 - msg["user_info"] = { - "platform": msg.get("user_platform", ""), - "user_id": msg.get("user_id", ""), - "user_nickname": msg.get("user_nickname", ""), - "user_cardname": msg.get("user_cardname", ""), - } + platform = message.user_platform + user_id = message.user_id + user_nickname = message.user_nickname + user_cardname = message.user_cardname - user_info = msg.get("user_info", {}) - platform = user_info.get("platform") - user_id = user_info.get("user_id") - - user_nickname = user_info.get("user_nickname") - user_cardname = user_info.get("user_cardname") - - timestamp: float = msg.get("time") # type: ignore - content: str - if msg.get("display_message"): - content = msg.get("display_message", "") - else: - content = msg.get("processed_plain_text", "") # 默认空字符串 + timestamp = message.time + content = message.display_message or message.processed_plain_text or "" + # 向下兼容 if "ᶠ" in content: content = content.replace("ᶠ", "") if "ⁿ" in content: @@ -504,52 +491,32 @@ def _build_readable_messages_internal( person = Person(platform=platform, user_id=user_id) # 根据 replace_bot_name 参数决定是否替换机器人名称 - person_name: str + person_name = ( + person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人") + ) if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" - else: - person_name = person.person_name or user_id # type: ignore - - # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 - if not person_name: - if user_cardname: - person_name = f"昵称:{user_cardname}" - elif user_nickname: - person_name = f"{user_nickname}" - else: - person_name = "某人" # 使用独立函数处理用户引用格式 - content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name) + if content := replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name): + detailed_messages_raw.append((timestamp, person_name, content, False)) - target_str = "这是QQ的一个功能,用于提及某人,但没那么明显" - if target_str in content and random.random() < 0.6: - content = content.replace(target_str, "") - - if content != "": - message_details_raw.append((timestamp, person_name, content, False)) - - if not message_details_raw: + if not detailed_messages_raw: return "", [], pic_id_mapping, current_pic_counter - message_details_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面 + detailed_messages_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面 + detailed_message: List[Tuple[float, str, str, bool]] = [] - # 为每条消息添加一个标记,指示它是否是动作记录 - message_details_with_flags = [] - for timestamp, name, content, is_action in message_details_raw: - message_details_with_flags.append((timestamp, name, content, is_action)) - - # 应用截断逻辑 (如果 truncate 为 True) - message_details: List[Tuple[float, str, str, bool]] = [] - n_messages = len(message_details_with_flags) - if truncate and n_messages > 0: - for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags): + # 2. 应用消息截断逻辑 + messages_count = len(detailed_messages_raw) + if truncate and messages_count > 0: + for i, (timestamp, name, content, is_action) in enumerate(detailed_messages_raw): # 对于动作记录,不进行截断 if is_action: - message_details.append((timestamp, name, content, is_action)) + detailed_message.append((timestamp, name, content, is_action)) continue - percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1) + percentile = i / messages_count # 计算消息在列表中的位置百分比 (0 <= percentile < 1) original_len = len(content) limit = -1 # 默认不截断 @@ -562,116 +529,42 @@ def _build_readable_messages_internal( elif percentile < 0.7: # 60% 到 80% 之前的消息 (即中间的 20%) limit = 200 replace_content = "......(内容太长了)" - elif percentile < 1.0: # 80% 到 100% 之前的消息 (即较新的 20%) + elif percentile <= 1.0: # 80% 到 100% 之前的消息 (即较新的 20%) limit = 400 - replace_content = "......(太长了)" + replace_content = "......(内容太长了)" truncated_content = content if 0 < limit < original_len: truncated_content = f"{content[:limit]}{replace_content}" - message_details.append((timestamp, name, truncated_content, is_action)) + detailed_message.append((timestamp, name, truncated_content, is_action)) else: # 如果不截断,直接使用原始列表 - message_details = message_details_with_flags + detailed_message = detailed_messages_raw - # 3: 合并连续消息 (如果 merge_messages 为 True) - merged_messages = [] - if merge_messages and message_details: - # 初始化第一个合并块 - current_merge = { - "name": message_details[0][1], - "start_time": message_details[0][0], - "end_time": message_details[0][0], - "content": [message_details[0][2]], - "is_action": message_details[0][3], - } + # 3: 格式化为字符串 + output_lines: List[str] = [] - for i in range(1, len(message_details)): - timestamp, name, content, is_action = message_details[i] + for timestamp, name, content, is_action in detailed_message: + readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode) - # 对于动作记录,不进行合并 - if is_action or current_merge["is_action"]: - # 保存当前的合并块 - merged_messages.append(current_merge) - # 创建新的块 - current_merge = { - "name": name, - "start_time": timestamp, - "end_time": timestamp, - "content": [content], - "is_action": is_action, - } - continue + # 查找消息id(如果有)并构建id_prefix + message_id = timestamp_to_id_mapping.get(timestamp) + id_prefix = f"[{message_id}]" if message_id else "" - # 如果是同一个人发送的连续消息且时间间隔小于等于60秒 - if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60): - current_merge["content"].append(content) - current_merge["end_time"] = timestamp # 更新最后消息时间 - else: - # 保存上一个合并块 - merged_messages.append(current_merge) - # 开始新的合并块 - current_merge = { - "name": name, - "start_time": timestamp, - "end_time": timestamp, - "content": [content], - "is_action": is_action, - } - # 添加最后一个合并块 - merged_messages.append(current_merge) - elif message_details: # 如果不合并消息,则每个消息都是一个独立的块 - for timestamp, name, content, is_action in message_details: - merged_messages.append( - { - "name": name, - "start_time": timestamp, # 起始和结束时间相同 - "end_time": timestamp, - "content": [content], # 内容只有一个元素 - "is_action": is_action, - } - ) - - # 4 & 5: 格式化为字符串 - output_lines = [] - - for _i, merged in enumerate(merged_messages): - # 使用指定的 timestamp_mode 格式化时间 - readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode) - - # 查找对应的消息ID - message_id = timestamp_to_id.get(merged["start_time"], "") - id_prefix = f"[{message_id}] " if message_id else "" - - # 检查是否是动作记录 - if merged["is_action"]: + if is_action: # 对于动作记录,使用特殊格式 - output_lines.append(f"{id_prefix}{readable_time}, {merged['content'][0]}") + output_lines.append(f"{id_prefix}{readable_time}, {content}") else: - header = f"{id_prefix}{readable_time}, {merged['name']} :" - output_lines.append(header) - # 将内容合并,并添加缩进 - for line in merged["content"]: - stripped_line = line.strip() - if stripped_line: # 过滤空行 - # 移除末尾句号,添加分号 - 这个逻辑似乎有点奇怪,暂时保留 - if stripped_line.endswith("。"): - stripped_line = stripped_line[:-1] - # 如果内容被截断,结尾已经是 ...(内容太长),不再添加分号 - if not stripped_line.endswith("(内容太长)"): - output_lines.append(f"{stripped_line}") - else: - output_lines.append(stripped_line) # 直接添加截断后的内容 + output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}") output_lines.append("\n") # 在每个消息块后添加换行,保持可读性 - # 移除可能的多余换行,然后合并 formatted_string = "".join(output_lines).strip() # 返回格式化后的字符串、消息详情列表、图片映射字典和更新后的计数器 return ( formatted_string, - [(t, n, c) for t, n, c, is_action in message_details if not is_action], + [(t, n, c) for t, n, c, is_action in detailed_message if not is_action], pic_id_mapping, current_pic_counter, ) @@ -755,9 +648,8 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str: async def build_readable_messages_with_list( - messages: List[Dict[str, Any]], + messages: List[DatabaseMessages], replace_bot_name: bool = True, - merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, ) -> Tuple[str, List[Tuple[float, str, str]]]: @@ -766,7 +658,7 @@ async def build_readable_messages_with_list( 允许通过参数控制格式化行为。 """ formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( - messages, replace_bot_name, merge_messages, timestamp_mode, truncate + convert_DatabaseMessages_to_MessageAndActionModel(messages), replace_bot_name, timestamp_mode, truncate ) if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): @@ -776,15 +668,14 @@ async def build_readable_messages_with_list( def build_readable_messages_with_id( - messages: List[Dict[str, Any]], + messages: List[DatabaseMessages], replace_bot_name: bool = True, - merge_messages: bool = False, timestamp_mode: str = "relative", read_mark: float = 0.0, truncate: bool = False, show_actions: bool = False, show_pic: bool = True, -) -> Tuple[str, List[Dict[str, Any]]]: +) -> Tuple[str, List[DatabaseMessages]]: """ 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 允许通过参数控制格式化行为。 @@ -794,7 +685,6 @@ def build_readable_messages_with_id( formatted_string = build_readable_messages( messages=messages, replace_bot_name=replace_bot_name, - merge_messages=merge_messages, timestamp_mode=timestamp_mode, truncate=truncate, show_actions=show_actions, @@ -807,15 +697,14 @@ def build_readable_messages_with_id( def build_readable_messages( - messages: List[Dict[str, Any]], + messages: List[DatabaseMessages], replace_bot_name: bool = True, - merge_messages: bool = False, timestamp_mode: str = "relative", read_mark: float = 0.0, truncate: bool = False, show_actions: bool = False, show_pic: bool = True, - message_id_list: Optional[List[Dict[str, Any]]] = None, + message_id_list: Optional[List[DatabaseMessages]] = None, ) -> str: # sourcery skip: extract-method """ 将消息列表转换为可读的文本格式。 @@ -831,19 +720,32 @@ def build_readable_messages( truncate: 是否截断长消息 show_actions: 是否显示动作记录 """ + # WIP HERE and BELOW ---------------------------------------------- # 创建messages的深拷贝,避免修改原始列表 if not messages: return "" - copy_messages = [msg.copy() for msg in messages] + copy_messages: List[MessageAndActionModel] = [ + MessageAndActionModel( + msg.time, + msg.user_info.user_id, + msg.user_info.platform, + msg.user_info.user_nickname, + msg.user_info.user_cardname, + msg.processed_plain_text, + msg.display_message, + msg.chat_info.platform, + ) + for msg in messages + ] if show_actions and copy_messages: # 获取所有消息的时间范围 - min_time = min(msg.get("time", 0) for msg in copy_messages) - max_time = max(msg.get("time", 0) for msg in copy_messages) + min_time = min(msg.time or 0 for msg in copy_messages) + max_time = max(msg.time or 0 for msg in copy_messages) # 从第一条消息中获取chat_id - chat_id = copy_messages[0].get("chat_id") if copy_messages else None + chat_id = messages[0].chat_id if messages else None # 获取这个时间范围内的动作记录,并匹配chat_id actions_in_range = ( @@ -863,34 +765,34 @@ def build_readable_messages( ) # 合并两部分动作记录 - actions = list(actions_in_range) + list(action_after_latest) + actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest) # 将动作记录转换为消息格式 for action in actions: # 只有当build_into_prompt为True时才添加动作记录 if action.action_build_into_prompt: - action_msg = { - "time": action.time, - "user_id": global_config.bot.qq_account, # 使用机器人的QQ账号 - "user_nickname": global_config.bot.nickname, # 使用机器人的昵称 - "user_cardname": "", # 机器人没有群名片 - "processed_plain_text": f"{action.action_prompt_display}", - "display_message": f"{action.action_prompt_display}", - "chat_info_platform": action.chat_info_platform, - "is_action_record": True, # 添加标识字段 - "action_name": action.action_name, # 保存动作名称 - } + action_msg = MessageAndActionModel( + time=float(action.time), # type: ignore + user_id=global_config.bot.qq_account, # 使用机器人的QQ账号 + user_platform=global_config.bot.platform, # 使用机器人的平台 + user_nickname=global_config.bot.nickname, # 使用机器人的用户名 + user_cardname="", # 机器人没有群名片 + processed_plain_text=f"{action.action_prompt_display}", + display_message=f"{action.action_prompt_display}", + chat_info_platform=str(action.chat_info_platform), + is_action_record=True, # 添加标识字段 + action_name=str(action.action_name), # 保存动作名称 + ) copy_messages.append(action_msg) # 重新按时间排序 - copy_messages.sort(key=lambda x: x.get("time", 0)) + copy_messages.sort(key=lambda x: x.time or 0) if read_mark <= 0: # 没有有效的 read_mark,直接格式化所有消息 formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal( copy_messages, replace_bot_name, - merge_messages, timestamp_mode, truncate, show_pic=show_pic, @@ -905,8 +807,8 @@ def build_readable_messages( return formatted_string else: # 按 read_mark 分割消息 - messages_before_mark = [msg for msg in copy_messages if msg.get("time", 0) <= read_mark] - messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark] + messages_before_mark = [msg for msg in copy_messages if (msg.time or 0) <= read_mark] + messages_after_mark = [msg for msg in copy_messages if (msg.time or 0) > read_mark] # 共享的图片映射字典和计数器 pic_id_mapping = {} @@ -916,7 +818,6 @@ def build_readable_messages( formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal( messages_before_mark, replace_bot_name, - merge_messages, timestamp_mode, truncate, pic_id_mapping, @@ -927,7 +828,6 @@ def build_readable_messages( formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( messages_after_mark, replace_bot_name, - merge_messages, timestamp_mode, False, pic_id_mapping, @@ -960,13 +860,13 @@ def build_readable_messages( return "".join(result_parts) -async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: +async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str: """ 构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。 处理 回复 和 @ 字段,将bbb映射为匿名占位符。 """ if not messages: - print("111111111111没有消息,无法构建匿名消息") + logger.warning("没有消息,无法构建匿名消息") return "" person_map = {} @@ -1017,14 +917,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: for msg in messages: try: - platform: str = msg.get("chat_info_platform") # type: ignore - user_id = msg.get("user_id") - _timestamp = msg.get("time") - content: str = "" - if msg.get("display_message"): - content = msg.get("display_message", "") - else: - content = msg.get("processed_plain_text", "") + platform = msg.chat_info.platform + user_id = msg.user_info.user_id + content = msg.display_message or msg.processed_plain_text or "" if "ᶠ" in content: content = content.replace("ᶠ", "") @@ -1101,3 +996,22 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: person_ids_set.add(person_id) return list(person_ids_set) # 将集合转换为列表返回 + + +def convert_DatabaseMessages_to_MessageAndActionModel(message: List[DatabaseMessages]) -> List[MessageAndActionModel]: + """ + 将 DatabaseMessages 列表转换为 MessageAndActionModel 列表。 + """ + return [ + MessageAndActionModel( + time=msg.time, + user_id=msg.user_info.user_id, + user_platform=msg.user_info.platform, + user_nickname=msg.user_info.user_nickname, + user_cardname=msg.user_info.user_cardname, + processed_plain_text=msg.processed_plain_text, + display_message=msg.display_message, + chat_info_platform=msg.chat_info.platform, + ) + for msg in message + ] diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index d0976e9c..3528fe4b 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -12,6 +12,7 @@ from typing import Optional, Tuple, Dict, List, Any from src.common.logger import get_logger from src.common.data_models.info_data_model import TargetPersonInfo +from src.common.data_models.database_data_model import DatabaseMessages from src.common.message_repository import find_messages, count_messages from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecv @@ -113,6 +114,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: """获取文本的embedding向量""" + # 每次都创建新的LLMRequest实例以避免事件循环冲突 llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) try: embedding, _ = await llm.get_embedding(text) @@ -151,10 +153,13 @@ def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> li if ( (db_msg.user_info.platform, db_msg.user_info.user_id) != sender and db_msg.user_info.user_id != global_config.bot.qq_account - and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) not in who_chat_in_group + and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) + not in who_chat_in_group and len(who_chat_in_group) < 5 ): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目 - who_chat_in_group.append((db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)) + who_chat_in_group.append( + (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) + ) return who_chat_in_group @@ -640,9 +645,9 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: target_info = TargetPersonInfo( platform=platform, user_id=user_id, - user_nickname=user_info.user_nickname, # type: ignore + user_nickname=user_info.user_nickname, # type: ignore person_id=None, - person_name=None + person_name=None, ) # Try to fetch person info @@ -669,17 +674,17 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: return is_group_chat, chat_target_info -def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]: +def assign_message_ids(messages: List[DatabaseMessages]) -> List[DatabaseMessages]: """ 为消息列表中的每个消息分配唯一的简短随机ID - + Args: messages: 消息列表 - + Returns: - 包含 {'id': str, 'message': any} 格式的字典列表 + List[DatabaseMessages]: 分配了唯一ID的消息列表(写入message_id属性) """ - result = [] + result: List[DatabaseMessages] = list(messages) # 复制原始消息列表 used_ids = set() len_i = len(messages) if len_i > 100: @@ -688,95 +693,86 @@ def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]: else: a = 1 b = 9 - - for i, message in enumerate(messages): + + for i, _ in enumerate(result): # 生成唯一的简短ID while True: # 使用索引+随机数生成简短ID random_suffix = random.randint(a, b) - message_id = f"m{i+1}{random_suffix}" - + message_id = f"m{i + 1}{random_suffix}" + if message_id not in used_ids: used_ids.add(message_id) break - - result.append({ - 'id': message_id, - 'message': message - }) - + result[i].message_id = message_id + return result -def assign_message_ids_flexible( - messages: list, - prefix: str = "msg", - id_length: int = 6, - use_timestamp: bool = False -) -> list: - """ - 为消息列表中的每个消息分配唯一的简短随机ID(增强版) - - Args: - messages: 消息列表 - prefix: ID前缀,默认为"msg" - id_length: ID的总长度(不包括前缀),默认为6 - use_timestamp: 是否在ID中包含时间戳,默认为False - - Returns: - 包含 {'id': str, 'message': any} 格式的字典列表 - """ - result = [] - used_ids = set() - - for i, message in enumerate(messages): - # 生成唯一的ID - while True: - if use_timestamp: - # 使用时间戳的后几位 + 随机字符 - timestamp_suffix = str(int(time.time() * 1000))[-3:] - remaining_length = id_length - 3 - random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length)) - message_id = f"{prefix}{timestamp_suffix}{random_chars}" - else: - # 使用索引 + 随机字符 - index_str = str(i + 1) - remaining_length = max(1, id_length - len(index_str)) - random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length)) - message_id = f"{prefix}{index_str}{random_chars}" - - if message_id not in used_ids: - used_ids.add(message_id) - break - - result.append({ - 'id': message_id, - 'message': message - }) - - return result +# def assign_message_ids_flexible( +# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False +# ) -> list: +# """ +# 为消息列表中的每个消息分配唯一的简短随机ID(增强版) + +# Args: +# messages: 消息列表 +# prefix: ID前缀,默认为"msg" +# id_length: ID的总长度(不包括前缀),默认为6 +# use_timestamp: 是否在ID中包含时间戳,默认为False + +# Returns: +# 包含 {'id': str, 'message': any} 格式的字典列表 +# """ +# result = [] +# used_ids = set() + +# for i, message in enumerate(messages): +# # 生成唯一的ID +# while True: +# if use_timestamp: +# # 使用时间戳的后几位 + 随机字符 +# timestamp_suffix = str(int(time.time() * 1000))[-3:] +# remaining_length = id_length - 3 +# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length)) +# message_id = f"{prefix}{timestamp_suffix}{random_chars}" +# else: +# # 使用索引 + 随机字符 +# index_str = str(i + 1) +# remaining_length = max(1, id_length - len(index_str)) +# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length)) +# message_id = f"{prefix}{index_str}{random_chars}" + +# if message_id not in used_ids: +# used_ids.add(message_id) +# break + +# result.append({"id": message_id, "message": message}) + +# return result # 使用示例: # messages = ["Hello", "World", "Test message"] -# +# # # 基础版本 # result1 = assign_message_ids(messages) # # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}] -# +# # # 增强版本 - 自定义前缀和长度 # result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8) # # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}] -# +# # # 增强版本 - 使用时间戳 # result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True) # # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}] + def parse_keywords_string(keywords_input) -> list[str]: # sourcery skip: use-contextlib-suppress """ 统一的关键词解析函数,支持多种格式的关键词字符串解析 - + 支持的格式: 1. 字符串列表格式:'["utils.py", "修改", "代码", "动作"]' 2. 斜杠分隔格式:'utils.py/修改/代码/动作' @@ -784,25 +780,25 @@ def parse_keywords_string(keywords_input) -> list[str]: 4. 空格分隔格式:'utils.py 修改 代码 动作' 5. 已经是列表的情况:["utils.py", "修改", "代码", "动作"] 6. JSON格式字符串:'{"keywords": ["utils.py", "修改", "代码", "动作"]}' - + Args: keywords_input: 关键词输入,可以是字符串或列表 - + Returns: list[str]: 解析后的关键词列表,去除空白项 """ if not keywords_input: return [] - + # 如果已经是列表,直接处理 if isinstance(keywords_input, list): return [str(k).strip() for k in keywords_input if str(k).strip()] - + # 转换为字符串处理 keywords_str = str(keywords_input).strip() if not keywords_str: return [] - + try: # 尝试作为JSON对象解析(支持 {"keywords": [...]} 格式) json_data = json.loads(keywords_str) @@ -815,7 +811,7 @@ def parse_keywords_string(keywords_input) -> list[str]: return [str(k).strip() for k in json_data if str(k).strip()] except (json.JSONDecodeError, ValueError): pass - + try: # 尝试使用 ast.literal_eval 解析(支持Python字面量格式) parsed = ast.literal_eval(keywords_str) @@ -823,15 +819,15 @@ def parse_keywords_string(keywords_input) -> list[str]: return [str(k).strip() for k in parsed if str(k).strip()] except (ValueError, SyntaxError): pass - + # 尝试不同的分隔符 - separators = ['/', ',', ' ', '|', ';'] - + separators = ["/", ",", " ", "|", ";"] + for separator in separators: if separator in keywords_str: keywords_list = [k.strip() for k in keywords_str.split(separator) if k.strip()] if len(keywords_list) > 1: # 确保分割有效 return keywords_list - + # 如果没有分隔符,返回单个关键词 - return [keywords_str] if keywords_str else [] \ No newline at end of file + return [keywords_str] if keywords_str else [] diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py index c73f1a9e..222ff59c 100644 --- a/src/common/data_models/__init__.py +++ b/src/common/data_models/__init__.py @@ -1,26 +1,28 @@ -from typing import Dict, Any +import copy +from typing import Any -class AbstractClassFlag: - pass - +class BaseDataModel: + def deepcopy(self): + return copy.deepcopy(self) def temporarily_transform_class_to_dict(obj: Any) -> Any: + # sourcery skip: assign-if-exp, reintroduce-else """ - 将对象或容器中的 AbstractClassFlag 子类(类对象)或 AbstractClassFlag 实例 + 将对象或容器中的 BaseDataModel 子类(类对象)或 BaseDataModel 实例 递归转换为普通 dict,不修改原对象。 - - 对于类对象(isinstance(value, type) 且 issubclass(..., AbstractClassFlag)), + - 对于类对象(isinstance(value, type) 且 issubclass(..., BaseDataModel)), 读取类的 __dict__ 中非 dunder 项并递归转换。 - - 对于实例(isinstance(value, AbstractClassFlag)),读取 vars(instance) 并递归转换。 + - 对于实例(isinstance(value, BaseDataModel)),读取 vars(instance) 并递归转换。 """ def _transform(value: Any) -> Any: - # 值是类对象且为 AbstractClassFlag 的子类 - if isinstance(value, type) and issubclass(value, AbstractClassFlag): + # 值是类对象且为 BaseDataModel 的子类 + if isinstance(value, type) and issubclass(value, BaseDataModel): return {k: _transform(v) for k, v in value.__dict__.items() if not k.startswith("__") and not callable(v)} - # 值是 AbstractClassFlag 的实例 - if isinstance(value, AbstractClassFlag): + # 值是 BaseDataModel 的实例 + if isinstance(value, BaseDataModel): return {k: _transform(v) for k, v in vars(value).items()} # 常见容器类型,递归处理 diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 77da7f99..59761d09 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -1,36 +1,41 @@ -from typing import Optional, Dict, Any -from dataclasses import dataclass, field, fields, MISSING +from typing import Optional, Any +from dataclasses import dataclass, field + +from . import BaseDataModel -from . import AbstractClassFlag @dataclass -class DatabaseUserInfo(AbstractClassFlag): +class DatabaseUserInfo(BaseDataModel): platform: str = field(default_factory=str) user_id: str = field(default_factory=str) user_nickname: str = field(default_factory=str) user_cardname: Optional[str] = None - - def __post_init__(self): - assert isinstance(self.platform, str), "platform must be a string" - assert isinstance(self.user_id, str), "user_id must be a string" - assert isinstance(self.user_nickname, str), "user_nickname must be a string" - assert isinstance(self.user_cardname, str) or self.user_cardname is None, "user_cardname must be a string or None" + + # def __post_init__(self): + # assert isinstance(self.platform, str), "platform must be a string" + # assert isinstance(self.user_id, str), "user_id must be a string" + # assert isinstance(self.user_nickname, str), "user_nickname must be a string" + # assert isinstance(self.user_cardname, str) or self.user_cardname is None, ( + # "user_cardname must be a string or None" + # ) @dataclass -class DatabaseGroupInfo(AbstractClassFlag): +class DatabaseGroupInfo(BaseDataModel): group_id: str = field(default_factory=str) group_name: str = field(default_factory=str) group_platform: Optional[str] = None - - def __post_init__(self): - assert isinstance(self.group_id, str), "group_id must be a string" - assert isinstance(self.group_name, str), "group_name must be a string" - assert isinstance(self.group_platform, str) or self.group_platform is None, "group_platform must be a string or None" + + # def __post_init__(self): + # assert isinstance(self.group_id, str), "group_id must be a string" + # assert isinstance(self.group_name, str), "group_name must be a string" + # assert isinstance(self.group_platform, str) or self.group_platform is None, ( + # "group_platform must be a string or None" + # ) @dataclass -class DatabaseChatInfo(AbstractClassFlag): +class DatabaseChatInfo(BaseDataModel): stream_id: str = field(default_factory=str) platform: str = field(default_factory=str) create_time: float = field(default_factory=float) @@ -38,123 +43,117 @@ class DatabaseChatInfo(AbstractClassFlag): user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo) group_info: Optional[DatabaseGroupInfo] = None - def __post_init__(self): - assert isinstance(self.stream_id, str), "stream_id must be a string" - assert isinstance(self.platform, str), "platform must be a string" - assert isinstance(self.create_time, float), "create_time must be a float" - assert isinstance(self.last_active_time, float), "last_active_time must be a float" - assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance" - assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, "group_info must be a DatabaseGroupInfo instance or None" + # def __post_init__(self): + # assert isinstance(self.stream_id, str), "stream_id must be a string" + # assert isinstance(self.platform, str), "platform must be a string" + # assert isinstance(self.create_time, float), "create_time must be a float" + # assert isinstance(self.last_active_time, float), "last_active_time must be a float" + # assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance" + # assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, ( + # "group_info must be a DatabaseGroupInfo instance or None" + # ) @dataclass(init=False) -class DatabaseMessages(AbstractClassFlag): - # chat_info: DatabaseChatInfo - # user_info: DatabaseUserInfo - # group_info: Optional[DatabaseGroupInfo] = None +class DatabaseMessages(BaseDataModel): + def __init__( + self, + message_id: str = "", + time: float = 0.0, + chat_id: str = "", + reply_to: Optional[str] = None, + interest_value: Optional[float] = None, + key_words: Optional[str] = None, + key_words_lite: Optional[str] = None, + is_mentioned: Optional[bool] = None, + processed_plain_text: Optional[str] = None, + display_message: Optional[str] = None, + priority_mode: Optional[str] = None, + priority_info: Optional[str] = None, + additional_config: Optional[str] = None, + is_emoji: bool = False, + is_picid: bool = False, + is_command: bool = False, + is_notify: bool = False, + selected_expressions: Optional[str] = None, + user_id: str = "", + user_nickname: str = "", + user_cardname: Optional[str] = None, + user_platform: str = "", + chat_info_group_id: Optional[str] = None, + chat_info_group_name: Optional[str] = None, + chat_info_group_platform: Optional[str] = None, + chat_info_user_id: str = "", + chat_info_user_nickname: str = "", + chat_info_user_cardname: Optional[str] = None, + chat_info_user_platform: str = "", + chat_info_stream_id: str = "", + chat_info_platform: str = "", + chat_info_create_time: float = 0.0, + chat_info_last_active_time: float = 0.0, + **kwargs: Any, + ): + self.message_id = message_id + self.time = time + self.chat_id = chat_id + self.reply_to = reply_to + self.interest_value = interest_value - message_id: str = field(default_factory=str) - time: float = field(default_factory=float) - chat_id: str = field(default_factory=str) - reply_to: Optional[str] = None - interest_value: Optional[float] = None + self.key_words = key_words + self.key_words_lite = key_words_lite + self.is_mentioned = is_mentioned - key_words: Optional[str] = None - key_words_lite: Optional[str] = None - is_mentioned: Optional[bool] = None + self.processed_plain_text = processed_plain_text + self.display_message = display_message - # 从 chat_info 扁平化而来的字段 - # chat_info_stream_id: str = field(default_factory=str) - # chat_info_platform: str = field(default_factory=str) - # chat_info_user_platform: str = field(default_factory=str) - # chat_info_user_id: str = field(default_factory=str) - # chat_info_user_nickname: str = field(default_factory=str) - # chat_info_user_cardname: Optional[str] = None - # chat_info_group_platform: Optional[str] = None - # chat_info_group_id: Optional[str] = None - # chat_info_group_name: Optional[str] = None - # chat_info_create_time: float = field(default_factory=float) - # chat_info_last_active_time: float = field(default_factory=float) + self.priority_mode = priority_mode + self.priority_info = priority_info - # 从顶层 user_info 扁平化而来的字段 (消息发送者信息) - # user_platform: str = field(default_factory=str) - # user_id: str = field(default_factory=str) - # user_nickname: str = field(default_factory=str) - # user_cardname: Optional[str] = None + self.additional_config = additional_config + self.is_emoji = is_emoji + self.is_picid = is_picid + self.is_command = is_command + self.is_notify = is_notify - processed_plain_text: Optional[str] = None # 处理后的纯文本消息 - display_message: Optional[str] = None # 显示的消息 - - priority_mode: Optional[str] = None - priority_info: Optional[str] = None - - additional_config: Optional[str] = None - is_emoji: bool = False - is_picid: bool = False - is_command: bool = False - is_notify: bool = False - - selected_expressions: Optional[str] = None - - # def __post_init__(self): - - # if self.chat_info_group_id and self.chat_info_group_name: - # self.group_info = DatabaseGroupInfo( - # group_id=self.chat_info_group_id, - # group_name=self.chat_info_group_name, - # group_platform=self.chat_info_group_platform, - # ) - - # chat_user_info = DatabaseUserInfo( - # user_id=self.chat_info_user_id, - # user_nickname=self.chat_info_user_nickname, - # user_cardname=self.chat_info_user_cardname, - # platform=self.chat_info_user_platform, - # ) - # self.chat_info = DatabaseChatInfo( - # stream_id=self.chat_info_stream_id, - # platform=self.chat_info_platform, - # create_time=self.chat_info_create_time, - # last_active_time=self.chat_info_last_active_time, - # user_info=chat_user_info, - # group_info=self.group_info, - # ) - def __init__(self, **kwargs: Any): - defined = {f.name: f for f in fields(self.__class__)} - for name, f in defined.items(): - if name in kwargs: - setattr(self, name, kwargs.pop(name)) - elif f.default is not MISSING: - setattr(self, name, f.default) - else: - raise TypeError(f"缺失必需字段: {name}") + self.selected_expressions = selected_expressions + self.group_info: Optional[DatabaseGroupInfo] = None self.user_info = DatabaseUserInfo( - user_id=kwargs.get("user_id"), # type: ignore - user_nickname=kwargs.get("user_nickname"), # type: ignore - user_cardname=kwargs.get("user_cardname"), # type: ignore - platform=kwargs.get("user_platform"), # type: ignore + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + platform=user_platform, ) - if kwargs.get("chat_info_group_id") and kwargs.get("chat_info_group_name"): + if chat_info_group_id and chat_info_group_name: self.group_info = DatabaseGroupInfo( - group_id=kwargs.get("chat_info_group_id"), # type: ignore - group_name=kwargs.get("chat_info_group_name"), # type: ignore - group_platform=kwargs.get("chat_info_group_platform"), # type: ignore + group_id=chat_info_group_id, + group_name=chat_info_group_name, + group_platform=chat_info_group_platform, ) - chat_user_info = DatabaseUserInfo( - user_id=kwargs.get("chat_info_user_id"), # type: ignore - user_nickname=kwargs.get("chat_info_user_nickname"), # type: ignore - user_cardname=kwargs.get("chat_info_user_cardname"), # type: ignore - platform=kwargs.get("chat_info_user_platform"), # type: ignore - ) - self.chat_info = DatabaseChatInfo( - stream_id=kwargs.get("chat_info_stream_id"), # type: ignore - platform=kwargs.get("chat_info_platform"), # type: ignore - create_time=kwargs.get("chat_info_create_time"), # type: ignore - last_active_time=kwargs.get("chat_info_last_active_time"), # type: ignore - user_info=chat_user_info, + stream_id=chat_info_stream_id, + platform=chat_info_platform, + create_time=chat_info_create_time, + last_active_time=chat_info_last_active_time, + user_info=DatabaseUserInfo( + user_id=chat_info_user_id, + user_nickname=chat_info_user_nickname, + user_cardname=chat_info_user_cardname, + platform=chat_info_user_platform, + ), group_info=self.group_info, ) - + + if kwargs: + for key, value in kwargs.items(): + setattr(self, key, value) + + # def __post_init__(self): + # assert isinstance(self.message_id, str), "message_id must be a string" + # assert isinstance(self.time, float), "time must be a float" + # assert isinstance(self.chat_id, str), "chat_id must be a string" + # assert isinstance(self.reply_to, str) or self.reply_to is None, "reply_to must be a string or None" + # assert isinstance(self.interest_value, float) or self.interest_value is None, ( + # "interest_value must be a float or None" + # ) diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index f9a5d569..ae3678d1 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -1,8 +1,10 @@ from dataclasses import dataclass, field from typing import Optional +from . import BaseDataModel + @dataclass -class TargetPersonInfo: +class TargetPersonInfo(BaseDataModel): platform: str = field(default_factory=str) user_id: str = field(default_factory=str) user_nickname: str = field(default_factory=str) diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py new file mode 100644 index 00000000..0fa87ba0 --- /dev/null +++ b/src/common/data_models/message_data_model.py @@ -0,0 +1,17 @@ +from typing import Optional +from dataclasses import dataclass, field + +from . import BaseDataModel + +@dataclass +class MessageAndActionModel(BaseDataModel): + time: float = field(default_factory=float) + user_id: str = field(default_factory=str) + user_platform: str = field(default_factory=str) + user_nickname: str = field(default_factory=str) + user_cardname: Optional[str] = None + processed_plain_text: Optional[str] = None + display_message: Optional[str] = None + chat_info_platform: str = field(default_factory=str) + is_action_record: bool = field(default=False) + action_name: Optional[str] = None diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 5e26a76e..6df79149 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -46,6 +46,8 @@ class PersonalityConfig(ConfigBase): reply_style: str = "" """表达风格""" + + plan_style: str = "" compress_personality: bool = True """是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭""" diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 97c34546..807f6484 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -159,14 +159,23 @@ class ClientRegistry: return decorator - def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient: + def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient: """ 获取注册的API客户端实例 Args: api_provider: APIProvider实例 + force_new: 是否强制创建新实例(用于解决事件循环问题) Returns: BaseClient: 注册的API客户端实例 """ + # 如果强制创建新实例,直接创建不使用缓存 + if force_new: + if client_class := self.client_registry.get(api_provider.client_type): + return client_class(api_provider) + else: + raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") + + # 正常的缓存逻辑 if api_provider.name not in self.client_instance_cache: if client_class := self.client_registry.get(api_provider.client_type): self.client_instance_cache[api_provider.name] = client_class(api_provider) diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index db6f085e..93a41a3d 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -44,6 +44,14 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("Gemini客户端") +# gemini_thinking参数(默认范围) +# 不同模型的思考预算范围配置 +THINKING_BUDGET_LIMITS = { + "gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True}, + "gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True}, + "gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False}, +} + gemini_safe_settings = [ SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE), @@ -328,6 +336,48 @@ class GeminiClient(BaseClient): api_key=api_provider.api_key, ) # 这里和openai不一样,gemini会自己决定自己是否需要retry + # 思维预算特殊值 + THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定 + THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用) + + @staticmethod + def clamp_thinking_budget(tb: int, model_id: str): + """ + 按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本) + """ + limits = None + matched_key = None + + # 优先尝试精确匹配 + if model_id in THINKING_BUDGET_LIMITS: + limits = THINKING_BUDGET_LIMITS[model_id] + matched_key = model_id + else: + # 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先 + sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True) + for key in sorted_keys: + # 必须满足:完全等于 或者 前缀匹配(带 "-" 边界) + if model_id == key or model_id.startswith(key + "-"): + limits = THINKING_BUDGET_LIMITS[key] + matched_key = key + break + + # 特殊值处理 + if tb == GeminiClient.THINKING_BUDGET_AUTO: + return GeminiClient.THINKING_BUDGET_AUTO + if tb == GeminiClient.THINKING_BUDGET_DISABLED: + if limits and limits.get("can_disable", False): + return GeminiClient.THINKING_BUDGET_DISABLED + return limits["min"] if limits else GeminiClient.THINKING_BUDGET_AUTO + + # 已知模型裁剪到范围 + if limits: + return max(limits["min"], min(tb, limits["max"])) + + # 未知模型,返回动态模式 + logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。") + return GeminiClient.THINKING_BUDGET_AUTO + async def get_response( self, model_info: ModelInfo, @@ -373,6 +423,19 @@ class GeminiClient(BaseClient): messages = _convert_messages(message_list) # 将tool_options转换为Gemini API所需的格式 tools = _convert_tool_options(tool_options) if tool_options else None + + tb = GeminiClient.THINKING_BUDGET_AUTO + #空处理 + if extra_params and "thinking_budget" in extra_params: + try: + tb = int(extra_params["thinking_budget"]) + except (ValueError, TypeError): + logger.warning( + f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}" + ) + # 裁剪到模型支持的范围 + tb = self.clamp_thinking_budget(tb, model_info.model_identifier) + # 将response_format转换为Gemini API所需的格式 generation_config_dict = { "max_output_tokens": max_tokens, @@ -380,11 +443,7 @@ class GeminiClient(BaseClient): "response_modalities": ["TEXT"], "thinking_config": ThinkingConfig( include_thoughts=True, - thinking_budget=( - extra_params["thinking_budget"] - if extra_params and "thinking_budget" in extra_params - else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复 - ), + thinking_budget=tb, ), "safety_settings": gemini_safe_settings, # 防止空回复问题 } diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index c580899a..bba00f94 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -388,6 +388,7 @@ class OpenaiClient(BaseClient): base_url=api_provider.base_url, api_key=api_provider.api_key, max_retries=0, + timeout=api_provider.timeout, ) async def get_response( @@ -520,6 +521,11 @@ class OpenaiClient(BaseClient): extra_body=extra_params, ) except APIConnectionError as e: + # 添加详细的错误信息以便调试 + logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}") + logger.error(f"错误类型: {type(e)}") + if hasattr(e, '__cause__') and e.__cause__: + logger.error(f"底层错误: {str(e.__cause__)}") raise NetworkConnectionError() from e except APIStatusError as e: # 重封装APIError为RespNotOkException diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index e8e4db5f..1125e9fd 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -195,7 +195,7 @@ class LLMRequest: if not content: if raise_when_empty: - logger.warning("生成的响应为空") + logger.warning(f"生成的响应为空, 请求类型: {self.request_type}") raise RuntimeError("生成的响应为空") content = "生成的响应为空,请检查模型配置或输入内容是否正确" @@ -248,7 +248,11 @@ class LLMRequest: ) model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) - client = client_registry.get_client_class_instance(api_provider) + + # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 + force_new_client = (self.request_type == "embedding") + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + logger.debug(f"选择请求模型: {model_info.name}") total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index 6dd681ea..83e6818f 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -163,13 +163,9 @@ class ChatAction: limit=15, limit_mode="last", ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_talking_prompt = build_readable_messages( - tmp_msgs, + message_list_before_now, replace_bot_name=True, - merge_messages=False, timestamp_mode="normal_no_YMD", read_mark=0.0, truncate=True, @@ -230,13 +226,9 @@ class ChatAction: limit=10, limit_mode="last", ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] chat_talking_prompt = build_readable_messages( - tmp_msgs, + message_list_before_now, replace_bot_name=True, - merge_messages=False, timestamp_mode="normal_no_YMD", read_mark=0.0, truncate=True, diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index 51b53f11..da54acd0 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -166,13 +166,10 @@ class ChatMood: limit=10, limit_mode="last", ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] + chat_talking_prompt = build_readable_messages( - tmp_msgs, + message_list_before_now, replace_bot_name=True, - merge_messages=False, timestamp_mode="normal_no_YMD", read_mark=0.0, truncate=True, @@ -248,13 +245,10 @@ class ChatMood: limit=5, limit_mode="last", ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] + chat_talking_prompt = build_readable_messages( - tmp_msgs, + message_list_before_now, replace_bot_name=True, - merge_messages=False, timestamp_mode="normal_no_YMD", read_mark=0.0, truncate=True, diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index d735d7c2..86447e27 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -17,6 +17,10 @@ from src.mais4u.mais4u_chat.screen_manager import screen_manager from src.chat.express.expression_selector import expression_selector from .s4u_mood_manager import mood_manager from src.mais4u.mais4u_chat.internal_manager import internal_manager +from src.common.data_models.database_data_model import DatabaseMessages + +from typing import List + logger = get_logger("prompt") @@ -58,7 +62,7 @@ def init_prompt(): """, "s4u_prompt", # New template for private CHAT chat ) - + Prompt( """ 你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播 @@ -95,14 +99,13 @@ class PromptBuilder: def __init__(self): self.prompt_built = "" self.activate_messages = "" - - async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target): + async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target): style_habits = [] # 使用从处理器传来的选中表达方式 # LLM模式:调用LLM选择5-10个,然后随机选5个 - selected_expressions ,_ = await expression_selector.select_suitable_expressions_llm( + selected_expressions, _ = await expression_selector.select_suitable_expressions_llm( chat_stream.stream_id, chat_history, max_num=12, target_message=target ) @@ -122,7 +125,6 @@ class PromptBuilder: if style_habits_str.strip(): expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n" - return expression_habits_block async def build_relation_info(self, chat_stream) -> str: @@ -148,9 +150,7 @@ class PromptBuilder: person_ids.append(person_id) # 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为 - relation_info_list = [ - Person(person_id=person_id).build_relationship() for person_id in person_ids - ] + relation_info_list = [Person(person_id=person_id).build_relationship() for person_id in person_ids] if relation_info := "".join(relation_info_list): relation_prompt = await global_prompt_manager.format_prompt( "relation_prompt", relation_info=relation_info @@ -160,7 +160,7 @@ class PromptBuilder: async def build_memory_block(self, text: str) -> str: # 待更新记忆系统 return "" - + related_memory = await hippocampus_manager.get_memory_from_text( text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False ) @@ -176,38 +176,37 @@ class PromptBuilder: message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), + # sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if limit=300, ) - talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}" - core_dialogue_list = [] - background_dialogue_list = [] + core_dialogue_list: List[DatabaseMessages] = [] + background_dialogue_list: List[DatabaseMessages] = [] bot_id = str(global_config.bot.qq_account) target_user_id = str(message.chat_stream.user_info.user_id) - # TODO: 修复之! for msg in message_list_before_now: try: msg_user_id = str(msg.user_info.user_id) if msg_user_id == bot_id: if msg.reply_to and talk_type == msg.reply_to: - core_dialogue_list.append(msg.__dict__) + core_dialogue_list.append(msg) elif msg.reply_to and talk_type != msg.reply_to: - background_dialogue_list.append(msg.__dict__) + background_dialogue_list.append(msg) # else: - # background_dialogue_list.append(msg_dict) + # background_dialogue_list.append(msg_dict) elif msg_user_id == target_user_id: - core_dialogue_list.append(msg.__dict__) + core_dialogue_list.append(msg) else: - background_dialogue_list.append(msg.__dict__) + background_dialogue_list.append(msg) except Exception as e: logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}") background_dialogue_prompt = "" if background_dialogue_list: - context_msgs = background_dialogue_list[-s4u_config.max_context_message_length:] + context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :] background_dialogue_prompt_str = build_readable_messages( context_msgs, timestamp_mode="normal_no_YMD", @@ -217,10 +216,10 @@ class PromptBuilder: core_msg_str = "" if core_dialogue_list: - core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length:] + core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :] first_msg = core_dialogue_list[0] - start_speaking_user_id = first_msg.get("user_id") + start_speaking_user_id = first_msg.user_info.user_id if start_speaking_user_id == bot_id: last_speaking_user_id = bot_id msg_seg_str = "你的发言:\n" @@ -229,13 +228,13 @@ class PromptBuilder: last_speaking_user_id = start_speaking_user_id msg_seg_str = "对方的发言:\n" - msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n" + msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n" all_msg_seg_list = [] for msg in core_dialogue_list[1:]: - speaker = msg.get("user_id") + speaker = msg.user_info.user_id if speaker == last_speaking_user_id: - msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" + msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n" else: msg_seg_str = f"{msg_seg_str}\n" all_msg_seg_list.append(msg_seg_str) @@ -252,46 +251,40 @@ class PromptBuilder: for msg in all_msg_seg_list: core_msg_str += msg - - all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat( + all_dialogue_history = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), limit=20, ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in all_dialogue_prompt] + all_dialogue_prompt_str = build_readable_messages( - tmp_msgs, + all_dialogue_history, timestamp_mode="normal_no_YMD", show_pic=False, ) - - return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str + return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str def build_gift_info(self, message: MessageRecvS4U): if message.is_gift: - return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户" + return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户" else: if message.is_fake_gift: return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)" - + return "" def build_sc_info(self, message: MessageRecvS4U): super_chat_manager = get_super_chat_manager() return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id) - async def build_prompt_normal( self, message: MessageRecvS4U, message_txt: str, ) -> str: - chat_stream = message.chat_stream - + person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id) person_name = person.person_name @@ -302,28 +295,31 @@ class PromptBuilder: sender_name = f"[{message.chat_stream.user_info.user_nickname}]" else: sender_name = f"用户({message.chat_stream.user_info.user_id})" - - + relation_info_block, memory_block, expression_habits_block = await asyncio.gather( - self.build_relation_info(chat_stream), self.build_memory_block(message_txt), self.build_expression_habits(chat_stream, message_txt, sender_name) + self.build_relation_info(chat_stream), + self.build_memory_block(message_txt), + self.build_expression_habits(chat_stream, message_txt, sender_name), + ) + + core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts( + chat_stream, message ) - core_dialogue_prompt, background_dialogue_prompt,all_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message) - gift_info = self.build_gift_info(message) - + sc_info = self.build_sc_info(message) - + screen_info = screen_manager.get_screen_str() - + internal_state = internal_manager.get_internal_state_str() time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - + mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id) template_name = "s4u_prompt" - + if not message.is_internal: prompt = await global_prompt_manager.format_prompt( template_name, @@ -356,7 +352,7 @@ class PromptBuilder: mind=message.processed_plain_text, mood_state=mood.mood_state, ) - + # print(prompt) return prompt diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 4d501beb..b64188b4 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -99,13 +99,10 @@ class ChatMood: limit=int(global_config.chat.max_context_size / 3), limit_mode="last", ) - # TODO: 修复! - from src.common.data_models import temporarily_transform_class_to_dict - tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] + chat_talking_prompt = build_readable_messages( - tmp_msgs, + message_list_before_now, replace_bot_name=True, - merge_messages=False, timestamp_mode="normal_no_YMD", read_mark=0.0, truncate=True, @@ -151,13 +148,10 @@ class ChatMood: limit=15, limit_mode="last", ) - # TODO: 修复 - from src.common.data_models import temporarily_transform_class_to_dict - tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now] + chat_talking_prompt = build_readable_messages( - tmp_msgs, + message_list_before_now, replace_bot_name=True, - merge_messages=False, timestamp_mode="normal_no_YMD", read_mark=0.0, truncate=True, diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 61683796..0fe759bd 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -5,7 +5,7 @@ import time import random from json_repair import repair_json -from typing import Union +from typing import Union, Optional from src.common.logger import get_logger from src.common.database.database import db @@ -253,8 +253,8 @@ class Person: # 初始化默认值 self.nickname = "" - self.person_name = None - self.name_reason = None + self.person_name: Optional[str] = None + self.name_reason: Optional[str] = None self.know_times = 0 self.know_since = None self.last_know = None diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 67958399..916162a8 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -1,18 +1,21 @@ +import json +import traceback + +from json_repair import repair_json +from datetime import datetime +from typing import List + from src.common.logger import get_logger -from .person_info import Person -import random +from src.common.data_models.database_data_model import DatabaseMessages from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.chat.utils.chat_message_builder import build_readable_messages -import json -from json_repair import repair_json -from datetime import datetime -from typing import List, Dict, Any from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -import traceback +from .person_info import Person logger = get_logger("relation") + def init_prompt(): Prompt( """ @@ -45,8 +48,7 @@ def init_prompt(): """, "attitude_to_me_prompt", ) - - + Prompt( """ 你的名字是{bot_name},{bot_name}的别名是{alias_str}。 @@ -80,104 +82,102 @@ def init_prompt(): "neuroticism_prompt", ) + class RelationshipManager: def __init__(self): self.relationship_llm = LLMRequest( model_set=model_config.model_task_config.utils, request_type="relationship.person" - ) - + ) + async def get_attitude_to_me(self, readable_messages, timestamp, person: Person): alias_str = ", ".join(global_config.bot.alias_names) current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") # 解析当前态度值 current_attitude_score = person.attitude_to_me total_confidence = person.attitude_to_me_confidence - + prompt = await global_prompt_manager.format_prompt( "attitude_to_me_prompt", - bot_name = global_config.bot.nickname, - alias_str = alias_str, - person_name = person.person_name, - nickname = person.nickname, - readable_messages = readable_messages, - current_time = current_time, + bot_name=global_config.bot.nickname, + alias_str=alias_str, + person_name=person.person_name, + nickname=person.nickname, + readable_messages=readable_messages, + current_time=current_time, ) - + attitude, _ = await self.relationship_llm.generate_response_async(prompt=prompt) - - attitude = repair_json(attitude) attitude_data = json.loads(attitude) - + if not attitude_data or (isinstance(attitude_data, list) and len(attitude_data) == 0): return "" - + # 确保 attitude_data 是字典格式 if not isinstance(attitude_data, dict): logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(attitude_data)}, 内容: {attitude_data}") return "" - + attitude_score = attitude_data["attitude"] - confidence = pow(attitude_data["confidence"],2) - + confidence = pow(attitude_data["confidence"], 2) + new_confidence = total_confidence + confidence - new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence)/new_confidence - + new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence) / new_confidence + person.attitude_to_me = new_attitude_score person.attitude_to_me_confidence = new_confidence - + return person - + async def get_neuroticism(self, readable_messages, timestamp, person: Person): alias_str = ", ".join(global_config.bot.alias_names) current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") # 解析当前态度值 current_neuroticism_score = person.neuroticism total_confidence = person.neuroticism_confidence - + prompt = await global_prompt_manager.format_prompt( "neuroticism_prompt", - bot_name = global_config.bot.nickname, - alias_str = alias_str, - person_name = person.person_name, - nickname = person.nickname, - readable_messages = readable_messages, - current_time = current_time, + bot_name=global_config.bot.nickname, + alias_str=alias_str, + person_name=person.person_name, + nickname=person.nickname, + readable_messages=readable_messages, + current_time=current_time, ) - - neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt) + neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt) # logger.info(f"prompt: {prompt}") # logger.info(f"neuroticism: {neuroticism}") - neuroticism = repair_json(neuroticism) neuroticism_data = json.loads(neuroticism) - + if not neuroticism_data or (isinstance(neuroticism_data, list) and len(neuroticism_data) == 0): return "" - + # 确保 neuroticism_data 是字典格式 if not isinstance(neuroticism_data, dict): logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(neuroticism_data)}, 内容: {neuroticism_data}") return "" - + neuroticism_score = neuroticism_data["neuroticism"] - confidence = pow(neuroticism_data["confidence"],2) - + confidence = pow(neuroticism_data["confidence"], 2) + new_confidence = total_confidence + confidence - - new_neuroticism_score = (current_neuroticism_score * total_confidence + neuroticism_score * confidence)/new_confidence - + + new_neuroticism_score = ( + current_neuroticism_score * total_confidence + neuroticism_score * confidence + ) / new_confidence + person.neuroticism = new_neuroticism_score person.neuroticism_confidence = new_confidence - - return person - - async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]): + return person + + async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]): """更新用户印象 Args: @@ -202,12 +202,11 @@ class RelationshipManager: # 遍历消息,构建映射 for msg in user_messages: - if msg.get("user_id") == "system": + if msg.user_info.user_id == "system": continue try: - - user_id = msg.get("user_id") - platform = msg.get("chat_info_platform") + user_id = msg.user_info.user_id + platform = msg.chat_info.platform assert isinstance(user_id, str) and isinstance(platform, str) msg_person = Person(user_id=user_id, platform=platform) @@ -242,19 +241,16 @@ class RelationshipManager: # 确保 original_name 和 mapped_name 都不为 None if original_name is not None and mapped_name is not None: readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}") - + # await self.get_points( - # readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person) + # readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person) await self.get_attitude_to_me(readable_messages=readable_messages, timestamp=timestamp, person=person) await self.get_neuroticism(readable_messages=readable_messages, timestamp=timestamp, person=person) person.know_times = know_times + 1 person.last_know = timestamp - - person.sync_to_database() - - + person.sync_to_database() def calculate_time_weight(self, point_time: str, current_time: str) -> float: """计算基于时间的权重系数""" @@ -280,6 +276,7 @@ class RelationshipManager: logger.error(f"计算时间权重失败: {e}") return 0.5 # 发生错误时返回中等权重 + init_prompt() relationship_manager = None @@ -290,4 +287,3 @@ def get_relationship_manager(): if relationship_manager is None: relationship_manager = RelationshipManager() return relationship_manager - diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index b693350b..3ffbc715 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -127,7 +127,7 @@ async def generate_reply( success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context( extra_info=extra_info, available_actions=available_actions, - choosen_actions=choosen_actions, + chosen_actions=choosen_actions, enable_tool=enable_tool, reply_message=reply_message, reply_reason=reply_reason, diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 2645474f..7a83f07f 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -294,7 +294,9 @@ def get_messages_before_time_in_chat( return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit) -def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[DatabaseMessages]: +def get_messages_before_time_for_users( + timestamp: float, person_ids: List[str], limit: int = 0 +) -> List[DatabaseMessages]: """ 获取指定用户在指定时间戳之前的消息 @@ -410,9 +412,8 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa def build_readable_messages_to_str( - messages: List[Dict[str, Any]], + messages: List[DatabaseMessages], replace_bot_name: bool = True, - merge_messages: bool = False, timestamp_mode: str = "relative", read_mark: float = 0.0, truncate: bool = False, @@ -434,14 +435,13 @@ def build_readable_messages_to_str( 格式化后的可读字符串 """ return build_readable_messages( - messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions + messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions ) async def build_readable_messages_with_details( - messages: List[Dict[str, Any]], + messages: List[DatabaseMessages], replace_bot_name: bool = True, - merge_messages: bool = False, timestamp_mode: str = "relative", truncate: bool = False, ) -> Tuple[str, List[Tuple[float, str, str]]]: @@ -458,7 +458,7 @@ async def build_readable_messages_with_details( Returns: 格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容) """ - return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate) + return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate) async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]: diff --git a/src/plugins/built_in/emoji_plugin/emoji.py b/src/plugins/built_in/emoji_plugin/emoji.py index bfb60bde..2a439d27 100644 --- a/src/plugins/built_in/emoji_plugin/emoji.py +++ b/src/plugins/built_in/emoji_plugin/emoji.py @@ -2,13 +2,14 @@ import random from typing import Tuple # 导入新插件系统 -from src.plugin_system import BaseAction, ActionActivationType, ChatMode +from src.plugin_system import BaseAction, ActionActivationType # 导入依赖的系统组件 from src.common.logger import get_logger # 导入API模块 - 标准Python包方式 from src.plugin_system.apis import emoji_api, llm_api, message_api + # NoReplyAction已集成到heartFC_chat.py中,不再需要导入 from src.config.config import global_config @@ -84,11 +85,8 @@ class EmojiAction(BaseAction): messages_text = "" if recent_messages: # 使用message_api构建可读的消息字符串 - # TODO: 修复 - from src.common.data_models import temporarily_transform_class_to_dict - tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in recent_messages] messages_text = message_api.build_readable_messages( - messages=tmp_msgs, + messages=recent_messages, timestamp_mode="normal_no_YMD", truncate=False, show_actions=False, diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 826af325..eba7012d 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.4.6" +version = "6.5.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -29,6 +29,8 @@ identity = "年龄为19岁,是女孩子,身高为160cm,有黑色的短发" # 描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容 reply_style = "回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。不要浮夸,不要夸张修辞。" +plan_style = "当你刚刚发送了消息,没有人回复时,不要选择action,如果有别的动作(非回复)满足条件,可以选择,当你一次发送了太多消息,为了避免打扰聊天节奏,不要选择动作" + compress_personality = false # 是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭 compress_identity = true # 是否压缩身份,压缩后会精简身份信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果不长,可以关闭 diff --git a/template/model_config_template.toml b/template/model_config_template.toml index 92ac8881..0d756314 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -5,7 +5,7 @@ version = "1.3.0" [[api_providers]] # API服务提供商(可以配置多个) name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名) -base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL +base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL api_key = "your-api-key-here" # API密钥(请替换为实际的API密钥) client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini") max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) diff --git a/test_del_memory.py b/test_del_memory.py deleted file mode 100644 index 523ad156..00000000 --- a/test_del_memory.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -测试del_memory函数的脚本 -""" - -import sys -import os - -# 添加src目录到Python路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) - -from person_info.person_info import Person - -def test_del_memory(): - """测试del_memory函数""" - print("开始测试del_memory函数...") - - # 创建一个测试用的Person实例(不连接数据库) - person = Person.__new__(Person) - person.person_id = "test_person" - person.memory_points = [ - "性格:这个人很友善:5.0", - "性格:这个人很友善:4.0", - "爱好:喜欢打游戏:3.0", - "爱好:喜欢打游戏:2.0", - "工作:是一名程序员:1.0", - "性格:这个人很友善:6.0" - ] - - print(f"原始记忆点数量: {len(person.memory_points)}") - print("原始记忆点:") - for i, memory in enumerate(person.memory_points): - print(f" {i+1}. {memory}") - - # 测试删除"性格"分类中"这个人很友善"的记忆 - print("\n测试1: 删除'性格'分类中'这个人很友善'的记忆") - deleted_count = person.del_memory("性格", "这个人很友善") - print(f"删除了 {deleted_count} 个记忆点") - print("删除后的记忆点:") - for i, memory in enumerate(person.memory_points): - print(f" {i+1}. {memory}") - - # 测试删除"爱好"分类中"喜欢打游戏"的记忆 - print("\n测试2: 删除'爱好'分类中'喜欢打游戏'的记忆") - deleted_count = person.del_memory("爱好", "喜欢打游戏") - print(f"删除了 {deleted_count} 个记忆点") - print("删除后的记忆点:") - for i, memory in enumerate(person.memory_points): - print(f" {i+1}. {memory}") - - # 测试相似度匹配 - print("\n测试3: 测试相似度匹配") - person.memory_points = [ - "性格:这个人非常友善:5.0", - "性格:这个人很友善:4.0", - "性格:这个人友善:3.0" - ] - print("原始记忆点:") - for i, memory in enumerate(person.memory_points): - print(f" {i+1}. {memory}") - - # 删除"这个人很友善"(应该匹配"这个人很友善"和"这个人友善") - deleted_count = person.del_memory("性格", "这个人很友善", similarity_threshold=0.8) - print(f"删除了 {deleted_count} 个记忆点") - print("删除后的记忆点:") - for i, memory in enumerate(person.memory_points): - print(f" {i+1}. {memory}") - - print("\n测试完成!") - -if __name__ == "__main__": - test_del_memory() diff --git a/test_fix_memory_points.py b/test_fix_memory_points.py deleted file mode 100644 index bf351463..00000000 --- a/test_fix_memory_points.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -测试修复后的memory_points处理 -""" - -import sys -import os - -# 添加src目录到Python路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) - -from person_info.person_info import Person - -def test_memory_points_with_none(): - """测试包含None值的memory_points处理""" - print("测试包含None值的memory_points处理...") - - # 创建一个测试Person实例 - person = Person(person_id="test_user_123") - - # 模拟包含None值的memory_points - person.memory_points = [ - "喜好:喜欢咖啡:1.0", - None, # 模拟None值 - "性格:开朗:1.0", - None, # 模拟另一个None值 - "兴趣:编程:1.0" - ] - - print(f"原始memory_points: {person.memory_points}") - - # 测试get_all_category方法 - try: - categories = person.get_all_category() - print(f"获取到的分类: {categories}") - print("✓ get_all_category方法正常工作") - except Exception as e: - print(f"✗ get_all_category方法出错: {e}") - return False - - # 测试get_memory_list_by_category方法 - try: - memories = person.get_memory_list_by_category("喜好") - print(f"获取到的喜好记忆: {memories}") - print("✓ get_memory_list_by_category方法正常工作") - except Exception as e: - print(f"✗ get_memory_list_by_category方法出错: {e}") - return False - - # 测试del_memory方法 - try: - deleted_count = person.del_memory("喜好", "喜欢咖啡") - print(f"删除的记忆点数量: {deleted_count}") - print(f"删除后的memory_points: {person.memory_points}") - print("✓ del_memory方法正常工作") - except Exception as e: - print(f"✗ del_memory方法出错: {e}") - return False - - return True - -def test_memory_points_empty(): - """测试空的memory_points处理""" - print("\n测试空的memory_points处理...") - - person = Person(person_id="test_user_456") - person.memory_points = [] - - try: - categories = person.get_all_category() - print(f"空列表的分类: {categories}") - print("✓ 空列表处理正常") - except Exception as e: - print(f"✗ 空列表处理出错: {e}") - return False - - try: - memories = person.get_memory_list_by_category("测试分类") - print(f"空列表的记忆: {memories}") - print("✓ 空列表分类查询正常") - except Exception as e: - print(f"✗ 空列表分类查询出错: {e}") - return False - - return True - -def test_memory_points_all_none(): - """测试全部为None的memory_points处理""" - print("\n测试全部为None的memory_points处理...") - - person = Person(person_id="test_user_789") - person.memory_points = [None, None, None] - - try: - categories = person.get_all_category() - print(f"全None列表的分类: {categories}") - print("✓ 全None列表处理正常") - except Exception as e: - print(f"✗ 全None列表处理出错: {e}") - return False - - try: - memories = person.get_memory_list_by_category("测试分类") - print(f"全None列表的记忆: {memories}") - print("✓ 全None列表分类查询正常") - except Exception as e: - print(f"✗ 全None列表分类查询出错: {e}") - return False - - return True - -if __name__ == "__main__": - print("开始测试修复后的memory_points处理...") - - success = True - success &= test_memory_points_with_none() - success &= test_memory_points_empty() - success &= test_memory_points_all_none() - - if success: - print("\n🎉 所有测试通过!memory_points的None值处理已修复。") - else: - print("\n❌ 部分测试失败,需要进一步检查。")