mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'dev' of https://github.com/Windpicker-owo/MaiBot into dev
commit
d93091ff1d
|
|
@ -25,8 +25,8 @@
|
|||
|
||||
**🍔MaiCore 是一个基于大语言模型的可交互智能体**
|
||||
|
||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,支持normal和focus统一化处理。
|
||||
- 🔌 **强大插件系统**:全面重构的插件架构,支持完整的管理API和权限控制。
|
||||
- 💭 **智能对话系统**:基于 LLM 的自然语言交互,聊天时机控制。
|
||||
- 🔌 **强大插件系统**:全面重构的插件架构,更多API。
|
||||
- 🤔 **实时思维系统**:模拟人类思考过程。
|
||||
- 🧠 **表达学习功能**:学习群友的说话风格和表达方式
|
||||
- 💝 **情感表达系统**:情绪系统和表情包系统。
|
||||
|
|
@ -46,7 +46,7 @@
|
|||
|
||||
## 🔥 更新和安装
|
||||
|
||||
**最新版本: v0.9.1** ([更新日志](changelogs/changelog.md))
|
||||
**最新版本: v0.10.0** ([更新日志](changelogs/changelog.md))
|
||||
|
||||
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
|
||||
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
|
||||
|
|
@ -56,7 +56,6 @@
|
|||
- `classical`: 旧版本(停止维护)
|
||||
|
||||
### 最新版本部署教程
|
||||
- [从0.6/0.7升级须知](https://docs.mai-mai.org/faq/maibot/update_to_07.html)
|
||||
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)
|
||||
|
||||
> [!WARNING]
|
||||
|
|
@ -64,7 +63,6 @@
|
|||
> - 项目处于活跃开发阶段,功能和 API 可能随时调整。
|
||||
> - 文档未完善,有问题可以提交 Issue 或者 Discussion。
|
||||
> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。
|
||||
> - 由于持续迭代,可能存在一些已知或未知的 bug。
|
||||
> - 由于程序处于开发中,可能消耗较多 token。
|
||||
|
||||
## 💬 讨论
|
||||
|
|
|
|||
|
|
@ -1,16 +1,74 @@
|
|||
# Changelog
|
||||
|
||||
## [0.10.0] - 2025-7-1
|
||||
### 主要功能更改
|
||||
### 🌟 主要功能更改
|
||||
- 优化的回复生成,现在的回复对上下文把控更加精准
|
||||
- 新的回复逻辑控制,现在合并了normal和focus模式,更加统一
|
||||
- 优化表达方式系统,现在学习和使用更加精准
|
||||
- 新的关系系统,现在的关系构建更精准也更克制
|
||||
- 工具系统重构,现在合并到了插件系统中
|
||||
- 彻底重构了整个LLM Request了,现在支持模型轮询和更多灵活的参数
|
||||
- 同时重构了整个模型配置系统,升级需要重新配置llm配置文件
|
||||
- 随着LLM Request的重构,插件系统彻底重构完成。插件系统进入稳定状态,仅增加新的API
|
||||
- 具体相比于之前的更改可以查看[changes.md](./changes.md)
|
||||
- **警告所有插件开发者:插件系统即将迎来不稳定时期,随时会发动更改。**
|
||||
|
||||
#### 🔧 工具系统重构
|
||||
- **工具系统整合**: 工具系统现在完全合并到插件系统中,提供统一的扩展能力
|
||||
- **工具启用控制**: 支持配置是否启用特定工具,提供更人性化的直接调用方式
|
||||
- **配置文件读取**: 工具现在支持读取配置文件,增强配置灵活性
|
||||
|
||||
#### 🚀 LLM系统全面重构
|
||||
- **LLM Request重构**: 彻底重构了整个LLM Request系统,现在支持模型轮询和更多灵活的参数
|
||||
- **模型配置升级**: 同时重构了整个模型配置系统,升级需要重新配置llm配置文件
|
||||
- **任务类型支持**: 新增任务类型和能力字段至模型配置,增强模型初始化逻辑
|
||||
- **异常处理增强**: 增强LLMRequest类的异常处理,添加统一的模型异常处理方法
|
||||
|
||||
#### 🔌 插件系统稳定化
|
||||
- **插件系统重构完成**: 随着LLM Request的重构,插件系统彻底重构完成,进入稳定状态
|
||||
- **API扩展**: 仅增加新的API,保持向后兼容性
|
||||
- **插件管理优化**: 让插件管理配置真正有用,提升管理体验
|
||||
|
||||
#### 💾 记忆系统优化
|
||||
- **及时构建**: 记忆系统再优化,现在及时构建,并且不会重复构建
|
||||
- **精确提取**: 记忆提取更精确,提升记忆质量
|
||||
|
||||
#### 🎭 表达方式系统
|
||||
- **表达方式记录**: 记录使用的表达方式,提供更好的学习追踪
|
||||
- **学习优化**: 优化表达方式提取,修复表达学习出错问题
|
||||
- **配置优化**: 优化表达方式配置和逻辑,提升系统稳定性
|
||||
|
||||
#### 🔄 聊天系统统一
|
||||
- **normal和focus合并**: 彻底合并normal和focus,完全基于planner决定target message
|
||||
- **no_reply内置**: 将no_reply功能移动到主循环中,简化系统架构
|
||||
- **回复优化**: 优化reply,填补缺失值,让麦麦可以回复自己的消息
|
||||
- **频率控制API**: 加入聊天频率控制相关API,提供更精细的控制
|
||||
|
||||
#### 日志系统改进
|
||||
- **日志颜色优化**: 修改了log的颜色,更加护眼
|
||||
- **日志清理优化**: 修复了日志清理先等24h的问题,提升系统性能
|
||||
- **计时定位**: 通过计时定位LLM异常延时,提升问题排查效率
|
||||
|
||||
### 🐛 问题修复
|
||||
|
||||
#### 代码质量提升
|
||||
- **lint问题修复**: 修复了lint爆炸的问题,代码更加规范了
|
||||
- **导入优化**: 修复导入爆炸和文档错误,优化代码结构
|
||||
|
||||
#### 系统稳定性
|
||||
- **循环导入**: 修复了import时循环导入的问题
|
||||
- **并行动作**: 修复并行动作炸裂问题,提升并发处理能力
|
||||
- **空响应处理**: 空响应就raise,避免系统异常
|
||||
|
||||
#### 功能修复
|
||||
- **API问题**: 修复api问题,提升系统可用性
|
||||
- **notice问题**: 为组件方法提供新参数,暂时解决notice问题
|
||||
- **关系构建**: 修复不认识的用户构建关系问题
|
||||
- **流式解析**: 修复流式解析越界问题,避免空choices的SSE帧错误
|
||||
|
||||
#### 配置和兼容性
|
||||
- **默认值**: 添加默认值,提升配置灵活性
|
||||
- **类型问题**: 修复类型问题,提升代码健壮性
|
||||
- **配置加载**: 优化配置加载逻辑,提升系统启动稳定性
|
||||
|
||||
### 细节优化
|
||||
- 修复了lint爆炸的问题,代码更加规范了
|
||||
- 修改了log的颜色,更加护眼
|
||||
|
||||
## [0.9.1] - 2025-7-26
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ services:
|
|||
# - ./data/MaiMBot:/data/MaiMBot
|
||||
# networks:
|
||||
# - maim_bot
|
||||
|
||||
volumes:
|
||||
site-packages:
|
||||
networks:
|
||||
|
|
|
|||
|
|
@ -47,3 +47,4 @@ reportportal-client
|
|||
scikit-learn
|
||||
seaborn
|
||||
structlog
|
||||
google.genai
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
|
@ -172,7 +173,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||
return True
|
||||
|
||||
|
||||
def main(): # sourcery skip: dict-comprehension
|
||||
async def main_async(): # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
|
|
@ -239,6 +240,29 @@ def main(): # sourcery skip: dict-comprehension
|
|||
return None
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 设置新的事件循环并运行异步主函数"""
|
||||
# 检查是否有现有的事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_closed():
|
||||
# 如果事件循环已关闭,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步主函数
|
||||
loop.run_until_complete(main_async())
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
if not loop.is_closed():
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import math
|
||||
import random
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from rich.traceback import install
|
||||
|
|
@ -8,6 +9,7 @@ from collections import deque
|
|||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
|
|
@ -15,21 +17,19 @@ from src.chat.planner_actions.planner import ActionPlanner
|
|||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.chat_loop.hfc_utils import CycleDetail
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.frequency_control.talk_frequency_control import talk_frequency_control
|
||||
from src.chat.frequency_control.focus_value_control import focus_value_control
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import ChatMode, EventType
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
import math
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
# no_action逻辑已集成到heartFC_chat.py中,不再需要导入
|
||||
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
|
||||
# 导入记忆系统
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.frequency_control.talk_frequency_control import talk_frequency_control
|
||||
from src.chat.frequency_control.focus_value_control import focus_value_control
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
|
|
@ -62,10 +62,7 @@ class HeartFChatting:
|
|||
其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_id: str,
|
||||
):
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
HeartFChatting 初始化函数
|
||||
|
||||
|
|
@ -83,7 +80,7 @@ class HeartFChatting:
|
|||
|
||||
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
|
||||
|
||||
self.talk_frequency_control = talk_frequency_control.get_talk_frequency_control(self.stream_id)
|
||||
self.focus_value_control = focus_value_control.get_focus_value_control(self.stream_id)
|
||||
|
||||
|
|
@ -104,7 +101,7 @@ class HeartFChatting:
|
|||
self.plan_timeout_count = 0
|
||||
|
||||
self.last_read_time = time.time() - 1
|
||||
|
||||
|
||||
self.focus_energy = 1
|
||||
self.no_action_consecutive = 0
|
||||
# 最近三次no_action的新消息兴趣度记录
|
||||
|
|
@ -166,27 +163,26 @@ class HeartFChatting:
|
|||
|
||||
# 获取动作类型,兼容新旧格式
|
||||
action_type = "未知动作"
|
||||
if hasattr(self, '_current_cycle_detail') and self._current_cycle_detail:
|
||||
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
|
||||
loop_plan_info = self._current_cycle_detail.loop_plan_info
|
||||
if isinstance(loop_plan_info, dict):
|
||||
action_result = loop_plan_info.get('action_result', {})
|
||||
action_result = loop_plan_info.get("action_result", {})
|
||||
if isinstance(action_result, dict):
|
||||
# 旧格式:action_result是字典
|
||||
action_type = action_result.get('action_type', '未知动作')
|
||||
action_type = action_result.get("action_type", "未知动作")
|
||||
elif isinstance(action_result, list) and action_result:
|
||||
# 新格式:action_result是actions列表
|
||||
action_type = action_result[0].get('action_type', '未知动作')
|
||||
action_type = action_result[0].get("action_type", "未知动作")
|
||||
elif isinstance(loop_plan_info, list) and loop_plan_info:
|
||||
# 直接是actions列表的情况
|
||||
action_type = loop_plan_info[0].get('action_type', '未知动作')
|
||||
action_type = loop_plan_info[0].get("action_type", "未知动作")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
|
||||
f"选择动作: {action_type}"
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
|
||||
def _determine_form_type(self) -> None:
|
||||
"""判断使用哪种形式的no_action"""
|
||||
# 如果连续no_action次数少于3次,使用waiting形式
|
||||
|
|
@ -195,42 +191,44 @@ class HeartFChatting:
|
|||
else:
|
||||
# 计算最近三次记录的兴趣度总和
|
||||
total_recent_interest = sum(self.recent_interest_records)
|
||||
|
||||
|
||||
# 计算调整后的阈值
|
||||
adjusted_threshold = 1 / self.talk_frequency_control.get_current_talk_frequency()
|
||||
|
||||
logger.info(f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}"
|
||||
)
|
||||
|
||||
# 如果兴趣度总和小于阈值,进入breaking形式
|
||||
if total_recent_interest < adjusted_threshold:
|
||||
logger.info(f"{self.log_prefix} 兴趣度不足,进入休息")
|
||||
self.focus_energy = random.randint(3, 6)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息")
|
||||
self.focus_energy = 1
|
||||
|
||||
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool,float]:
|
||||
self.focus_energy = 1
|
||||
|
||||
async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]:
|
||||
"""
|
||||
判断是否应该处理消息
|
||||
|
||||
|
||||
Args:
|
||||
new_message: 新消息列表
|
||||
mode: 当前聊天模式
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该处理消息
|
||||
"""
|
||||
new_message_count = len(new_message)
|
||||
talk_frequency = self.talk_frequency_control.get_current_talk_frequency()
|
||||
|
||||
|
||||
modified_exit_count_threshold = self.focus_energy * 0.5 / talk_frequency
|
||||
modified_exit_interest_threshold = 1.5 / talk_frequency
|
||||
total_interest = 0.0
|
||||
for msg_dict in new_message:
|
||||
interest_value = msg_dict.get("interest_value")
|
||||
if interest_value is not None and msg_dict.get("processed_plain_text", ""):
|
||||
for msg in new_message:
|
||||
interest_value = msg.interest_value
|
||||
if interest_value is not None and msg.processed_plain_text:
|
||||
total_interest += float(interest_value)
|
||||
|
||||
|
||||
if new_message_count >= modified_exit_count_threshold:
|
||||
self.recent_interest_records.append(total_interest)
|
||||
logger.info(
|
||||
|
|
@ -244,9 +242,11 @@ class HeartFChatting:
|
|||
if new_message_count > 0:
|
||||
# 只在兴趣值变化时输出log
|
||||
if not hasattr(self, "_last_accumulated_interest") or total_interest != self._last_accumulated_interest:
|
||||
logger.info(f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}"
|
||||
)
|
||||
self._last_accumulated_interest = total_interest
|
||||
|
||||
|
||||
if total_interest >= modified_exit_interest_threshold:
|
||||
# 记录兴趣度到列表
|
||||
self.recent_interest_records.append(total_interest)
|
||||
|
|
@ -261,29 +261,25 @@ class HeartFChatting:
|
|||
f"{self.log_prefix} 已等待{time.time() - self.last_read_time:.0f}秒,累计{new_message_count}条消息,累计兴趣{total_interest:.1f},继续等待..."
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return False,0.0
|
||||
|
||||
return False, 0.0
|
||||
|
||||
async def _loopbody(self):
|
||||
recent_messages_dict = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
start_time=self.last_read_time,
|
||||
end_time=time.time(),
|
||||
limit = 10,
|
||||
limit=10,
|
||||
limit_mode="latest",
|
||||
filter_mai=True,
|
||||
filter_command=True,
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_recent_messages_dict = [temporarily_transform_class_to_dict(msg) for msg in recent_messages_dict]
|
||||
)
|
||||
# 统一的消息处理逻辑
|
||||
should_process,interest_value = await self._should_process_messages(temp_recent_messages_dict)
|
||||
should_process, interest_value = await self._should_process_messages(recent_messages_dict)
|
||||
|
||||
if should_process:
|
||||
self.last_read_time = time.time()
|
||||
await self._observe(interest_value = interest_value)
|
||||
await self._observe(interest_value=interest_value)
|
||||
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
|
|
@ -298,22 +294,21 @@ class HeartFChatting:
|
|||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
selected_expressions:List[int] = None,
|
||||
selected_expressions: List[int] = None,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(
|
||||
reply_set=response_set,
|
||||
message_data=action_message,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.get("chat_info_platform")
|
||||
if platform is None:
|
||||
platform = getattr(self.chat_stream, "platform", "unknown")
|
||||
|
||||
person = Person(platform = platform ,user_id = action_message.get("user_id", ""))
|
||||
|
||||
person = Person(platform=platform, user_id=action_message.get("user_id", ""))
|
||||
person_name = person.person_name
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
|
|
@ -342,12 +337,10 @@ class HeartFChatting:
|
|||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(self,interest_value:float = 0.0) -> bool:
|
||||
|
||||
async def _observe(self, interest_value: float = 0.0) -> bool:
|
||||
action_type = "no_action"
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
|
||||
# 使用sigmoid函数将interest_value转换为概率
|
||||
# 当interest_value为0时,概率接近0(使用Focus模式)
|
||||
# 当interest_value很高时,概率接近1(使用Normal模式)
|
||||
|
|
@ -361,12 +354,14 @@ class HeartFChatting:
|
|||
x0 = 1.0 # 控制曲线中心点
|
||||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
||||
|
||||
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 0.5 / self.talk_frequency_control.get_current_talk_frequency()
|
||||
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 2 * self.talk_frequency_control.get_current_talk_frequency()
|
||||
|
||||
# 根据概率决定使用哪种模式
|
||||
if random.random() < normal_mode_probability:
|
||||
mode = ChatMode.NORMAL
|
||||
logger.info(f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability*100:.0f}%概率下选择回复")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability * 100:.0f}%概率下选择回复"
|
||||
)
|
||||
else:
|
||||
mode = ChatMode.FOCUS
|
||||
|
||||
|
|
@ -387,10 +382,9 @@ class HeartFChatting:
|
|||
await hippocampus_manager.build_memory_for_chat(self.stream_id)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
|
||||
|
||||
|
||||
if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS:
|
||||
#如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考
|
||||
# 如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考
|
||||
actions = [
|
||||
{
|
||||
"action_type": "no_action",
|
||||
|
|
@ -420,23 +414,21 @@ class HeartFChatting:
|
|||
):
|
||||
return False
|
||||
with Timer("规划器", cycle_timers):
|
||||
actions, _= await self.action_planner.plan(
|
||||
actions, _ = await self.action_planner.plan(
|
||||
mode=mode,
|
||||
loop_start_time=self.last_read_time,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# 3. 并行执行所有动作
|
||||
async def execute_action(action_info,actions):
|
||||
async def execute_action(action_info, actions):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
if action_info["action_type"] == "no_action":
|
||||
# 直接处理no_action逻辑,不再通过动作系统
|
||||
reason = action_info.get("reasoning", "选择不回复")
|
||||
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
|
||||
# 存储no_action信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
|
|
@ -447,13 +439,8 @@ class HeartFChatting:
|
|||
action_data={"reason": reason},
|
||||
action_name="no_action",
|
||||
)
|
||||
|
||||
return {
|
||||
"action_type": "no_action",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": ""
|
||||
}
|
||||
|
||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||
elif action_info["action_type"] != "reply":
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
|
|
@ -463,20 +450,19 @@ class HeartFChatting:
|
|||
action_info["action_data"],
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
action_info["action_message"]
|
||||
action_info["action_message"],
|
||||
)
|
||||
return {
|
||||
"action_type": action_info["action_type"],
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command
|
||||
"command": command,
|
||||
}
|
||||
else:
|
||||
|
||||
try:
|
||||
success, response_set, prompt_selected_expressions = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_message = action_info["action_message"],
|
||||
reply_message=action_info["action_message"],
|
||||
available_actions=available_actions,
|
||||
choosen_actions=actions,
|
||||
reply_reason=action_info.get("reasoning", ""),
|
||||
|
|
@ -485,29 +471,21 @@ class HeartFChatting:
|
|||
from_plugin=False,
|
||||
return_expressions=True,
|
||||
)
|
||||
|
||||
|
||||
if prompt_selected_expressions and len(prompt_selected_expressions) > 1:
|
||||
_,selected_expressions = prompt_selected_expressions
|
||||
_, selected_expressions = prompt_selected_expressions
|
||||
else:
|
||||
selected_expressions = []
|
||||
|
||||
if not success or not response_set:
|
||||
logger.info(f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
|
||||
)
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None
|
||||
}
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||
response_set=response_set,
|
||||
|
|
@ -521,7 +499,7 @@ class HeartFChatting:
|
|||
"action_type": "reply",
|
||||
"success": True,
|
||||
"reply_text": reply_text,
|
||||
"loop_info": loop_info
|
||||
"loop_info": loop_info,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
|
|
@ -531,26 +509,26 @@ class HeartFChatting:
|
|||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
action_tasks = [asyncio.create_task(execute_action(action,actions)) for action in actions]
|
||||
|
||||
|
||||
action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions]
|
||||
|
||||
# 并行执行所有任务
|
||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理执行结果
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
action_command = ""
|
||||
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
|
||||
_cur_action = actions[i]
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
|
|
@ -590,7 +568,6 @@ class HeartFChatting:
|
|||
},
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
|
||||
|
||||
if s4u_config.enable_s4u:
|
||||
await stop_typing()
|
||||
|
|
@ -602,7 +579,7 @@ class HeartFChatting:
|
|||
# await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
|
||||
|
||||
action_type = actions[0]["action_type"] if actions else "no_action"
|
||||
|
||||
|
||||
# 管理no_action计数器:当执行了非no_action动作时,重置计数器
|
||||
if action_type != "no_action":
|
||||
# no_action逻辑已集成到heartFC_chat.py中,直接重置计数器
|
||||
|
|
@ -610,7 +587,7 @@ class HeartFChatting:
|
|||
self.no_action_consecutive = 0
|
||||
logger.debug(f"{self.log_prefix} 执行了{action_type}动作,重置no_action计数器")
|
||||
return True
|
||||
|
||||
|
||||
if action_type == "no_action":
|
||||
self.no_action_consecutive += 1
|
||||
self._determine_form_type()
|
||||
|
|
@ -692,11 +669,12 @@ class HeartFChatting:
|
|||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
async def _send_response(self,
|
||||
reply_set,
|
||||
message_data,
|
||||
selected_expressions:List[int] = None,
|
||||
) -> str:
|
||||
async def _send_response(
|
||||
self,
|
||||
reply_set,
|
||||
message_data,
|
||||
selected_expressions: List[int] = None,
|
||||
) -> str:
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||
)
|
||||
|
|
@ -714,7 +692,7 @@ class HeartFChatting:
|
|||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message = message_data,
|
||||
reply_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
selected_expressions=selected_expressions,
|
||||
|
|
@ -724,7 +702,7 @@ class HeartFChatting:
|
|||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_message = message_data,
|
||||
reply_message=message_data,
|
||||
set_reply=False,
|
||||
typing=True,
|
||||
selected_expressions=selected_expressions,
|
||||
|
|
|
|||
|
|
@ -709,36 +709,36 @@ class EmojiManager:
|
|||
return emoji
|
||||
return None # 如果循环结束还没找到,则返回 None
|
||||
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]:
|
||||
"""根据哈希值获取已注册表情包的情感标签列表
|
||||
|
||||
Args:
|
||||
emoji_hash: 表情包的哈希值
|
||||
|
||||
Returns:
|
||||
Optional[str]: 表情包描述,如果未找到则返回None
|
||||
Optional[List[str]]: 情感标签列表,如果未找到则返回None
|
||||
"""
|
||||
try:
|
||||
# 先从内存中查找
|
||||
emoji = await self.get_emoji_from_manager(emoji_hash)
|
||||
if emoji and emoji.emotion:
|
||||
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.emotion}...")
|
||||
return ",".join(emoji.emotion)
|
||||
logger.info(f"[缓存命中] 从内存获取表情包情感标签: {emoji.emotion}...")
|
||||
return emoji.emotion
|
||||
|
||||
# 如果内存中没有,从数据库查找
|
||||
self._ensure_db()
|
||||
try:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if emoji_record and emoji_record.emotion:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
|
||||
return emoji_record.emotion
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
||||
return emoji_record.emotion.split(',')
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||
logger.error(f"获取表情包情感标签失败 (Hash: {emoji_hash}): {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import List, Dict, Optional, Any, Tuple
|
|||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
|
|
@ -346,21 +347,17 @@ class ExpressionLearner:
|
|||
current_time = time.time()
|
||||
|
||||
# 获取上次学习时间
|
||||
random_msg_temp = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=current_time,
|
||||
limit=num,
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
random_msg: Optional[List[Dict[str, Any]]] = [temporarily_transform_class_to_dict(msg) for msg in random_msg_temp] if random_msg_temp else None
|
||||
|
||||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
# 转化成str
|
||||
chat_id: str = random_msg[0]["chat_id"]
|
||||
chat_id: str = random_msg[0].chat_id
|
||||
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
|
|
|
|||
|
|
@ -117,30 +117,36 @@ class EmbeddingStore:
|
|||
self.idx2hash = None
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
"""获取字符串的嵌入向量,处理异步调用"""
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 尝试获取当前事件循环
|
||||
asyncio.get_running_loop()
|
||||
# 如果在事件循环中,使用线程池执行
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
return asyncio.run(get_embedding(s))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,直接运行
|
||||
result = asyncio.run(get_embedding(s))
|
||||
if result is None:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
|
@ -181,8 +187,14 @@ class EmbeddingStore:
|
|||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 直接使用异步函数
|
||||
embedding = asyncio.run(llm.get_embedding(s))
|
||||
# 在线程中创建独立的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import networkx as nx
|
|||
import numpy as np
|
||||
from typing import List, Tuple, Set, Coroutine, Any, Dict
|
||||
from collections import Counter
|
||||
from itertools import combinations
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
|
@ -23,6 +22,8 @@ from src.chat.utils.chat_message_builder import (
|
|||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
) # 导入 build_readable_messages
|
||||
|
||||
|
||||
# 添加cosine_similarity函数
|
||||
def cosine_similarity(v1, v2):
|
||||
"""计算余弦相似度"""
|
||||
|
|
@ -51,18 +52,9 @@ def calculate_information_content(text):
|
|||
return entropy
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
logger = get_logger("memory")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
|
|
@ -96,7 +88,7 @@ class MemoryGraph:
|
|||
if "memory_items" in self.G.nodes[concept]:
|
||||
# 获取现有的记忆项(已经是str格式)
|
||||
existing_memory = self.G.nodes[concept]["memory_items"]
|
||||
|
||||
|
||||
# 如果现有记忆不为空,则使用LLM整合新旧记忆
|
||||
if existing_memory and hippocampus_instance and hippocampus_instance.model_small:
|
||||
try:
|
||||
|
|
@ -170,16 +162,16 @@ class MemoryGraph:
|
|||
second_layer_items.append(memory_items)
|
||||
|
||||
return first_layer_items, second_layer_items
|
||||
|
||||
|
||||
async def _integrate_memories_with_llm(self, existing_memory: str, new_memory: str, llm_model: LLMRequest) -> str:
|
||||
"""
|
||||
使用LLM整合新旧记忆内容
|
||||
|
||||
|
||||
Args:
|
||||
existing_memory: 现有的记忆内容(字符串格式,可能包含多条记忆)
|
||||
new_memory: 新的记忆内容
|
||||
llm_model: LLM模型实例
|
||||
|
||||
|
||||
Returns:
|
||||
str: 整合后的记忆内容
|
||||
"""
|
||||
|
|
@ -203,8 +195,10 @@ class MemoryGraph:
|
|||
整合后的记忆:"""
|
||||
|
||||
# 调用LLM进行整合
|
||||
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(integration_prompt)
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(
|
||||
integration_prompt
|
||||
)
|
||||
|
||||
if content and content.strip():
|
||||
integrated_content = content.strip()
|
||||
logger.debug(f"LLM记忆整合成功,模型: {model_name}")
|
||||
|
|
@ -212,7 +206,7 @@ class MemoryGraph:
|
|||
else:
|
||||
logger.warning("LLM返回的整合结果为空,使用默认连接方式")
|
||||
return f"{existing_memory} | {new_memory}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM记忆整合过程中出错: {e}")
|
||||
return f"{existing_memory} | {new_memory}"
|
||||
|
|
@ -238,7 +232,11 @@ class MemoryGraph:
|
|||
if memory_items:
|
||||
# 删除整个节点
|
||||
self.G.remove_node(topic)
|
||||
return f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." if len(memory_items) > 50 else f"删除了节点 {topic} 的完整记忆: {memory_items}"
|
||||
return (
|
||||
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
|
||||
if len(memory_items) > 50
|
||||
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
|
||||
)
|
||||
else:
|
||||
# 如果没有记忆项,删除该节点
|
||||
self.G.remove_node(topic)
|
||||
|
|
@ -263,38 +261,40 @@ class Hippocampus:
|
|||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||
# 从数据库加载记忆图
|
||||
self.entorhinal_cortex.sync_memory_from_db()
|
||||
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.modify")
|
||||
self.model_small = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="memory.modify"
|
||||
)
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取记忆图中所有节点的名字列表"""
|
||||
return list(self.memory_graph.G.nodes())
|
||||
|
||||
|
||||
def calculate_weighted_activation(self, current_activation: float, edge_strength: int, target_node: str) -> float:
|
||||
"""
|
||||
计算考虑节点权重的激活值
|
||||
|
||||
|
||||
Args:
|
||||
current_activation: 当前激活值
|
||||
edge_strength: 边的强度
|
||||
target_node: 目标节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
float: 计算后的激活值
|
||||
"""
|
||||
# 基础激活值计算
|
||||
base_activation = current_activation - (1 / edge_strength)
|
||||
|
||||
|
||||
if base_activation <= 0:
|
||||
return 0.0
|
||||
|
||||
|
||||
# 获取目标节点的权重
|
||||
if target_node in self.memory_graph.G:
|
||||
node_data = self.memory_graph.G.nodes[target_node]
|
||||
node_weight = node_data.get("weight", 1.0)
|
||||
|
||||
|
||||
# 权重加成:每次整合增加10%激活值,最大加成200%
|
||||
weight_multiplier = 1.0 + min((node_weight - 1.0) * 0.1, 2.0)
|
||||
|
||||
|
||||
return base_activation * weight_multiplier
|
||||
else:
|
||||
return base_activation
|
||||
|
|
@ -332,9 +332,7 @@ class Hippocampus:
|
|||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||
f"如果确定找不出主题或者没有明显主题,返回<none>。"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -403,7 +401,7 @@ class Hippocampus:
|
|||
memories.sort(key=lambda x: x[2], reverse=True)
|
||||
return memories
|
||||
|
||||
async def get_keywords_from_text(self, text: str) -> list:
|
||||
async def get_keywords_from_text(self, text: str) -> Tuple[List[str], List]:
|
||||
"""从文本中提取关键词。
|
||||
|
||||
Args:
|
||||
|
|
@ -418,16 +416,13 @@ class Hippocampus:
|
|||
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
|
||||
text_length = len(text)
|
||||
topic_num: int | list[int] = 0
|
||||
|
||||
|
||||
|
||||
words = jieba.cut(text)
|
||||
keywords_lite = [word for word in words if len(word) > 1]
|
||||
keywords_lite = list(set(keywords_lite))
|
||||
if keywords_lite:
|
||||
logger.debug(f"提取关键词极简版: {keywords_lite}")
|
||||
|
||||
|
||||
|
||||
if text_length <= 12:
|
||||
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
|
||||
elif text_length <= 20:
|
||||
|
|
@ -455,7 +450,7 @@ class Hippocampus:
|
|||
if keywords:
|
||||
logger.debug(f"提取关键词: {keywords}")
|
||||
|
||||
return keywords,keywords_lite
|
||||
return keywords, keywords_lite
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self,
|
||||
|
|
@ -570,20 +565,17 @@ class Hippocampus:
|
|||
for node, activation in remember_map.items():
|
||||
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
memory_items = node_data.get("memory_items", "")
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
if memory_items := node_data.get("memory_items", ""):
|
||||
logger.debug("节点包含完整记忆")
|
||||
# 计算记忆与关键词的相似度
|
||||
memory_words = set(jieba.cut(memory_items))
|
||||
text_words = set(keywords)
|
||||
all_words = memory_words | text_words
|
||||
if all_words:
|
||||
if all_words := memory_words | text_words:
|
||||
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
|
||||
v1 = [1 if word in memory_words else 0 for word in all_words]
|
||||
v2 = [1 if word in text_words else 0 for word in all_words]
|
||||
_ = cosine_similarity(v1, v2) # 计算但不使用,用_表示
|
||||
|
||||
|
||||
# 添加完整记忆到结果中
|
||||
all_memories.append((node, memory_items, activation))
|
||||
else:
|
||||
|
|
@ -613,7 +605,9 @@ class Hippocampus:
|
|||
|
||||
return result
|
||||
|
||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str],list[str]]:
|
||||
async def get_activate_from_text(
|
||||
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
) -> tuple[float, list[str], list[str]]:
|
||||
"""从文本中提取关键词并获取相关记忆。
|
||||
|
||||
Args:
|
||||
|
|
@ -627,13 +621,13 @@ class Hippocampus:
|
|||
float: 激活节点数与总节点数的比值
|
||||
list[str]: 有效的关键词
|
||||
"""
|
||||
keywords,keywords_lite = await self.get_keywords_from_text(text)
|
||||
keywords, keywords_lite = await self.get_keywords_from_text(text)
|
||||
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
if not valid_keywords:
|
||||
# logger.info("没有找到有效的关键词节点")
|
||||
return 0, keywords,keywords_lite
|
||||
return 0, keywords, keywords_lite
|
||||
|
||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
|
|
@ -700,7 +694,7 @@ class Hippocampus:
|
|||
activation_ratio = activation_ratio * 50
|
||||
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
|
||||
|
||||
return activation_ratio, keywords,keywords_lite
|
||||
return activation_ratio, keywords, keywords_lite
|
||||
|
||||
|
||||
# 负责海马体与其他部分的交互
|
||||
|
|
@ -730,7 +724,7 @@ class EntorhinalCortex:
|
|||
continue
|
||||
|
||||
memory_items = data.get("memory_items", "")
|
||||
|
||||
|
||||
# 直接检查字符串是否为空,不需要分割成列表
|
||||
if not memory_items or memory_items.strip() == "":
|
||||
self.memory_graph.G.remove_node(concept)
|
||||
|
|
@ -865,7 +859,9 @@ class EntorhinalCortex:
|
|||
|
||||
end_time = time.time()
|
||||
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.info(f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边")
|
||||
logger.info(
|
||||
f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边"
|
||||
)
|
||||
|
||||
async def resync_memory_to_db(self):
|
||||
"""清空数据库并重新同步所有记忆数据"""
|
||||
|
|
@ -888,7 +884,7 @@ class EntorhinalCortex:
|
|||
nodes_data = []
|
||||
for concept, data in memory_nodes:
|
||||
memory_items = data.get("memory_items", "")
|
||||
|
||||
|
||||
# 直接检查字符串是否为空,不需要分割成列表
|
||||
if not memory_items or memory_items.strip() == "":
|
||||
self.memory_graph.G.remove_node(concept)
|
||||
|
|
@ -960,7 +956,7 @@ class EntorhinalCortex:
|
|||
|
||||
# 清空当前图
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
|
||||
# 统计加载情况
|
||||
total_nodes = 0
|
||||
loaded_nodes = 0
|
||||
|
|
@ -969,7 +965,7 @@ class EntorhinalCortex:
|
|||
# 从数据库加载所有节点
|
||||
nodes = list(GraphNodes.select())
|
||||
total_nodes = len(nodes)
|
||||
|
||||
|
||||
for node in nodes:
|
||||
concept = node.concept
|
||||
try:
|
||||
|
|
@ -978,7 +974,7 @@ class EntorhinalCortex:
|
|||
logger.warning(f"节点 {concept} 的memory_items为空,跳过")
|
||||
skipped_nodes += 1
|
||||
continue
|
||||
|
||||
|
||||
# 直接使用memory_items
|
||||
memory_items = node.memory_items.strip()
|
||||
|
||||
|
|
@ -999,11 +995,15 @@ class EntorhinalCortex:
|
|||
last_modified = node.last_modified or current_time
|
||||
|
||||
# 获取权重属性
|
||||
weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0
|
||||
|
||||
weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
|
||||
|
||||
# 添加节点到图中
|
||||
self.memory_graph.G.add_node(
|
||||
concept, memory_items=memory_items, weight=weight, created_time=created_time, last_modified=last_modified
|
||||
concept,
|
||||
memory_items=memory_items,
|
||||
weight=weight,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified,
|
||||
)
|
||||
loaded_nodes += 1
|
||||
except Exception as e:
|
||||
|
|
@ -1044,9 +1044,11 @@ class EntorhinalCortex:
|
|||
|
||||
if need_update:
|
||||
logger.info("[数据库] 已为缺失的时间字段进行补充")
|
||||
|
||||
|
||||
# 输出加载统计信息
|
||||
logger.info(f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes} 个")
|
||||
logger.info(
|
||||
f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes} 个"
|
||||
)
|
||||
|
||||
|
||||
# 负责整合,遗忘,合并记忆
|
||||
|
|
@ -1054,10 +1056,12 @@ class ParahippocampalGyrus:
|
|||
def __init__(self, hippocampus: Hippocampus):
|
||||
self.hippocampus = hippocampus
|
||||
self.memory_graph = hippocampus.memory_graph
|
||||
|
||||
self.memory_modify_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory.modify")
|
||||
|
||||
async def memory_compress(self, messages: list, compress_rate=0.1):
|
||||
self.memory_modify_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="memory.modify"
|
||||
)
|
||||
|
||||
async def memory_compress(self, messages: list[DatabaseMessages], compress_rate=0.1):
|
||||
"""压缩和总结消息内容,生成记忆主题和摘要。
|
||||
|
||||
Args:
|
||||
|
|
@ -1083,7 +1087,6 @@ class ParahippocampalGyrus:
|
|||
# build_readable_messages 只返回一个字符串,不需要解包
|
||||
input_text = build_readable_messages(
|
||||
messages,
|
||||
merge_messages=True, # 合并连续消息
|
||||
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
|
||||
replace_bot_name=False, # 保留原始用户名
|
||||
)
|
||||
|
|
@ -1163,7 +1166,7 @@ class ParahippocampalGyrus:
|
|||
similar_topics.sort(key=lambda x: x[1], reverse=True)
|
||||
similar_topics = similar_topics[:3]
|
||||
similar_topics_dict[topic] = similar_topics
|
||||
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"prompt: {topic_what_prompt}")
|
||||
logger.info(f"压缩后的记忆: {compressed_memory}")
|
||||
|
|
@ -1259,14 +1262,14 @@ class ParahippocampalGyrus:
|
|||
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
|
||||
last_modified = node_data.get("last_modified", current_time)
|
||||
node_weight = node_data.get("weight", 1.0)
|
||||
|
||||
|
||||
# 条件1:检查是否长时间未修改 (使用配置的遗忘时间)
|
||||
time_threshold = 3600 * global_config.memory.memory_forget_time
|
||||
|
||||
|
||||
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
|
||||
# 权重为1时使用默认阈值,权重越高阈值越大(越难遗忘)
|
||||
adjusted_threshold = time_threshold * node_weight
|
||||
|
||||
|
||||
if current_time - last_modified > adjusted_threshold and memory_items:
|
||||
# 既然每个节点现在是完整记忆,直接删除整个节点
|
||||
try:
|
||||
|
|
@ -1315,8 +1318,6 @@ class ParahippocampalGyrus:
|
|||
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
|
||||
|
||||
|
||||
class HippocampusManager:
|
||||
def __init__(self):
|
||||
self._hippocampus: Hippocampus = None # type: ignore
|
||||
|
|
@ -1361,29 +1362,32 @@ class HippocampusManager:
|
|||
"""为指定chat_id构建记忆(在heartFC_chat.py中调用)"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
|
||||
|
||||
try:
|
||||
# 检查是否需要构建记忆
|
||||
logger.info(f"为 {chat_id} 构建记忆")
|
||||
if memory_segment_manager.check_and_build_memory_for_chat(chat_id):
|
||||
logger.info(f"为 {chat_id} 构建记忆,需要构建记忆")
|
||||
messages = memory_segment_manager.get_messages_for_memory_build(chat_id, 50)
|
||||
|
||||
|
||||
build_probability = 0.3 * global_config.memory.memory_build_frequency
|
||||
|
||||
|
||||
if messages and random.random() < build_probability:
|
||||
logger.info(f"为 {chat_id} 构建记忆,消息数量: {len(messages)}")
|
||||
|
||||
|
||||
# 调用记忆压缩和构建
|
||||
compressed_memory, similar_topics_dict = await self._hippocampus.parahippocampal_gyrus.memory_compress(
|
||||
(
|
||||
compressed_memory,
|
||||
similar_topics_dict,
|
||||
) = await self._hippocampus.parahippocampal_gyrus.memory_compress(
|
||||
messages, global_config.memory.memory_compress_rate
|
||||
)
|
||||
|
||||
|
||||
# 添加记忆节点
|
||||
current_time = time.time()
|
||||
for topic, memory in compressed_memory:
|
||||
await self._hippocampus.memory_graph.add_dot(topic, memory, self._hippocampus)
|
||||
|
||||
|
||||
# 连接相似主题
|
||||
if topic in similar_topics_dict:
|
||||
similar_topics = similar_topics_dict[topic]
|
||||
|
|
@ -1391,23 +1395,23 @@ class HippocampusManager:
|
|||
if topic != similar_topic:
|
||||
strength = int(similarity * 10)
|
||||
self._hippocampus.memory_graph.G.add_edge(
|
||||
topic, similar_topic,
|
||||
topic,
|
||||
similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time
|
||||
last_modified=current_time,
|
||||
)
|
||||
|
||||
|
||||
# 同步到数据库
|
||||
await self._hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
logger.info(f"为 {chat_id} 构建记忆完成")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为 {chat_id} 构建记忆失败: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
|
|
@ -1424,16 +1428,20 @@ class HippocampusManager:
|
|||
response = []
|
||||
return response
|
||||
|
||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
|
||||
async def get_activate_from_text(
|
||||
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
) -> tuple[float, list[str], list[str]]:
|
||||
"""从文本中获取激活值的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
try:
|
||||
response, keywords,keywords_lite = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
|
||||
response, keywords, keywords_lite = await self._hippocampus.get_activate_from_text(
|
||||
text, max_depth, fast_retrieval
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"文本产生激活值失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return 0.0, [],[]
|
||||
return 0.0, [], []
|
||||
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
"""从关键词获取相关记忆的公共接口"""
|
||||
|
|
@ -1455,81 +1463,78 @@ hippocampus_manager = HippocampusManager()
|
|||
# 在Hippocampus类中添加新的记忆构建管理器
|
||||
class MemoryBuilder:
|
||||
"""记忆构建器
|
||||
|
||||
|
||||
为每个chat_id维护消息缓存和触发机制,类似ExpressionLearner
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.last_update_time: float = time.time()
|
||||
self.last_processed_time: float = 0.0
|
||||
|
||||
|
||||
def should_trigger_memory_build(self) -> bool:
|
||||
"""检查是否应该触发记忆构建"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = current_time - self.last_update_time
|
||||
if time_diff < 600 /global_config.memory.memory_build_frequency:
|
||||
if time_diff < 600 / global_config.memory.memory_build_frequency:
|
||||
return False
|
||||
|
||||
|
||||
# 检查消息数量
|
||||
|
||||
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_update_time,
|
||||
timestamp_end=current_time,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}")
|
||||
|
||||
if not recent_messages or len(recent_messages) < 30/global_config.memory.memory_build_frequency :
|
||||
|
||||
if not recent_messages or len(recent_messages) < 30 / global_config.memory.memory_build_frequency:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
def get_messages_for_memory_build(self, threshold: int = 25) -> List[Dict[str, Any]]:
|
||||
|
||||
def get_messages_for_memory_build(self, threshold: int = 25) -> List[DatabaseMessages]:
|
||||
"""获取用于记忆构建的消息"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_update_time,
|
||||
timestamp_end=current_time,
|
||||
limit=threshold,
|
||||
)
|
||||
tmp_msg = [msg.__dict__ for msg in messages] if messages else []
|
||||
if messages:
|
||||
# 更新最后处理时间
|
||||
self.last_processed_time = current_time
|
||||
self.last_update_time = current_time
|
||||
|
||||
return tmp_msg or []
|
||||
|
||||
return messages or []
|
||||
|
||||
|
||||
class MemorySegmentManager:
|
||||
"""记忆段管理器
|
||||
|
||||
|
||||
管理所有chat_id的MemoryBuilder实例,自动检查和触发记忆构建
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.builders: Dict[str, MemoryBuilder] = {}
|
||||
|
||||
|
||||
def get_or_create_builder(self, chat_id: str) -> MemoryBuilder:
|
||||
"""获取或创建指定chat_id的MemoryBuilder"""
|
||||
if chat_id not in self.builders:
|
||||
self.builders[chat_id] = MemoryBuilder(chat_id)
|
||||
return self.builders[chat_id]
|
||||
|
||||
|
||||
def check_and_build_memory_for_chat(self, chat_id: str) -> bool:
|
||||
"""检查指定chat_id是否需要构建记忆,如果需要则返回True"""
|
||||
builder = self.get_or_create_builder(chat_id)
|
||||
return builder.should_trigger_memory_build()
|
||||
|
||||
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[Dict[str, Any]]:
|
||||
|
||||
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[DatabaseMessages]:
|
||||
"""获取指定chat_id用于记忆构建的消息"""
|
||||
if chat_id not in self.builders:
|
||||
return []
|
||||
|
|
@ -1538,4 +1543,3 @@ class MemorySegmentManager:
|
|||
|
||||
# 创建全局实例
|
||||
memory_segment_manager = MemorySegmentManager()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
import json
|
||||
import random
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
import random
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
|
@ -75,19 +75,20 @@ class MemoryActivator:
|
|||
request_type="memory.selection",
|
||||
)
|
||||
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]:
|
||||
async def activate_memory_with_chat_history(
|
||||
self, target_message, chat_history: List[DatabaseMessages]
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
# 如果记忆系统被禁用,直接返回空列表
|
||||
if not global_config.memory.enable_memory:
|
||||
return []
|
||||
|
||||
|
||||
keywords_list = set()
|
||||
|
||||
for msg in chat_history_prompt:
|
||||
keywords = parse_keywords_string(msg.get("key_words", ""))
|
||||
|
||||
for msg in chat_history:
|
||||
keywords = parse_keywords_string(msg.key_words)
|
||||
if keywords:
|
||||
if len(keywords_list) < 30:
|
||||
# 最多容纳30个关键词
|
||||
|
|
@ -95,24 +96,22 @@ class MemoryActivator:
|
|||
logger.debug(f"提取关键词: {keywords_list}")
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
if not keywords_list:
|
||||
logger.debug("没有提取到关键词,返回空记忆列表")
|
||||
return []
|
||||
|
||||
|
||||
# 从海马体获取相关记忆
|
||||
related_memory = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
|
||||
)
|
||||
|
||||
|
||||
# logger.info(f"当前记忆关键词: {keywords_list}")
|
||||
logger.debug(f"获取到的记忆: {related_memory}")
|
||||
|
||||
|
||||
if not related_memory:
|
||||
logger.debug("海马体没有返回相关记忆")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
used_ids = set()
|
||||
candidate_memories = []
|
||||
|
|
@ -120,12 +119,7 @@ class MemoryActivator:
|
|||
# 为每个记忆分配随机ID并过滤相关记忆
|
||||
for memory in related_memory:
|
||||
keyword, content = memory
|
||||
found = False
|
||||
for kw in keywords_list:
|
||||
if kw in content:
|
||||
found = True
|
||||
break
|
||||
|
||||
found = any(kw in content for kw in keywords_list)
|
||||
if found:
|
||||
# 随机分配一个不重复的2位数id
|
||||
while True:
|
||||
|
|
@ -138,95 +132,83 @@ class MemoryActivator:
|
|||
if not candidate_memories:
|
||||
logger.info("没有找到相关的候选记忆")
|
||||
return []
|
||||
|
||||
|
||||
# 如果只有少量记忆,直接返回
|
||||
if len(candidate_memories) <= 2:
|
||||
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
|
||||
|
||||
# 使用 LLM 选择合适的记忆
|
||||
selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories)
|
||||
|
||||
return selected_memories
|
||||
|
||||
async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]:
|
||||
return await self._select_memories_with_llm(target_message, chat_history, candidate_memories)
|
||||
|
||||
async def _select_memories_with_llm(
|
||||
self, target_message, chat_history: List[DatabaseMessages], candidate_memories
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
使用 LLM 选择合适的记忆
|
||||
|
||||
|
||||
Args:
|
||||
target_message: 目标消息
|
||||
chat_history_prompt: 聊天历史
|
||||
candidate_memories: 候选记忆列表,每个记忆包含 memory_id、keyword、content
|
||||
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 选择的记忆列表,格式为 (keyword, content)
|
||||
"""
|
||||
try:
|
||||
# 构建聊天历史字符串
|
||||
obs_info_text = build_readable_messages(
|
||||
chat_history_prompt,
|
||||
chat_history,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# 构建记忆信息字符串
|
||||
memory_lines = []
|
||||
for memory in candidate_memories:
|
||||
memory_id = memory["memory_id"]
|
||||
keyword = memory["keyword"]
|
||||
content = memory["content"]
|
||||
|
||||
|
||||
# 将 content 列表转换为字符串
|
||||
if isinstance(content, list):
|
||||
content_str = " | ".join(str(item) for item in content)
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
|
||||
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
|
||||
|
||||
|
||||
memory_info = "\n".join(memory_lines)
|
||||
|
||||
|
||||
# 获取并格式化 prompt
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
|
||||
formatted_prompt = prompt_template.format(
|
||||
obs_info_text=obs_info_text,
|
||||
target_message=target_message,
|
||||
memory_info=memory_info
|
||||
obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# 调用 LLM
|
||||
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
|
||||
formatted_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=150
|
||||
formatted_prompt, temperature=0.3, max_tokens=150
|
||||
)
|
||||
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆选择 prompt: {formatted_prompt}")
|
||||
logger.info(f"LLM 记忆选择响应: {response}")
|
||||
else:
|
||||
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
|
||||
logger.debug(f"LLM 记忆选择响应: {response}")
|
||||
|
||||
|
||||
# 解析响应获取选择的记忆编号
|
||||
try:
|
||||
fixed_json = repair_json(response)
|
||||
|
||||
|
||||
# 解析为 Python 对象
|
||||
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||
|
||||
# 提取 memory_ids 字段
|
||||
memory_ids_str = result.get("memory_ids", "")
|
||||
|
||||
# 解析逗号分隔的编号
|
||||
if memory_ids_str:
|
||||
|
||||
# 提取 memory_ids 字段并解析逗号分隔的编号
|
||||
if memory_ids_str := result.get("memory_ids", ""):
|
||||
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
|
||||
# 过滤掉空字符串和无效编号
|
||||
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
|
||||
|
|
@ -236,26 +218,24 @@ class MemoryActivator:
|
|||
except Exception as e:
|
||||
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
|
||||
selected_memory_ids = []
|
||||
|
||||
|
||||
# 根据编号筛选记忆
|
||||
selected_memories = []
|
||||
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
|
||||
|
||||
for memory_id in selected_memory_ids:
|
||||
if memory_id in memory_id_to_memory:
|
||||
selected_memories.append(memory_id_to_memory[memory_id])
|
||||
|
||||
|
||||
selected_memories = [
|
||||
memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory
|
||||
]
|
||||
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
|
||||
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
|
||||
|
||||
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
|
||||
# 出错时返回前3个候选记忆作为备选,转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
|
|
|||
|
|
@ -70,13 +70,10 @@ class ActionModifier:
|
|||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
|
||||
|
||||
chat_content = build_readable_messages(
|
||||
temp_msg_list_before_now_half,
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
|
|
|
|||
|
|
@ -33,6 +33,9 @@ def init_prompt():
|
|||
{time_block}
|
||||
{identity_block}
|
||||
你现在需要根据聊天内容,选择的合适的action来参与聊天。
|
||||
请你根据以下行事风格来决定action:
|
||||
{plan_style}
|
||||
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
{chat_content_block}
|
||||
|
||||
|
|
@ -280,11 +283,8 @@ class ActionPlanner:
|
|||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_msg_list_before_now = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=temp_msg_list_before_now,
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
|
|
@ -388,6 +388,7 @@ class ActionPlanner:
|
|||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
identity_block=identity_block,
|
||||
plan_style = global_config.personality.plan_style
|
||||
)
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import List, Optional, Dict, Any, Tuple
|
|||
from datetime import datetime
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
|
@ -156,7 +157,7 @@ class DefaultReplyer:
|
|||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
choosen_actions: Optional[List[Dict[str, Any]]] = None,
|
||||
chosen_actions: Optional[List[Dict[str, Any]]] = None,
|
||||
enable_tool: bool = True,
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
|
|
@ -171,7 +172,7 @@ class DefaultReplyer:
|
|||
extra_info: 额外信息,用于补充上下文
|
||||
reply_reason: 回复原因
|
||||
available_actions: 可用的动作信息字典
|
||||
choosen_actions: 已选动作
|
||||
chosen_actions: 已选动作
|
||||
enable_tool: 是否启用工具调用
|
||||
from_plugin: 是否来自插件
|
||||
|
||||
|
|
@ -189,7 +190,7 @@ class DefaultReplyer:
|
|||
prompt, selected_expressions = await self.build_prompt_reply_context(
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
choosen_actions=choosen_actions,
|
||||
chosen_actions=chosen_actions,
|
||||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
|
|
@ -296,7 +297,7 @@ class DefaultReplyer:
|
|||
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
|
||||
if sender == global_config.bot.nickname:
|
||||
return ""
|
||||
|
||||
|
|
@ -352,7 +353,7 @@ class DefaultReplyer:
|
|||
|
||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||
|
||||
async def build_memory_block(self, chat_history: List[Dict[str, Any]], target: str) -> str:
|
||||
async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
|
||||
"""构建记忆块
|
||||
|
||||
Args:
|
||||
|
|
@ -369,7 +370,7 @@ class DefaultReplyer:
|
|||
instant_memory = None
|
||||
|
||||
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||
target_message=target, chat_history_prompt=chat_history
|
||||
target_message=target, chat_history=chat_history
|
||||
)
|
||||
|
||||
if global_config.memory.enable_instant_memory:
|
||||
|
|
@ -433,7 +434,7 @@ class DefaultReplyer:
|
|||
logger.error(f"工具信息获取失败: {e}")
|
||||
return ""
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||
def _parse_reply_target(self, target_message: Optional[str]) -> Tuple[str, str]:
|
||||
"""解析回复目标消息
|
||||
|
||||
Args:
|
||||
|
|
@ -514,7 +515,7 @@ class DefaultReplyer:
|
|||
return name, result, duration
|
||||
|
||||
def build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
||||
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
构建 s4u 风格的分离对话 prompt
|
||||
|
|
@ -530,16 +531,16 @@ class DefaultReplyer:
|
|||
bot_id = str(global_config.bot.qq_account)
|
||||
|
||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
||||
for msg_dict in message_list_before_now:
|
||||
for msg in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
reply_to = msg_dict.get("reply_to", "")
|
||||
msg_user_id = str(msg.user_info.user_id)
|
||||
reply_to = msg.reply_to
|
||||
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
# bot 和目标用户的对话
|
||||
core_dialogue_list.append(msg_dict)
|
||||
core_dialogue_list.append(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
||||
logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
|
|
@ -574,7 +575,6 @@ class DefaultReplyer:
|
|||
core_dialogue_prompt_str = build_readable_messages(
|
||||
core_dialogue_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
|
|
@ -640,7 +640,7 @@ class DefaultReplyer:
|
|||
|
||||
action_descriptions = ""
|
||||
if available_actions:
|
||||
action_descriptions = "你可以做以下这些动作:\n"
|
||||
action_descriptions = "除了进行回复之外,你可以做以下这些动作,不过这些动作由另一个模型决定,:\n"
|
||||
for action_name, action_info in available_actions.items():
|
||||
action_description = action_info.description
|
||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||
|
|
@ -658,7 +658,7 @@ class DefaultReplyer:
|
|||
choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
|
||||
|
||||
if choosen_action_descriptions:
|
||||
action_descriptions += "根据聊天情况,你决定在回复的同时做以下这些动作:\n"
|
||||
action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n"
|
||||
action_descriptions += choosen_action_descriptions
|
||||
|
||||
return action_descriptions
|
||||
|
|
@ -668,7 +668,7 @@ class DefaultReplyer:
|
|||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
choosen_actions: Optional[List[Dict[str, Any]]] = None,
|
||||
chosen_actions: Optional[List[Dict[str, Any]]] = None,
|
||||
enable_tool: bool = True,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, List[int]]:
|
||||
|
|
@ -679,7 +679,7 @@ class DefaultReplyer:
|
|||
extra_info: 额外信息,用于补充上下文
|
||||
reply_reason: 回复原因
|
||||
available_actions: 可用动作
|
||||
choosen_actions: 已选动作
|
||||
chosen_actions: 已选动作
|
||||
enable_timeout: 是否启用超时处理
|
||||
enable_tool: 是否启用工具调用
|
||||
reply_message: 回复的原始消息
|
||||
|
|
@ -712,27 +712,21 @@ class DefaultReplyer:
|
|||
|
||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size * 1,
|
||||
)
|
||||
temp_msg_list_before_long = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_long]
|
||||
|
||||
# TODO: 修复!
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
temp_msg_list_before_short = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_short]
|
||||
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
temp_msg_list_before_short,
|
||||
message_list_before_short,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
|
|
@ -744,12 +738,12 @@ class DefaultReplyer:
|
|||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||
),
|
||||
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
|
||||
self._time_and_run_task(self.build_memory_block(temp_msg_list_before_short, target), "memory_block"),
|
||||
self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||
),
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, choosen_actions), "actions_info"),
|
||||
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
|
|
@ -810,25 +804,9 @@ class DefaultReplyer:
|
|||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
# if is_group_chat:
|
||||
# chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
# else:
|
||||
# chat_target_name = "对方"
|
||||
# if self.chat_target_info:
|
||||
# chat_target_name = (
|
||||
# self.chat_target_info.get("person_name") or self.chat_target_info.get("user_nickname") or "对方"
|
||||
# )
|
||||
# chat_target_1 = await global_prompt_manager.format_prompt(
|
||||
# "chat_target_private1", sender_name=chat_target_name
|
||||
# )
|
||||
# chat_target_2 = await global_prompt_manager.format_prompt(
|
||||
# "chat_target_private2", sender_name=chat_target_name
|
||||
# )
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||
temp_msg_list_before_long, user_id, sender
|
||||
message_list_before_now_long, user_id, sender
|
||||
)
|
||||
|
||||
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
|
||||
|
|
@ -879,7 +857,7 @@ class DefaultReplyer:
|
|||
reason: str,
|
||||
reply_to: str,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, List[int]]: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
|
@ -902,20 +880,16 @@ class DefaultReplyer:
|
|||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
temp_msg_list_before_now_half,
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 并行执行2个构建任务
|
||||
(expression_habits_block, selected_expressions), relation_info = await asyncio.gather(
|
||||
(expression_habits_block, _), relation_info = await asyncio.gather(
|
||||
self.build_expression_habits(chat_talking_prompt_half, target),
|
||||
self.build_relation_info(sender, target),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import time # 导入 time 模块以获取当前时间
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
|
||||
|
|
@ -6,14 +6,17 @@ from typing import List, Dict, Any, Tuple, Optional, Callable
|
|||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_data_model import MessageAndActionModel
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.person_info.person_info import Person, get_person_id
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("chat_message_builder")
|
||||
|
||||
|
||||
def replace_user_references_sync(
|
||||
|
|
@ -349,7 +352,9 @@ def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit
|
|||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[DatabaseMessages]:
|
||||
def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
|
|
@ -390,16 +395,16 @@ def num_new_messages_since_with_users(
|
|||
|
||||
|
||||
def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[MessageAndActionModel],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
message_id_list: Optional[List[DatabaseMessages]] = None,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
# sourcery skip: use-getitem-for-re-match-groups
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
|
||||
|
|
@ -418,7 +423,7 @@ def _build_readable_messages_internal(
|
|||
if not messages:
|
||||
return "", [], pic_id_mapping or {}, pic_counter
|
||||
|
||||
message_details_raw: List[Tuple[float, str, str, bool]] = []
|
||||
detailed_messages_raw: List[Tuple[float, str, str, bool]] = []
|
||||
|
||||
# 使用传入的映射字典,如果没有则创建新的
|
||||
if pic_id_mapping is None:
|
||||
|
|
@ -426,25 +431,26 @@ def _build_readable_messages_internal(
|
|||
current_pic_counter = pic_counter
|
||||
|
||||
# 创建时间戳到消息ID的映射,用于在消息前添加[id]标识符
|
||||
timestamp_to_id = {}
|
||||
timestamp_to_id_mapping: Dict[float, str] = {}
|
||||
if message_id_list:
|
||||
for item in message_id_list:
|
||||
message = item.get("message", {})
|
||||
timestamp = message.get("time")
|
||||
for msg in message_id_list:
|
||||
timestamp = msg.time
|
||||
if timestamp is not None:
|
||||
timestamp_to_id[timestamp] = item.get("id", "")
|
||||
timestamp_to_id_mapping[timestamp] = msg.message_id
|
||||
|
||||
def process_pic_ids(content: str) -> str:
|
||||
def process_pic_ids(content: Optional[str]) -> str:
|
||||
"""处理内容中的图片ID,将其替换为[图片x]格式"""
|
||||
nonlocal current_pic_counter
|
||||
if content is None:
|
||||
logger.warning("Content is None when processing pic IDs.")
|
||||
raise ValueError("Content is None")
|
||||
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(match):
|
||||
def replace_pic_id(match: re.Match) -> str:
|
||||
nonlocal current_pic_counter
|
||||
nonlocal pic_counter
|
||||
pic_id = match.group(1)
|
||||
|
||||
if pic_id not in pic_id_mapping:
|
||||
pic_id_mapping[pic_id] = f"图片{current_pic_counter}"
|
||||
current_pic_counter += 1
|
||||
|
|
@ -453,42 +459,23 @@ def _build_readable_messages_internal(
|
|||
|
||||
return re.sub(pic_pattern, replace_pic_id, content)
|
||||
|
||||
# 1 & 2: 获取发送者信息并提取消息组件
|
||||
for msg in messages:
|
||||
# 检查是否是动作记录
|
||||
if msg.get("is_action_record", False):
|
||||
is_action = True
|
||||
timestamp: float = msg.get("time") # type: ignore
|
||||
content = msg.get("display_message", "")
|
||||
# 1: 获取发送者信息并提取消息组件
|
||||
for message in messages:
|
||||
if message.is_action_record:
|
||||
# 对于动作记录,也处理图片ID
|
||||
content = process_pic_ids(content)
|
||||
message_details_raw.append((timestamp, global_config.bot.nickname, content, is_action))
|
||||
content = process_pic_ids(message.display_message)
|
||||
detailed_messages_raw.append((message.time, message.user_nickname, content, True))
|
||||
continue
|
||||
|
||||
# 检查并修复缺少的user_info字段
|
||||
if "user_info" not in msg:
|
||||
# 创建user_info字段
|
||||
msg["user_info"] = {
|
||||
"platform": msg.get("user_platform", ""),
|
||||
"user_id": msg.get("user_id", ""),
|
||||
"user_nickname": msg.get("user_nickname", ""),
|
||||
"user_cardname": msg.get("user_cardname", ""),
|
||||
}
|
||||
platform = message.user_platform
|
||||
user_id = message.user_id
|
||||
user_nickname = message.user_nickname
|
||||
user_cardname = message.user_cardname
|
||||
|
||||
user_info = msg.get("user_info", {})
|
||||
platform = user_info.get("platform")
|
||||
user_id = user_info.get("user_id")
|
||||
|
||||
user_nickname = user_info.get("user_nickname")
|
||||
user_cardname = user_info.get("user_cardname")
|
||||
|
||||
timestamp: float = msg.get("time") # type: ignore
|
||||
content: str
|
||||
if msg.get("display_message"):
|
||||
content = msg.get("display_message", "")
|
||||
else:
|
||||
content = msg.get("processed_plain_text", "") # 默认空字符串
|
||||
timestamp = message.time
|
||||
content = message.display_message or message.processed_plain_text or ""
|
||||
|
||||
# 向下兼容
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
if "ⁿ" in content:
|
||||
|
|
@ -504,52 +491,32 @@ def _build_readable_messages_internal(
|
|||
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||
person_name: str
|
||||
person_name = (
|
||||
person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
|
||||
)
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
person_name = person.person_name or user_id # type: ignore
|
||||
|
||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||
if not person_name:
|
||||
if user_cardname:
|
||||
person_name = f"昵称:{user_cardname}"
|
||||
elif user_nickname:
|
||||
person_name = f"{user_nickname}"
|
||||
else:
|
||||
person_name = "某人"
|
||||
|
||||
# 使用独立函数处理用户引用格式
|
||||
content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name)
|
||||
if content := replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name):
|
||||
detailed_messages_raw.append((timestamp, person_name, content, False))
|
||||
|
||||
target_str = "这是QQ的一个功能,用于提及某人,但没那么明显"
|
||||
if target_str in content and random.random() < 0.6:
|
||||
content = content.replace(target_str, "")
|
||||
|
||||
if content != "":
|
||||
message_details_raw.append((timestamp, person_name, content, False))
|
||||
|
||||
if not message_details_raw:
|
||||
if not detailed_messages_raw:
|
||||
return "", [], pic_id_mapping, current_pic_counter
|
||||
|
||||
message_details_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面
|
||||
detailed_messages_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面
|
||||
detailed_message: List[Tuple[float, str, str, bool]] = []
|
||||
|
||||
# 为每条消息添加一个标记,指示它是否是动作记录
|
||||
message_details_with_flags = []
|
||||
for timestamp, name, content, is_action in message_details_raw:
|
||||
message_details_with_flags.append((timestamp, name, content, is_action))
|
||||
|
||||
# 应用截断逻辑 (如果 truncate 为 True)
|
||||
message_details: List[Tuple[float, str, str, bool]] = []
|
||||
n_messages = len(message_details_with_flags)
|
||||
if truncate and n_messages > 0:
|
||||
for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags):
|
||||
# 2. 应用消息截断逻辑
|
||||
messages_count = len(detailed_messages_raw)
|
||||
if truncate and messages_count > 0:
|
||||
for i, (timestamp, name, content, is_action) in enumerate(detailed_messages_raw):
|
||||
# 对于动作记录,不进行截断
|
||||
if is_action:
|
||||
message_details.append((timestamp, name, content, is_action))
|
||||
detailed_message.append((timestamp, name, content, is_action))
|
||||
continue
|
||||
|
||||
percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
|
||||
percentile = i / messages_count # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
|
||||
original_len = len(content)
|
||||
limit = -1 # 默认不截断
|
||||
|
||||
|
|
@ -562,116 +529,42 @@ def _build_readable_messages_internal(
|
|||
elif percentile < 0.7: # 60% 到 80% 之前的消息 (即中间的 20%)
|
||||
limit = 200
|
||||
replace_content = "......(内容太长了)"
|
||||
elif percentile < 1.0: # 80% 到 100% 之前的消息 (即较新的 20%)
|
||||
elif percentile <= 1.0: # 80% 到 100% 之前的消息 (即较新的 20%)
|
||||
limit = 400
|
||||
replace_content = "......(太长了)"
|
||||
replace_content = "......(内容太长了)"
|
||||
|
||||
truncated_content = content
|
||||
if 0 < limit < original_len:
|
||||
truncated_content = f"{content[:limit]}{replace_content}"
|
||||
|
||||
message_details.append((timestamp, name, truncated_content, is_action))
|
||||
detailed_message.append((timestamp, name, truncated_content, is_action))
|
||||
else:
|
||||
# 如果不截断,直接使用原始列表
|
||||
message_details = message_details_with_flags
|
||||
detailed_message = detailed_messages_raw
|
||||
|
||||
# 3: 合并连续消息 (如果 merge_messages 为 True)
|
||||
merged_messages = []
|
||||
if merge_messages and message_details:
|
||||
# 初始化第一个合并块
|
||||
current_merge = {
|
||||
"name": message_details[0][1],
|
||||
"start_time": message_details[0][0],
|
||||
"end_time": message_details[0][0],
|
||||
"content": [message_details[0][2]],
|
||||
"is_action": message_details[0][3],
|
||||
}
|
||||
# 3: 格式化为字符串
|
||||
output_lines: List[str] = []
|
||||
|
||||
for i in range(1, len(message_details)):
|
||||
timestamp, name, content, is_action = message_details[i]
|
||||
for timestamp, name, content, is_action in detailed_message:
|
||||
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
|
||||
|
||||
# 对于动作记录,不进行合并
|
||||
if is_action or current_merge["is_action"]:
|
||||
# 保存当前的合并块
|
||||
merged_messages.append(current_merge)
|
||||
# 创建新的块
|
||||
current_merge = {
|
||||
"name": name,
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"content": [content],
|
||||
"is_action": is_action,
|
||||
}
|
||||
continue
|
||||
# 查找消息id(如果有)并构建id_prefix
|
||||
message_id = timestamp_to_id_mapping.get(timestamp)
|
||||
id_prefix = f"[{message_id}]" if message_id else ""
|
||||
|
||||
# 如果是同一个人发送的连续消息且时间间隔小于等于60秒
|
||||
if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60):
|
||||
current_merge["content"].append(content)
|
||||
current_merge["end_time"] = timestamp # 更新最后消息时间
|
||||
else:
|
||||
# 保存上一个合并块
|
||||
merged_messages.append(current_merge)
|
||||
# 开始新的合并块
|
||||
current_merge = {
|
||||
"name": name,
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"content": [content],
|
||||
"is_action": is_action,
|
||||
}
|
||||
# 添加最后一个合并块
|
||||
merged_messages.append(current_merge)
|
||||
elif message_details: # 如果不合并消息,则每个消息都是一个独立的块
|
||||
for timestamp, name, content, is_action in message_details:
|
||||
merged_messages.append(
|
||||
{
|
||||
"name": name,
|
||||
"start_time": timestamp, # 起始和结束时间相同
|
||||
"end_time": timestamp,
|
||||
"content": [content], # 内容只有一个元素
|
||||
"is_action": is_action,
|
||||
}
|
||||
)
|
||||
|
||||
# 4 & 5: 格式化为字符串
|
||||
output_lines = []
|
||||
|
||||
for _i, merged in enumerate(merged_messages):
|
||||
# 使用指定的 timestamp_mode 格式化时间
|
||||
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
||||
|
||||
# 查找对应的消息ID
|
||||
message_id = timestamp_to_id.get(merged["start_time"], "")
|
||||
id_prefix = f"[{message_id}] " if message_id else ""
|
||||
|
||||
# 检查是否是动作记录
|
||||
if merged["is_action"]:
|
||||
if is_action:
|
||||
# 对于动作记录,使用特殊格式
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {merged['content'][0]}")
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {content}")
|
||||
else:
|
||||
header = f"{id_prefix}{readable_time}, {merged['name']} :"
|
||||
output_lines.append(header)
|
||||
# 将内容合并,并添加缩进
|
||||
for line in merged["content"]:
|
||||
stripped_line = line.strip()
|
||||
if stripped_line: # 过滤空行
|
||||
# 移除末尾句号,添加分号 - 这个逻辑似乎有点奇怪,暂时保留
|
||||
if stripped_line.endswith("。"):
|
||||
stripped_line = stripped_line[:-1]
|
||||
# 如果内容被截断,结尾已经是 ...(内容太长),不再添加分号
|
||||
if not stripped_line.endswith("(内容太长)"):
|
||||
output_lines.append(f"{stripped_line}")
|
||||
else:
|
||||
output_lines.append(stripped_line) # 直接添加截断后的内容
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
|
||||
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
|
||||
|
||||
# 移除可能的多余换行,然后合并
|
||||
formatted_string = "".join(output_lines).strip()
|
||||
|
||||
# 返回格式化后的字符串、消息详情列表、图片映射字典和更新后的计数器
|
||||
return (
|
||||
formatted_string,
|
||||
[(t, n, c) for t, n, c, is_action in message_details if not is_action],
|
||||
[(t, n, c) for t, n, c, is_action in detailed_message if not is_action],
|
||||
pic_id_mapping,
|
||||
current_pic_counter,
|
||||
)
|
||||
|
|
@ -755,9 +648,8 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
|||
|
||||
|
||||
async def build_readable_messages_with_list(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
|
|
@ -766,7 +658,7 @@ async def build_readable_messages_with_list(
|
|||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
convert_DatabaseMessages_to_MessageAndActionModel(messages), replace_bot_name, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||
|
|
@ -776,15 +668,14 @@ async def build_readable_messages_with_list(
|
|||
|
||||
|
||||
def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
) -> Tuple[str, List[DatabaseMessages]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
|
|
@ -794,7 +685,6 @@ def build_readable_messages_with_id(
|
|||
formatted_string = build_readable_messages(
|
||||
messages=messages,
|
||||
replace_bot_name=replace_bot_name,
|
||||
merge_messages=merge_messages,
|
||||
timestamp_mode=timestamp_mode,
|
||||
truncate=truncate,
|
||||
show_actions=show_actions,
|
||||
|
|
@ -807,15 +697,14 @@ def build_readable_messages_with_id(
|
|||
|
||||
|
||||
def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
message_id_list: Optional[List[DatabaseMessages]] = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
|
|
@ -831,19 +720,32 @@ def build_readable_messages(
|
|||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
"""
|
||||
# WIP HERE and BELOW ----------------------------------------------
|
||||
# 创建messages的深拷贝,避免修改原始列表
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
copy_messages = [msg.copy() for msg in messages]
|
||||
copy_messages: List[MessageAndActionModel] = [
|
||||
MessageAndActionModel(
|
||||
msg.time,
|
||||
msg.user_info.user_id,
|
||||
msg.user_info.platform,
|
||||
msg.user_info.user_nickname,
|
||||
msg.user_info.user_cardname,
|
||||
msg.processed_plain_text,
|
||||
msg.display_message,
|
||||
msg.chat_info.platform,
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
if show_actions and copy_messages:
|
||||
# 获取所有消息的时间范围
|
||||
min_time = min(msg.get("time", 0) for msg in copy_messages)
|
||||
max_time = max(msg.get("time", 0) for msg in copy_messages)
|
||||
min_time = min(msg.time or 0 for msg in copy_messages)
|
||||
max_time = max(msg.time or 0 for msg in copy_messages)
|
||||
|
||||
# 从第一条消息中获取chat_id
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
chat_id = messages[0].chat_id if messages else None
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = (
|
||||
|
|
@ -863,34 +765,34 @@ def build_readable_messages(
|
|||
)
|
||||
|
||||
# 合并两部分动作记录
|
||||
actions = list(actions_in_range) + list(action_after_latest)
|
||||
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
|
||||
|
||||
# 将动作记录转换为消息格式
|
||||
for action in actions:
|
||||
# 只有当build_into_prompt为True时才添加动作记录
|
||||
if action.action_build_into_prompt:
|
||||
action_msg = {
|
||||
"time": action.time,
|
||||
"user_id": global_config.bot.qq_account, # 使用机器人的QQ账号
|
||||
"user_nickname": global_config.bot.nickname, # 使用机器人的昵称
|
||||
"user_cardname": "", # 机器人没有群名片
|
||||
"processed_plain_text": f"{action.action_prompt_display}",
|
||||
"display_message": f"{action.action_prompt_display}",
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
"is_action_record": True, # 添加标识字段
|
||||
"action_name": action.action_name, # 保存动作名称
|
||||
}
|
||||
action_msg = MessageAndActionModel(
|
||||
time=float(action.time), # type: ignore
|
||||
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
|
||||
user_platform=global_config.bot.platform, # 使用机器人的平台
|
||||
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
|
||||
user_cardname="", # 机器人没有群名片
|
||||
processed_plain_text=f"{action.action_prompt_display}",
|
||||
display_message=f"{action.action_prompt_display}",
|
||||
chat_info_platform=str(action.chat_info_platform),
|
||||
is_action_record=True, # 添加标识字段
|
||||
action_name=str(action.action_name), # 保存动作名称
|
||||
)
|
||||
copy_messages.append(action_msg)
|
||||
|
||||
# 重新按时间排序
|
||||
copy_messages.sort(key=lambda x: x.get("time", 0))
|
||||
copy_messages.sort(key=lambda x: x.time or 0)
|
||||
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
copy_messages,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
timestamp_mode,
|
||||
truncate,
|
||||
show_pic=show_pic,
|
||||
|
|
@ -905,8 +807,8 @@ def build_readable_messages(
|
|||
return formatted_string
|
||||
else:
|
||||
# 按 read_mark 分割消息
|
||||
messages_before_mark = [msg for msg in copy_messages if msg.get("time", 0) <= read_mark]
|
||||
messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark]
|
||||
messages_before_mark = [msg for msg in copy_messages if (msg.time or 0) <= read_mark]
|
||||
messages_after_mark = [msg for msg in copy_messages if (msg.time or 0) > read_mark]
|
||||
|
||||
# 共享的图片映射字典和计数器
|
||||
pic_id_mapping = {}
|
||||
|
|
@ -916,7 +818,6 @@ def build_readable_messages(
|
|||
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
|
||||
messages_before_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
timestamp_mode,
|
||||
truncate,
|
||||
pic_id_mapping,
|
||||
|
|
@ -927,7 +828,6 @@ def build_readable_messages(
|
|||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
timestamp_mode,
|
||||
False,
|
||||
pic_id_mapping,
|
||||
|
|
@ -960,13 +860,13 @@ def build_readable_messages(
|
|||
return "".join(result_parts)
|
||||
|
||||
|
||||
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
||||
"""
|
||||
构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。
|
||||
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段,将bbb映射为匿名占位符。
|
||||
"""
|
||||
if not messages:
|
||||
print("111111111111没有消息,无法构建匿名消息")
|
||||
logger.warning("没有消息,无法构建匿名消息")
|
||||
return ""
|
||||
|
||||
person_map = {}
|
||||
|
|
@ -1017,14 +917,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||
|
||||
for msg in messages:
|
||||
try:
|
||||
platform: str = msg.get("chat_info_platform") # type: ignore
|
||||
user_id = msg.get("user_id")
|
||||
_timestamp = msg.get("time")
|
||||
content: str = ""
|
||||
if msg.get("display_message"):
|
||||
content = msg.get("display_message", "")
|
||||
else:
|
||||
content = msg.get("processed_plain_text", "")
|
||||
platform = msg.chat_info.platform
|
||||
user_id = msg.user_info.user_id
|
||||
content = msg.display_message or msg.processed_plain_text or ""
|
||||
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
|
|
@ -1101,3 +996,22 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||
person_ids_set.add(person_id)
|
||||
|
||||
return list(person_ids_set) # 将集合转换为列表返回
|
||||
|
||||
|
||||
def convert_DatabaseMessages_to_MessageAndActionModel(message: List[DatabaseMessages]) -> List[MessageAndActionModel]:
|
||||
"""
|
||||
将 DatabaseMessages 列表转换为 MessageAndActionModel 列表。
|
||||
"""
|
||||
return [
|
||||
MessageAndActionModel(
|
||||
time=msg.time,
|
||||
user_id=msg.user_info.user_id,
|
||||
user_platform=msg.user_info.platform,
|
||||
user_nickname=msg.user_info.user_nickname,
|
||||
user_cardname=msg.user_info.user_cardname,
|
||||
processed_plain_text=msg.processed_plain_text,
|
||||
display_message=msg.display_message,
|
||||
chat_info_platform=msg.chat_info.platform,
|
||||
)
|
||||
for msg in message
|
||||
]
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from typing import Optional, Tuple, Dict, List, Any
|
|||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
|
|
@ -113,6 +114,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||
try:
|
||||
embedding, _ = await llm.get_embedding(text)
|
||||
|
|
@ -151,10 +153,13 @@ def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> li
|
|||
if (
|
||||
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
|
||||
and db_msg.user_info.user_id != global_config.bot.qq_account
|
||||
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) not in who_chat_in_group
|
||||
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
|
||||
not in who_chat_in_group
|
||||
and len(who_chat_in_group) < 5
|
||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
who_chat_in_group.append((db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname))
|
||||
who_chat_in_group.append(
|
||||
(db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
|
||||
)
|
||||
|
||||
return who_chat_in_group
|
||||
|
||||
|
|
@ -640,9 +645,9 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||
target_info = TargetPersonInfo(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
user_nickname=user_info.user_nickname, # type: ignore
|
||||
user_nickname=user_info.user_nickname, # type: ignore
|
||||
person_id=None,
|
||||
person_name=None
|
||||
person_name=None,
|
||||
)
|
||||
|
||||
# Try to fetch person info
|
||||
|
|
@ -669,17 +674,17 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||
return is_group_chat, chat_target_info
|
||||
|
||||
|
||||
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
def assign_message_ids(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
|
||||
Returns:
|
||||
包含 {'id': str, 'message': any} 格式的字典列表
|
||||
List[DatabaseMessages]: 分配了唯一ID的消息列表(写入message_id属性)
|
||||
"""
|
||||
result = []
|
||||
result: List[DatabaseMessages] = list(messages) # 复制原始消息列表
|
||||
used_ids = set()
|
||||
len_i = len(messages)
|
||||
if len_i > 100:
|
||||
|
|
@ -688,95 +693,86 @@ def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
|||
else:
|
||||
a = 1
|
||||
b = 9
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
|
||||
for i, _ in enumerate(result):
|
||||
# 生成唯一的简短ID
|
||||
while True:
|
||||
# 使用索引+随机数生成简短ID
|
||||
random_suffix = random.randint(a, b)
|
||||
message_id = f"m{i+1}{random_suffix}"
|
||||
|
||||
message_id = f"m{i + 1}{random_suffix}"
|
||||
|
||||
if message_id not in used_ids:
|
||||
used_ids.add(message_id)
|
||||
break
|
||||
|
||||
result.append({
|
||||
'id': message_id,
|
||||
'message': message
|
||||
})
|
||||
|
||||
result[i].message_id = message_id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def assign_message_ids_flexible(
|
||||
messages: list,
|
||||
prefix: str = "msg",
|
||||
id_length: int = 6,
|
||||
use_timestamp: bool = False
|
||||
) -> list:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID(增强版)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
prefix: ID前缀,默认为"msg"
|
||||
id_length: ID的总长度(不包括前缀),默认为6
|
||||
use_timestamp: 是否在ID中包含时间戳,默认为False
|
||||
|
||||
Returns:
|
||||
包含 {'id': str, 'message': any} 格式的字典列表
|
||||
"""
|
||||
result = []
|
||||
used_ids = set()
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
# 生成唯一的ID
|
||||
while True:
|
||||
if use_timestamp:
|
||||
# 使用时间戳的后几位 + 随机字符
|
||||
timestamp_suffix = str(int(time.time() * 1000))[-3:]
|
||||
remaining_length = id_length - 3
|
||||
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
message_id = f"{prefix}{timestamp_suffix}{random_chars}"
|
||||
else:
|
||||
# 使用索引 + 随机字符
|
||||
index_str = str(i + 1)
|
||||
remaining_length = max(1, id_length - len(index_str))
|
||||
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
message_id = f"{prefix}{index_str}{random_chars}"
|
||||
|
||||
if message_id not in used_ids:
|
||||
used_ids.add(message_id)
|
||||
break
|
||||
|
||||
result.append({
|
||||
'id': message_id,
|
||||
'message': message
|
||||
})
|
||||
|
||||
return result
|
||||
# def assign_message_ids_flexible(
|
||||
# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
|
||||
# ) -> list:
|
||||
# """
|
||||
# 为消息列表中的每个消息分配唯一的简短随机ID(增强版)
|
||||
|
||||
# Args:
|
||||
# messages: 消息列表
|
||||
# prefix: ID前缀,默认为"msg"
|
||||
# id_length: ID的总长度(不包括前缀),默认为6
|
||||
# use_timestamp: 是否在ID中包含时间戳,默认为False
|
||||
|
||||
# Returns:
|
||||
# 包含 {'id': str, 'message': any} 格式的字典列表
|
||||
# """
|
||||
# result = []
|
||||
# used_ids = set()
|
||||
|
||||
# for i, message in enumerate(messages):
|
||||
# # 生成唯一的ID
|
||||
# while True:
|
||||
# if use_timestamp:
|
||||
# # 使用时间戳的后几位 + 随机字符
|
||||
# timestamp_suffix = str(int(time.time() * 1000))[-3:]
|
||||
# remaining_length = id_length - 3
|
||||
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
# message_id = f"{prefix}{timestamp_suffix}{random_chars}"
|
||||
# else:
|
||||
# # 使用索引 + 随机字符
|
||||
# index_str = str(i + 1)
|
||||
# remaining_length = max(1, id_length - len(index_str))
|
||||
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
# message_id = f"{prefix}{index_str}{random_chars}"
|
||||
|
||||
# if message_id not in used_ids:
|
||||
# used_ids.add(message_id)
|
||||
# break
|
||||
|
||||
# result.append({"id": message_id, "message": message})
|
||||
|
||||
# return result
|
||||
|
||||
|
||||
# 使用示例:
|
||||
# messages = ["Hello", "World", "Test message"]
|
||||
#
|
||||
#
|
||||
# # 基础版本
|
||||
# result1 = assign_message_ids(messages)
|
||||
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
|
||||
#
|
||||
#
|
||||
# # 增强版本 - 自定义前缀和长度
|
||||
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
|
||||
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
|
||||
#
|
||||
#
|
||||
# # 增强版本 - 使用时间戳
|
||||
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
|
||||
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
||||
|
||||
|
||||
def parse_keywords_string(keywords_input) -> list[str]:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
统一的关键词解析函数,支持多种格式的关键词字符串解析
|
||||
|
||||
|
||||
支持的格式:
|
||||
1. 字符串列表格式:'["utils.py", "修改", "代码", "动作"]'
|
||||
2. 斜杠分隔格式:'utils.py/修改/代码/动作'
|
||||
|
|
@ -784,25 +780,25 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
|||
4. 空格分隔格式:'utils.py 修改 代码 动作'
|
||||
5. 已经是列表的情况:["utils.py", "修改", "代码", "动作"]
|
||||
6. JSON格式字符串:'{"keywords": ["utils.py", "修改", "代码", "动作"]}'
|
||||
|
||||
|
||||
Args:
|
||||
keywords_input: 关键词输入,可以是字符串或列表
|
||||
|
||||
|
||||
Returns:
|
||||
list[str]: 解析后的关键词列表,去除空白项
|
||||
"""
|
||||
if not keywords_input:
|
||||
return []
|
||||
|
||||
|
||||
# 如果已经是列表,直接处理
|
||||
if isinstance(keywords_input, list):
|
||||
return [str(k).strip() for k in keywords_input if str(k).strip()]
|
||||
|
||||
|
||||
# 转换为字符串处理
|
||||
keywords_str = str(keywords_input).strip()
|
||||
if not keywords_str:
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
# 尝试作为JSON对象解析(支持 {"keywords": [...]} 格式)
|
||||
json_data = json.loads(keywords_str)
|
||||
|
|
@ -815,7 +811,7 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
|||
return [str(k).strip() for k in json_data if str(k).strip()]
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
# 尝试使用 ast.literal_eval 解析(支持Python字面量格式)
|
||||
parsed = ast.literal_eval(keywords_str)
|
||||
|
|
@ -823,15 +819,15 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
|||
return [str(k).strip() for k in parsed if str(k).strip()]
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
|
||||
# 尝试不同的分隔符
|
||||
separators = ['/', ',', ' ', '|', ';']
|
||||
|
||||
separators = ["/", ",", " ", "|", ";"]
|
||||
|
||||
for separator in separators:
|
||||
if separator in keywords_str:
|
||||
keywords_list = [k.strip() for k in keywords_str.split(separator) if k.strip()]
|
||||
if len(keywords_list) > 1: # 确保分割有效
|
||||
return keywords_list
|
||||
|
||||
|
||||
# 如果没有分隔符,返回单个关键词
|
||||
return [keywords_str] if keywords_str else []
|
||||
return [keywords_str] if keywords_str else []
|
||||
|
|
|
|||
|
|
@ -1,26 +1,28 @@
|
|||
from typing import Dict, Any
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AbstractClassFlag:
|
||||
pass
|
||||
|
||||
class BaseDataModel:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else
|
||||
"""
|
||||
将对象或容器中的 AbstractClassFlag 子类(类对象)或 AbstractClassFlag 实例
|
||||
将对象或容器中的 BaseDataModel 子类(类对象)或 BaseDataModel 实例
|
||||
递归转换为普通 dict,不修改原对象。
|
||||
- 对于类对象(isinstance(value, type) 且 issubclass(..., AbstractClassFlag)),
|
||||
- 对于类对象(isinstance(value, type) 且 issubclass(..., BaseDataModel)),
|
||||
读取类的 __dict__ 中非 dunder 项并递归转换。
|
||||
- 对于实例(isinstance(value, AbstractClassFlag)),读取 vars(instance) 并递归转换。
|
||||
- 对于实例(isinstance(value, BaseDataModel)),读取 vars(instance) 并递归转换。
|
||||
"""
|
||||
|
||||
def _transform(value: Any) -> Any:
|
||||
# 值是类对象且为 AbstractClassFlag 的子类
|
||||
if isinstance(value, type) and issubclass(value, AbstractClassFlag):
|
||||
# 值是类对象且为 BaseDataModel 的子类
|
||||
if isinstance(value, type) and issubclass(value, BaseDataModel):
|
||||
return {k: _transform(v) for k, v in value.__dict__.items() if not k.startswith("__") and not callable(v)}
|
||||
|
||||
# 值是 AbstractClassFlag 的实例
|
||||
if isinstance(value, AbstractClassFlag):
|
||||
# 值是 BaseDataModel 的实例
|
||||
if isinstance(value, BaseDataModel):
|
||||
return {k: _transform(v) for k, v in vars(value).items()}
|
||||
|
||||
# 常见容器类型,递归处理
|
||||
|
|
|
|||
|
|
@ -1,36 +1,41 @@
|
|||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass, field, fields, MISSING
|
||||
from typing import Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
from . import AbstractClassFlag
|
||||
|
||||
@dataclass
|
||||
class DatabaseUserInfo(AbstractClassFlag):
|
||||
class DatabaseUserInfo(BaseDataModel):
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.platform, str), "platform must be a string"
|
||||
assert isinstance(self.user_id, str), "user_id must be a string"
|
||||
assert isinstance(self.user_nickname, str), "user_nickname must be a string"
|
||||
assert isinstance(self.user_cardname, str) or self.user_cardname is None, "user_cardname must be a string or None"
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.platform, str), "platform must be a string"
|
||||
# assert isinstance(self.user_id, str), "user_id must be a string"
|
||||
# assert isinstance(self.user_nickname, str), "user_nickname must be a string"
|
||||
# assert isinstance(self.user_cardname, str) or self.user_cardname is None, (
|
||||
# "user_cardname must be a string or None"
|
||||
# )
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseGroupInfo(AbstractClassFlag):
|
||||
class DatabaseGroupInfo(BaseDataModel):
|
||||
group_id: str = field(default_factory=str)
|
||||
group_name: str = field(default_factory=str)
|
||||
group_platform: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.group_id, str), "group_id must be a string"
|
||||
assert isinstance(self.group_name, str), "group_name must be a string"
|
||||
assert isinstance(self.group_platform, str) or self.group_platform is None, "group_platform must be a string or None"
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.group_id, str), "group_id must be a string"
|
||||
# assert isinstance(self.group_name, str), "group_name must be a string"
|
||||
# assert isinstance(self.group_platform, str) or self.group_platform is None, (
|
||||
# "group_platform must be a string or None"
|
||||
# )
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseChatInfo(AbstractClassFlag):
|
||||
class DatabaseChatInfo(BaseDataModel):
|
||||
stream_id: str = field(default_factory=str)
|
||||
platform: str = field(default_factory=str)
|
||||
create_time: float = field(default_factory=float)
|
||||
|
|
@ -38,123 +43,117 @@ class DatabaseChatInfo(AbstractClassFlag):
|
|||
user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo)
|
||||
group_info: Optional[DatabaseGroupInfo] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.stream_id, str), "stream_id must be a string"
|
||||
assert isinstance(self.platform, str), "platform must be a string"
|
||||
assert isinstance(self.create_time, float), "create_time must be a float"
|
||||
assert isinstance(self.last_active_time, float), "last_active_time must be a float"
|
||||
assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance"
|
||||
assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, "group_info must be a DatabaseGroupInfo instance or None"
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.stream_id, str), "stream_id must be a string"
|
||||
# assert isinstance(self.platform, str), "platform must be a string"
|
||||
# assert isinstance(self.create_time, float), "create_time must be a float"
|
||||
# assert isinstance(self.last_active_time, float), "last_active_time must be a float"
|
||||
# assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance"
|
||||
# assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, (
|
||||
# "group_info must be a DatabaseGroupInfo instance or None"
|
||||
# )
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseMessages(AbstractClassFlag):
|
||||
# chat_info: DatabaseChatInfo
|
||||
# user_info: DatabaseUserInfo
|
||||
# group_info: Optional[DatabaseGroupInfo] = None
|
||||
class DatabaseMessages(BaseDataModel):
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str = "",
|
||||
time: float = 0.0,
|
||||
chat_id: str = "",
|
||||
reply_to: Optional[str] = None,
|
||||
interest_value: Optional[float] = None,
|
||||
key_words: Optional[str] = None,
|
||||
key_words_lite: Optional[str] = None,
|
||||
is_mentioned: Optional[bool] = None,
|
||||
processed_plain_text: Optional[str] = None,
|
||||
display_message: Optional[str] = None,
|
||||
priority_mode: Optional[str] = None,
|
||||
priority_info: Optional[str] = None,
|
||||
additional_config: Optional[str] = None,
|
||||
is_emoji: bool = False,
|
||||
is_picid: bool = False,
|
||||
is_command: bool = False,
|
||||
is_notify: bool = False,
|
||||
selected_expressions: Optional[str] = None,
|
||||
user_id: str = "",
|
||||
user_nickname: str = "",
|
||||
user_cardname: Optional[str] = None,
|
||||
user_platform: str = "",
|
||||
chat_info_group_id: Optional[str] = None,
|
||||
chat_info_group_name: Optional[str] = None,
|
||||
chat_info_group_platform: Optional[str] = None,
|
||||
chat_info_user_id: str = "",
|
||||
chat_info_user_nickname: str = "",
|
||||
chat_info_user_cardname: Optional[str] = None,
|
||||
chat_info_user_platform: str = "",
|
||||
chat_info_stream_id: str = "",
|
||||
chat_info_platform: str = "",
|
||||
chat_info_create_time: float = 0.0,
|
||||
chat_info_last_active_time: float = 0.0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.message_id = message_id
|
||||
self.time = time
|
||||
self.chat_id = chat_id
|
||||
self.reply_to = reply_to
|
||||
self.interest_value = interest_value
|
||||
|
||||
message_id: str = field(default_factory=str)
|
||||
time: float = field(default_factory=float)
|
||||
chat_id: str = field(default_factory=str)
|
||||
reply_to: Optional[str] = None
|
||||
interest_value: Optional[float] = None
|
||||
self.key_words = key_words
|
||||
self.key_words_lite = key_words_lite
|
||||
self.is_mentioned = is_mentioned
|
||||
|
||||
key_words: Optional[str] = None
|
||||
key_words_lite: Optional[str] = None
|
||||
is_mentioned: Optional[bool] = None
|
||||
self.processed_plain_text = processed_plain_text
|
||||
self.display_message = display_message
|
||||
|
||||
# 从 chat_info 扁平化而来的字段
|
||||
# chat_info_stream_id: str = field(default_factory=str)
|
||||
# chat_info_platform: str = field(default_factory=str)
|
||||
# chat_info_user_platform: str = field(default_factory=str)
|
||||
# chat_info_user_id: str = field(default_factory=str)
|
||||
# chat_info_user_nickname: str = field(default_factory=str)
|
||||
# chat_info_user_cardname: Optional[str] = None
|
||||
# chat_info_group_platform: Optional[str] = None
|
||||
# chat_info_group_id: Optional[str] = None
|
||||
# chat_info_group_name: Optional[str] = None
|
||||
# chat_info_create_time: float = field(default_factory=float)
|
||||
# chat_info_last_active_time: float = field(default_factory=float)
|
||||
self.priority_mode = priority_mode
|
||||
self.priority_info = priority_info
|
||||
|
||||
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
|
||||
# user_platform: str = field(default_factory=str)
|
||||
# user_id: str = field(default_factory=str)
|
||||
# user_nickname: str = field(default_factory=str)
|
||||
# user_cardname: Optional[str] = None
|
||||
self.additional_config = additional_config
|
||||
self.is_emoji = is_emoji
|
||||
self.is_picid = is_picid
|
||||
self.is_command = is_command
|
||||
self.is_notify = is_notify
|
||||
|
||||
processed_plain_text: Optional[str] = None # 处理后的纯文本消息
|
||||
display_message: Optional[str] = None # 显示的消息
|
||||
|
||||
priority_mode: Optional[str] = None
|
||||
priority_info: Optional[str] = None
|
||||
|
||||
additional_config: Optional[str] = None
|
||||
is_emoji: bool = False
|
||||
is_picid: bool = False
|
||||
is_command: bool = False
|
||||
is_notify: bool = False
|
||||
|
||||
selected_expressions: Optional[str] = None
|
||||
|
||||
# def __post_init__(self):
|
||||
|
||||
# if self.chat_info_group_id and self.chat_info_group_name:
|
||||
# self.group_info = DatabaseGroupInfo(
|
||||
# group_id=self.chat_info_group_id,
|
||||
# group_name=self.chat_info_group_name,
|
||||
# group_platform=self.chat_info_group_platform,
|
||||
# )
|
||||
|
||||
# chat_user_info = DatabaseUserInfo(
|
||||
# user_id=self.chat_info_user_id,
|
||||
# user_nickname=self.chat_info_user_nickname,
|
||||
# user_cardname=self.chat_info_user_cardname,
|
||||
# platform=self.chat_info_user_platform,
|
||||
# )
|
||||
# self.chat_info = DatabaseChatInfo(
|
||||
# stream_id=self.chat_info_stream_id,
|
||||
# platform=self.chat_info_platform,
|
||||
# create_time=self.chat_info_create_time,
|
||||
# last_active_time=self.chat_info_last_active_time,
|
||||
# user_info=chat_user_info,
|
||||
# group_info=self.group_info,
|
||||
# )
|
||||
def __init__(self, **kwargs: Any):
|
||||
defined = {f.name: f for f in fields(self.__class__)}
|
||||
for name, f in defined.items():
|
||||
if name in kwargs:
|
||||
setattr(self, name, kwargs.pop(name))
|
||||
elif f.default is not MISSING:
|
||||
setattr(self, name, f.default)
|
||||
else:
|
||||
raise TypeError(f"缺失必需字段: {name}")
|
||||
self.selected_expressions = selected_expressions
|
||||
|
||||
self.group_info: Optional[DatabaseGroupInfo] = None
|
||||
self.user_info = DatabaseUserInfo(
|
||||
user_id=kwargs.get("user_id"), # type: ignore
|
||||
user_nickname=kwargs.get("user_nickname"), # type: ignore
|
||||
user_cardname=kwargs.get("user_cardname"), # type: ignore
|
||||
platform=kwargs.get("user_platform"), # type: ignore
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
platform=user_platform,
|
||||
)
|
||||
if kwargs.get("chat_info_group_id") and kwargs.get("chat_info_group_name"):
|
||||
if chat_info_group_id and chat_info_group_name:
|
||||
self.group_info = DatabaseGroupInfo(
|
||||
group_id=kwargs.get("chat_info_group_id"), # type: ignore
|
||||
group_name=kwargs.get("chat_info_group_name"), # type: ignore
|
||||
group_platform=kwargs.get("chat_info_group_platform"), # type: ignore
|
||||
group_id=chat_info_group_id,
|
||||
group_name=chat_info_group_name,
|
||||
group_platform=chat_info_group_platform,
|
||||
)
|
||||
|
||||
chat_user_info = DatabaseUserInfo(
|
||||
user_id=kwargs.get("chat_info_user_id"), # type: ignore
|
||||
user_nickname=kwargs.get("chat_info_user_nickname"), # type: ignore
|
||||
user_cardname=kwargs.get("chat_info_user_cardname"), # type: ignore
|
||||
platform=kwargs.get("chat_info_user_platform"), # type: ignore
|
||||
)
|
||||
|
||||
self.chat_info = DatabaseChatInfo(
|
||||
stream_id=kwargs.get("chat_info_stream_id"), # type: ignore
|
||||
platform=kwargs.get("chat_info_platform"), # type: ignore
|
||||
create_time=kwargs.get("chat_info_create_time"), # type: ignore
|
||||
last_active_time=kwargs.get("chat_info_last_active_time"), # type: ignore
|
||||
user_info=chat_user_info,
|
||||
stream_id=chat_info_stream_id,
|
||||
platform=chat_info_platform,
|
||||
create_time=chat_info_create_time,
|
||||
last_active_time=chat_info_last_active_time,
|
||||
user_info=DatabaseUserInfo(
|
||||
user_id=chat_info_user_id,
|
||||
user_nickname=chat_info_user_nickname,
|
||||
user_cardname=chat_info_user_cardname,
|
||||
platform=chat_info_user_platform,
|
||||
),
|
||||
group_info=self.group_info,
|
||||
)
|
||||
|
||||
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.message_id, str), "message_id must be a string"
|
||||
# assert isinstance(self.time, float), "time must be a float"
|
||||
# assert isinstance(self.chat_id, str), "chat_id must be a string"
|
||||
# assert isinstance(self.reply_to, str) or self.reply_to is None, "reply_to must be a string or None"
|
||||
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
|
||||
# "interest_value must be a float or None"
|
||||
# )
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
@dataclass
|
||||
class TargetPersonInfo:
|
||||
class TargetPersonInfo(BaseDataModel):
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
@dataclass
|
||||
class MessageAndActionModel(BaseDataModel):
|
||||
time: float = field(default_factory=float)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_platform: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
processed_plain_text: Optional[str] = None
|
||||
display_message: Optional[str] = None
|
||||
chat_info_platform: str = field(default_factory=str)
|
||||
is_action_record: bool = field(default=False)
|
||||
action_name: Optional[str] = None
|
||||
|
|
@ -46,6 +46,8 @@ class PersonalityConfig(ConfigBase):
|
|||
|
||||
reply_style: str = ""
|
||||
"""表达风格"""
|
||||
|
||||
plan_style: str = ""
|
||||
|
||||
compress_personality: bool = True
|
||||
"""是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭"""
|
||||
|
|
|
|||
|
|
@ -159,14 +159,23 @@ class ClientRegistry:
|
|||
|
||||
return decorator
|
||||
|
||||
def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient:
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
|
||||
"""
|
||||
获取注册的API客户端实例
|
||||
Args:
|
||||
api_provider: APIProvider实例
|
||||
force_new: 是否强制创建新实例(用于解决事件循环问题)
|
||||
Returns:
|
||||
BaseClient: 注册的API客户端实例
|
||||
"""
|
||||
# 如果强制创建新实例,直接创建不使用缓存
|
||||
if force_new:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
return client_class(api_provider)
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
|
||||
# 正常的缓存逻辑
|
||||
if api_provider.name not in self.client_instance_cache:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
||||
|
|
|
|||
|
|
@ -44,6 +44,14 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
|||
|
||||
logger = get_logger("Gemini客户端")
|
||||
|
||||
# gemini_thinking参数(默认范围)
|
||||
# 不同模型的思考预算范围配置
|
||||
THINKING_BUDGET_LIMITS = {
|
||||
"gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True},
|
||||
"gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True},
|
||||
"gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False},
|
||||
}
|
||||
|
||||
gemini_safe_settings = [
|
||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||
|
|
@ -328,6 +336,48 @@ class GeminiClient(BaseClient):
|
|||
api_key=api_provider.api_key,
|
||||
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
||||
|
||||
# 思维预算特殊值
|
||||
THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定
|
||||
THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用)
|
||||
|
||||
@staticmethod
|
||||
def clamp_thinking_budget(tb: int, model_id: str):
|
||||
"""
|
||||
按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本)
|
||||
"""
|
||||
limits = None
|
||||
matched_key = None
|
||||
|
||||
# 优先尝试精确匹配
|
||||
if model_id in THINKING_BUDGET_LIMITS:
|
||||
limits = THINKING_BUDGET_LIMITS[model_id]
|
||||
matched_key = model_id
|
||||
else:
|
||||
# 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先
|
||||
sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True)
|
||||
for key in sorted_keys:
|
||||
# 必须满足:完全等于 或者 前缀匹配(带 "-" 边界)
|
||||
if model_id == key or model_id.startswith(key + "-"):
|
||||
limits = THINKING_BUDGET_LIMITS[key]
|
||||
matched_key = key
|
||||
break
|
||||
|
||||
# 特殊值处理
|
||||
if tb == GeminiClient.THINKING_BUDGET_AUTO:
|
||||
return GeminiClient.THINKING_BUDGET_AUTO
|
||||
if tb == GeminiClient.THINKING_BUDGET_DISABLED:
|
||||
if limits and limits.get("can_disable", False):
|
||||
return GeminiClient.THINKING_BUDGET_DISABLED
|
||||
return limits["min"] if limits else GeminiClient.THINKING_BUDGET_AUTO
|
||||
|
||||
# 已知模型裁剪到范围
|
||||
if limits:
|
||||
return max(limits["min"], min(tb, limits["max"]))
|
||||
|
||||
# 未知模型,返回动态模式
|
||||
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
|
||||
return GeminiClient.THINKING_BUDGET_AUTO
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
|
|
@ -373,6 +423,19 @@ class GeminiClient(BaseClient):
|
|||
messages = _convert_messages(message_list)
|
||||
# 将tool_options转换为Gemini API所需的格式
|
||||
tools = _convert_tool_options(tool_options) if tool_options else None
|
||||
|
||||
tb = GeminiClient.THINKING_BUDGET_AUTO
|
||||
#空处理
|
||||
if extra_params and "thinking_budget" in extra_params:
|
||||
try:
|
||||
tb = int(extra_params["thinking_budget"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}"
|
||||
)
|
||||
# 裁剪到模型支持的范围
|
||||
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
||||
|
||||
# 将response_format转换为Gemini API所需的格式
|
||||
generation_config_dict = {
|
||||
"max_output_tokens": max_tokens,
|
||||
|
|
@ -380,11 +443,7 @@ class GeminiClient(BaseClient):
|
|||
"response_modalities": ["TEXT"],
|
||||
"thinking_config": ThinkingConfig(
|
||||
include_thoughts=True,
|
||||
thinking_budget=(
|
||||
extra_params["thinking_budget"]
|
||||
if extra_params and "thinking_budget" in extra_params
|
||||
else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复
|
||||
),
|
||||
thinking_budget=tb,
|
||||
),
|
||||
"safety_settings": gemini_safe_settings, # 防止空回复问题
|
||||
}
|
||||
|
|
|
|||
|
|
@ -388,6 +388,7 @@ class OpenaiClient(BaseClient):
|
|||
base_url=api_provider.base_url,
|
||||
api_key=api_provider.api_key,
|
||||
max_retries=0,
|
||||
timeout=api_provider.timeout,
|
||||
)
|
||||
|
||||
async def get_response(
|
||||
|
|
@ -520,6 +521,11 @@ class OpenaiClient(BaseClient):
|
|||
extra_body=extra_params,
|
||||
)
|
||||
except APIConnectionError as e:
|
||||
# 添加详细的错误信息以便调试
|
||||
logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}")
|
||||
logger.error(f"错误类型: {type(e)}")
|
||||
if hasattr(e, '__cause__') and e.__cause__:
|
||||
logger.error(f"底层错误: {str(e.__cause__)}")
|
||||
raise NetworkConnectionError() from e
|
||||
except APIStatusError as e:
|
||||
# 重封装APIError为RespNotOkException
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ class LLMRequest:
|
|||
|
||||
if not content:
|
||||
if raise_when_empty:
|
||||
logger.warning("生成的响应为空")
|
||||
logger.warning(f"生成的响应为空, 请求类型: {self.request_type}")
|
||||
raise RuntimeError("生成的响应为空")
|
||||
content = "生成的响应为空,请检查模型配置或输入内容是否正确"
|
||||
|
||||
|
|
@ -248,7 +248,11 @@ class LLMRequest:
|
|||
)
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
client = client_registry.get_client_class_instance(api_provider)
|
||||
|
||||
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
|
||||
force_new_client = (self.request_type == "embedding")
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
logger.debug(f"选择请求模型: {model_info.name}")
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
|
||||
|
|
|
|||
|
|
@ -163,13 +163,9 @@ class ChatAction:
|
|||
limit=15,
|
||||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
|
|
@ -230,13 +226,9 @@ class ChatAction:
|
|||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
|
|
|
|||
|
|
@ -166,13 +166,10 @@ class ChatMood:
|
|||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
|
|
@ -248,13 +245,10 @@ class ChatMood:
|
|||
limit=5,
|
||||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
|||
from src.chat.express.expression_selector import expression_selector
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
from typing import List
|
||||
|
||||
logger = get_logger("prompt")
|
||||
|
||||
|
||||
|
|
@ -58,7 +62,7 @@ def init_prompt():
|
|||
""",
|
||||
"s4u_prompt", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
|
|
@ -95,14 +99,13 @@ class PromptBuilder:
|
|||
def __init__(self):
|
||||
self.prompt_built = ""
|
||||
self.activate_messages = ""
|
||||
|
||||
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
|
||||
|
||||
async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target):
|
||||
style_habits = []
|
||||
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions ,_ = await expression_selector.select_suitable_expressions_llm(
|
||||
selected_expressions, _ = await expression_selector.select_suitable_expressions_llm(
|
||||
chat_stream.stream_id, chat_history, max_num=12, target_message=target
|
||||
)
|
||||
|
||||
|
|
@ -122,7 +125,6 @@ class PromptBuilder:
|
|||
if style_habits_str.strip():
|
||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
||||
|
||||
|
||||
return expression_habits_block
|
||||
|
||||
async def build_relation_info(self, chat_stream) -> str:
|
||||
|
|
@ -148,9 +150,7 @@ class PromptBuilder:
|
|||
person_ids.append(person_id)
|
||||
|
||||
# 使用 Person 的 build_relationship 方法,设置 points_num=3 保持与原来相同的行为
|
||||
relation_info_list = [
|
||||
Person(person_id=person_id).build_relationship() for person_id in person_ids
|
||||
]
|
||||
relation_info_list = [Person(person_id=person_id).build_relationship() for person_id in person_ids]
|
||||
if relation_info := "".join(relation_info_list):
|
||||
relation_prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_prompt", relation_info=relation_info
|
||||
|
|
@ -160,7 +160,7 @@ class PromptBuilder:
|
|||
async def build_memory_block(self, text: str) -> str:
|
||||
# 待更新记忆系统
|
||||
return ""
|
||||
|
||||
|
||||
related_memory = await hippocampus_manager.get_memory_from_text(
|
||||
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
|
|
@ -176,38 +176,37 @@ class PromptBuilder:
|
|||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
|
||||
limit=300,
|
||||
)
|
||||
|
||||
|
||||
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
|
||||
|
||||
core_dialogue_list = []
|
||||
background_dialogue_list = []
|
||||
core_dialogue_list: List[DatabaseMessages] = []
|
||||
background_dialogue_list: List[DatabaseMessages] = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||
|
||||
# TODO: 修复之!
|
||||
for msg in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg.user_info.user_id)
|
||||
if msg_user_id == bot_id:
|
||||
if msg.reply_to and talk_type == msg.reply_to:
|
||||
core_dialogue_list.append(msg.__dict__)
|
||||
core_dialogue_list.append(msg)
|
||||
elif msg.reply_to and talk_type != msg.reply_to:
|
||||
background_dialogue_list.append(msg.__dict__)
|
||||
background_dialogue_list.append(msg)
|
||||
# else:
|
||||
# background_dialogue_list.append(msg_dict)
|
||||
# background_dialogue_list.append(msg_dict)
|
||||
elif msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg.__dict__)
|
||||
core_dialogue_list.append(msg)
|
||||
else:
|
||||
background_dialogue_list.append(msg.__dict__)
|
||||
background_dialogue_list.append(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"无法处理历史消息记录: {msg.__dict__}, 错误: {e}")
|
||||
|
||||
background_dialogue_prompt = ""
|
||||
if background_dialogue_list:
|
||||
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length:]
|
||||
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
|
||||
background_dialogue_prompt_str = build_readable_messages(
|
||||
context_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
|
|
@ -217,10 +216,10 @@ class PromptBuilder:
|
|||
|
||||
core_msg_str = ""
|
||||
if core_dialogue_list:
|
||||
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length:]
|
||||
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
|
||||
|
||||
first_msg = core_dialogue_list[0]
|
||||
start_speaking_user_id = first_msg.get("user_id")
|
||||
start_speaking_user_id = first_msg.user_info.user_id
|
||||
if start_speaking_user_id == bot_id:
|
||||
last_speaking_user_id = bot_id
|
||||
msg_seg_str = "你的发言:\n"
|
||||
|
|
@ -229,13 +228,13 @@ class PromptBuilder:
|
|||
last_speaking_user_id = start_speaking_user_id
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
|
||||
|
||||
all_msg_seg_list = []
|
||||
for msg in core_dialogue_list[1:]:
|
||||
speaker = msg.get("user_id")
|
||||
speaker = msg.user_info.user_id
|
||||
if speaker == last_speaking_user_id:
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
|
||||
else:
|
||||
msg_seg_str = f"{msg_seg_str}\n"
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
|
|
@ -252,46 +251,40 @@ class PromptBuilder:
|
|||
for msg in all_msg_seg_list:
|
||||
core_msg_str += msg
|
||||
|
||||
|
||||
all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat(
|
||||
all_dialogue_history = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=20,
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in all_dialogue_prompt]
|
||||
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
tmp_msgs,
|
||||
all_dialogue_history,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
|
||||
|
||||
return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str
|
||||
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
|
||||
|
||||
def build_gift_info(self, message: MessageRecvS4U):
|
||||
if message.is_gift:
|
||||
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
|
||||
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
|
||||
else:
|
||||
if message.is_fake_gift:
|
||||
return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)"
|
||||
|
||||
|
||||
return ""
|
||||
|
||||
def build_sc_info(self, message: MessageRecvS4U):
|
||||
super_chat_manager = get_super_chat_manager()
|
||||
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
|
||||
|
||||
|
||||
async def build_prompt_normal(
|
||||
self,
|
||||
message: MessageRecvS4U,
|
||||
message_txt: str,
|
||||
) -> str:
|
||||
|
||||
chat_stream = message.chat_stream
|
||||
|
||||
|
||||
person = Person(platform=message.chat_stream.user_info.platform, user_id=message.chat_stream.user_info.user_id)
|
||||
person_name = person.person_name
|
||||
|
||||
|
|
@ -302,28 +295,31 @@ class PromptBuilder:
|
|||
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
else:
|
||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
|
||||
|
||||
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
|
||||
self.build_relation_info(chat_stream), self.build_memory_block(message_txt), self.build_expression_habits(chat_stream, message_txt, sender_name)
|
||||
self.build_relation_info(chat_stream),
|
||||
self.build_memory_block(message_txt),
|
||||
self.build_expression_habits(chat_stream, message_txt, sender_name),
|
||||
)
|
||||
|
||||
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts(
|
||||
chat_stream, message
|
||||
)
|
||||
|
||||
core_dialogue_prompt, background_dialogue_prompt,all_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message)
|
||||
|
||||
gift_info = self.build_gift_info(message)
|
||||
|
||||
|
||||
sc_info = self.build_sc_info(message)
|
||||
|
||||
|
||||
screen_info = screen_manager.get_screen_str()
|
||||
|
||||
|
||||
internal_state = internal_manager.get_internal_state_str()
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id)
|
||||
|
||||
template_name = "s4u_prompt"
|
||||
|
||||
|
||||
if not message.is_internal:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
|
|
@ -356,7 +352,7 @@ class PromptBuilder:
|
|||
mind=message.processed_plain_text,
|
||||
mood_state=mood.mood_state,
|
||||
)
|
||||
|
||||
|
||||
# print(prompt)
|
||||
|
||||
return prompt
|
||||
|
|
|
|||
|
|
@ -99,13 +99,10 @@ class ChatMood:
|
|||
limit=int(global_config.chat.max_context_size / 3),
|
||||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复!
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
|
|
@ -151,13 +148,10 @@ class ChatMood:
|
|||
limit=15,
|
||||
limit_mode="last",
|
||||
)
|
||||
# TODO: 修复
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
|
||||
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
tmp_msgs,
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import time
|
|||
import random
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
|
|
@ -253,8 +253,8 @@ class Person:
|
|||
|
||||
# 初始化默认值
|
||||
self.nickname = ""
|
||||
self.person_name = None
|
||||
self.name_reason = None
|
||||
self.person_name: Optional[str] = None
|
||||
self.name_reason: Optional[str] = None
|
||||
self.know_times = 0
|
||||
self.know_since = None
|
||||
self.last_know = None
|
||||
|
|
|
|||
|
|
@ -1,18 +1,21 @@
|
|||
import json
|
||||
import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .person_info import Person
|
||||
import random
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
import traceback
|
||||
from .person_info import Person
|
||||
|
||||
logger = get_logger("relation")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
|
|
@ -45,8 +48,7 @@ def init_prompt():
|
|||
""",
|
||||
"attitude_to_me_prompt",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是{bot_name},{bot_name}的别名是{alias_str}。
|
||||
|
|
@ -80,104 +82,102 @@ def init_prompt():
|
|||
"neuroticism_prompt",
|
||||
)
|
||||
|
||||
|
||||
class RelationshipManager:
|
||||
def __init__(self):
|
||||
self.relationship_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="relationship.person"
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
async def get_attitude_to_me(self, readable_messages, timestamp, person: Person):
|
||||
alias_str = ", ".join(global_config.bot.alias_names)
|
||||
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
# 解析当前态度值
|
||||
current_attitude_score = person.attitude_to_me
|
||||
total_confidence = person.attitude_to_me_confidence
|
||||
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"attitude_to_me_prompt",
|
||||
bot_name = global_config.bot.nickname,
|
||||
alias_str = alias_str,
|
||||
person_name = person.person_name,
|
||||
nickname = person.nickname,
|
||||
readable_messages = readable_messages,
|
||||
current_time = current_time,
|
||||
bot_name=global_config.bot.nickname,
|
||||
alias_str=alias_str,
|
||||
person_name=person.person_name,
|
||||
nickname=person.nickname,
|
||||
readable_messages=readable_messages,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
|
||||
attitude, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
|
||||
attitude = repair_json(attitude)
|
||||
attitude_data = json.loads(attitude)
|
||||
|
||||
|
||||
if not attitude_data or (isinstance(attitude_data, list) and len(attitude_data) == 0):
|
||||
return ""
|
||||
|
||||
|
||||
# 确保 attitude_data 是字典格式
|
||||
if not isinstance(attitude_data, dict):
|
||||
logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(attitude_data)}, 内容: {attitude_data}")
|
||||
return ""
|
||||
|
||||
|
||||
attitude_score = attitude_data["attitude"]
|
||||
confidence = pow(attitude_data["confidence"],2)
|
||||
|
||||
confidence = pow(attitude_data["confidence"], 2)
|
||||
|
||||
new_confidence = total_confidence + confidence
|
||||
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence)/new_confidence
|
||||
|
||||
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence) / new_confidence
|
||||
|
||||
person.attitude_to_me = new_attitude_score
|
||||
person.attitude_to_me_confidence = new_confidence
|
||||
|
||||
|
||||
return person
|
||||
|
||||
|
||||
async def get_neuroticism(self, readable_messages, timestamp, person: Person):
|
||||
alias_str = ", ".join(global_config.bot.alias_names)
|
||||
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
# 解析当前态度值
|
||||
current_neuroticism_score = person.neuroticism
|
||||
total_confidence = person.neuroticism_confidence
|
||||
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"neuroticism_prompt",
|
||||
bot_name = global_config.bot.nickname,
|
||||
alias_str = alias_str,
|
||||
person_name = person.person_name,
|
||||
nickname = person.nickname,
|
||||
readable_messages = readable_messages,
|
||||
current_time = current_time,
|
||||
bot_name=global_config.bot.nickname,
|
||||
alias_str=alias_str,
|
||||
person_name=person.person_name,
|
||||
nickname=person.nickname,
|
||||
readable_messages=readable_messages,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
neuroticism, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
# logger.info(f"neuroticism: {neuroticism}")
|
||||
|
||||
|
||||
neuroticism = repair_json(neuroticism)
|
||||
neuroticism_data = json.loads(neuroticism)
|
||||
|
||||
|
||||
if not neuroticism_data or (isinstance(neuroticism_data, list) and len(neuroticism_data) == 0):
|
||||
return ""
|
||||
|
||||
|
||||
# 确保 neuroticism_data 是字典格式
|
||||
if not isinstance(neuroticism_data, dict):
|
||||
logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(neuroticism_data)}, 内容: {neuroticism_data}")
|
||||
return ""
|
||||
|
||||
|
||||
neuroticism_score = neuroticism_data["neuroticism"]
|
||||
confidence = pow(neuroticism_data["confidence"],2)
|
||||
|
||||
confidence = pow(neuroticism_data["confidence"], 2)
|
||||
|
||||
new_confidence = total_confidence + confidence
|
||||
|
||||
new_neuroticism_score = (current_neuroticism_score * total_confidence + neuroticism_score * confidence)/new_confidence
|
||||
|
||||
|
||||
new_neuroticism_score = (
|
||||
current_neuroticism_score * total_confidence + neuroticism_score * confidence
|
||||
) / new_confidence
|
||||
|
||||
person.neuroticism = new_neuroticism_score
|
||||
person.neuroticism_confidence = new_confidence
|
||||
|
||||
return person
|
||||
|
||||
|
||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
|
||||
return person
|
||||
|
||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]):
|
||||
"""更新用户印象
|
||||
|
||||
Args:
|
||||
|
|
@ -202,12 +202,11 @@ class RelationshipManager:
|
|||
|
||||
# 遍历消息,构建映射
|
||||
for msg in user_messages:
|
||||
if msg.get("user_id") == "system":
|
||||
if msg.user_info.user_id == "system":
|
||||
continue
|
||||
try:
|
||||
|
||||
user_id = msg.get("user_id")
|
||||
platform = msg.get("chat_info_platform")
|
||||
user_id = msg.user_info.user_id
|
||||
platform = msg.chat_info.platform
|
||||
assert isinstance(user_id, str) and isinstance(platform, str)
|
||||
msg_person = Person(user_id=user_id, platform=platform)
|
||||
|
||||
|
|
@ -242,19 +241,16 @@ class RelationshipManager:
|
|||
# 确保 original_name 和 mapped_name 都不为 None
|
||||
if original_name is not None and mapped_name is not None:
|
||||
readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}")
|
||||
|
||||
|
||||
# await self.get_points(
|
||||
# readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person)
|
||||
# readable_messages=readable_messages, name_mapping=name_mapping, timestamp=timestamp, person=person)
|
||||
await self.get_attitude_to_me(readable_messages=readable_messages, timestamp=timestamp, person=person)
|
||||
await self.get_neuroticism(readable_messages=readable_messages, timestamp=timestamp, person=person)
|
||||
|
||||
person.know_times = know_times + 1
|
||||
person.last_know = timestamp
|
||||
|
||||
person.sync_to_database()
|
||||
|
||||
|
||||
|
||||
person.sync_to_database()
|
||||
|
||||
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
|
||||
"""计算基于时间的权重系数"""
|
||||
|
|
@ -280,6 +276,7 @@ class RelationshipManager:
|
|||
logger.error(f"计算时间权重失败: {e}")
|
||||
return 0.5 # 发生错误时返回中等权重
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
relationship_manager = None
|
||||
|
|
@ -290,4 +287,3 @@ def get_relationship_manager():
|
|||
if relationship_manager is None:
|
||||
relationship_manager = RelationshipManager()
|
||||
return relationship_manager
|
||||
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ async def generate_reply(
|
|||
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
choosen_actions=choosen_actions,
|
||||
chosen_actions=choosen_actions,
|
||||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
|
|
|
|||
|
|
@ -294,7 +294,9 @@ def get_messages_before_time_in_chat(
|
|||
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
|
||||
|
||||
|
||||
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[DatabaseMessages]:
|
||||
def get_messages_before_time_for_users(
|
||||
timestamp: float, person_ids: List[str], limit: int = 0
|
||||
) -> List[DatabaseMessages]:
|
||||
"""
|
||||
获取指定用户在指定时间戳之前的消息
|
||||
|
||||
|
|
@ -410,9 +412,8 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
|
|||
|
||||
|
||||
def build_readable_messages_to_str(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
|
|
@ -434,14 +435,13 @@ def build_readable_messages_to_str(
|
|||
格式化后的可读字符串
|
||||
"""
|
||||
return build_readable_messages(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
|
||||
messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions
|
||||
)
|
||||
|
||||
|
||||
async def build_readable_messages_with_details(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
|
|
@ -458,7 +458,7 @@ async def build_readable_messages_with_details(
|
|||
Returns:
|
||||
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
|
||||
"""
|
||||
return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
|
||||
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
|
||||
|
||||
|
||||
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@ import random
|
|||
from typing import Tuple
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system import BaseAction, ActionActivationType
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
||||
|
||||
# NoReplyAction已集成到heartFC_chat.py中,不再需要导入
|
||||
from src.config.config import global_config
|
||||
|
||||
|
|
@ -84,11 +85,8 @@ class EmojiAction(BaseAction):
|
|||
messages_text = ""
|
||||
if recent_messages:
|
||||
# 使用message_api构建可读的消息字符串
|
||||
# TODO: 修复
|
||||
from src.common.data_models import temporarily_transform_class_to_dict
|
||||
tmp_msgs = [temporarily_transform_class_to_dict(msg) for msg in recent_messages]
|
||||
messages_text = message_api.build_readable_messages(
|
||||
messages=tmp_msgs,
|
||||
messages=recent_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[inner]
|
||||
version = "6.4.6"
|
||||
version = "6.5.0"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
|
|
@ -29,6 +29,8 @@ identity = "年龄为19岁,是女孩子,身高为160cm,有黑色的短发"
|
|||
# 描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容
|
||||
reply_style = "回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。不要浮夸,不要夸张修辞。"
|
||||
|
||||
plan_style = "当你刚刚发送了消息,没有人回复时,不要选择action,如果有别的动作(非回复)满足条件,可以选择,当你一次发送了太多消息,为了避免打扰聊天节奏,不要选择动作"
|
||||
|
||||
compress_personality = false # 是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭
|
||||
compress_identity = true # 是否压缩身份,压缩后会精简身份信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果不长,可以关闭
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ version = "1.3.0"
|
|||
|
||||
[[api_providers]] # API服务提供商(可以配置多个)
|
||||
name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名)
|
||||
base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL
|
||||
base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL
|
||||
api_key = "your-api-key-here" # API密钥(请替换为实际的API密钥)
|
||||
client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini")
|
||||
max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数)
|
||||
|
|
|
|||
|
|
@ -1,73 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试del_memory函数的脚本
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加src目录到Python路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
|
||||
from person_info.person_info import Person
|
||||
|
||||
def test_del_memory():
|
||||
"""测试del_memory函数"""
|
||||
print("开始测试del_memory函数...")
|
||||
|
||||
# 创建一个测试用的Person实例(不连接数据库)
|
||||
person = Person.__new__(Person)
|
||||
person.person_id = "test_person"
|
||||
person.memory_points = [
|
||||
"性格:这个人很友善:5.0",
|
||||
"性格:这个人很友善:4.0",
|
||||
"爱好:喜欢打游戏:3.0",
|
||||
"爱好:喜欢打游戏:2.0",
|
||||
"工作:是一名程序员:1.0",
|
||||
"性格:这个人很友善:6.0"
|
||||
]
|
||||
|
||||
print(f"原始记忆点数量: {len(person.memory_points)}")
|
||||
print("原始记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
# 测试删除"性格"分类中"这个人很友善"的记忆
|
||||
print("\n测试1: 删除'性格'分类中'这个人很友善'的记忆")
|
||||
deleted_count = person.del_memory("性格", "这个人很友善")
|
||||
print(f"删除了 {deleted_count} 个记忆点")
|
||||
print("删除后的记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
# 测试删除"爱好"分类中"喜欢打游戏"的记忆
|
||||
print("\n测试2: 删除'爱好'分类中'喜欢打游戏'的记忆")
|
||||
deleted_count = person.del_memory("爱好", "喜欢打游戏")
|
||||
print(f"删除了 {deleted_count} 个记忆点")
|
||||
print("删除后的记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
# 测试相似度匹配
|
||||
print("\n测试3: 测试相似度匹配")
|
||||
person.memory_points = [
|
||||
"性格:这个人非常友善:5.0",
|
||||
"性格:这个人很友善:4.0",
|
||||
"性格:这个人友善:3.0"
|
||||
]
|
||||
print("原始记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
# 删除"这个人很友善"(应该匹配"这个人很友善"和"这个人友善")
|
||||
deleted_count = person.del_memory("性格", "这个人很友善", similarity_threshold=0.8)
|
||||
print(f"删除了 {deleted_count} 个记忆点")
|
||||
print("删除后的记忆点:")
|
||||
for i, memory in enumerate(person.memory_points):
|
||||
print(f" {i+1}. {memory}")
|
||||
|
||||
print("\n测试完成!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_del_memory()
|
||||
|
|
@ -1,124 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试修复后的memory_points处理
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加src目录到Python路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
|
||||
from person_info.person_info import Person
|
||||
|
||||
def test_memory_points_with_none():
|
||||
"""测试包含None值的memory_points处理"""
|
||||
print("测试包含None值的memory_points处理...")
|
||||
|
||||
# 创建一个测试Person实例
|
||||
person = Person(person_id="test_user_123")
|
||||
|
||||
# 模拟包含None值的memory_points
|
||||
person.memory_points = [
|
||||
"喜好:喜欢咖啡:1.0",
|
||||
None, # 模拟None值
|
||||
"性格:开朗:1.0",
|
||||
None, # 模拟另一个None值
|
||||
"兴趣:编程:1.0"
|
||||
]
|
||||
|
||||
print(f"原始memory_points: {person.memory_points}")
|
||||
|
||||
# 测试get_all_category方法
|
||||
try:
|
||||
categories = person.get_all_category()
|
||||
print(f"获取到的分类: {categories}")
|
||||
print("✓ get_all_category方法正常工作")
|
||||
except Exception as e:
|
||||
print(f"✗ get_all_category方法出错: {e}")
|
||||
return False
|
||||
|
||||
# 测试get_memory_list_by_category方法
|
||||
try:
|
||||
memories = person.get_memory_list_by_category("喜好")
|
||||
print(f"获取到的喜好记忆: {memories}")
|
||||
print("✓ get_memory_list_by_category方法正常工作")
|
||||
except Exception as e:
|
||||
print(f"✗ get_memory_list_by_category方法出错: {e}")
|
||||
return False
|
||||
|
||||
# 测试del_memory方法
|
||||
try:
|
||||
deleted_count = person.del_memory("喜好", "喜欢咖啡")
|
||||
print(f"删除的记忆点数量: {deleted_count}")
|
||||
print(f"删除后的memory_points: {person.memory_points}")
|
||||
print("✓ del_memory方法正常工作")
|
||||
except Exception as e:
|
||||
print(f"✗ del_memory方法出错: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_memory_points_empty():
|
||||
"""测试空的memory_points处理"""
|
||||
print("\n测试空的memory_points处理...")
|
||||
|
||||
person = Person(person_id="test_user_456")
|
||||
person.memory_points = []
|
||||
|
||||
try:
|
||||
categories = person.get_all_category()
|
||||
print(f"空列表的分类: {categories}")
|
||||
print("✓ 空列表处理正常")
|
||||
except Exception as e:
|
||||
print(f"✗ 空列表处理出错: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
memories = person.get_memory_list_by_category("测试分类")
|
||||
print(f"空列表的记忆: {memories}")
|
||||
print("✓ 空列表分类查询正常")
|
||||
except Exception as e:
|
||||
print(f"✗ 空列表分类查询出错: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_memory_points_all_none():
|
||||
"""测试全部为None的memory_points处理"""
|
||||
print("\n测试全部为None的memory_points处理...")
|
||||
|
||||
person = Person(person_id="test_user_789")
|
||||
person.memory_points = [None, None, None]
|
||||
|
||||
try:
|
||||
categories = person.get_all_category()
|
||||
print(f"全None列表的分类: {categories}")
|
||||
print("✓ 全None列表处理正常")
|
||||
except Exception as e:
|
||||
print(f"✗ 全None列表处理出错: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
memories = person.get_memory_list_by_category("测试分类")
|
||||
print(f"全None列表的记忆: {memories}")
|
||||
print("✓ 全None列表分类查询正常")
|
||||
except Exception as e:
|
||||
print(f"✗ 全None列表分类查询出错: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始测试修复后的memory_points处理...")
|
||||
|
||||
success = True
|
||||
success &= test_memory_points_with_none()
|
||||
success &= test_memory_points_empty()
|
||||
success &= test_memory_points_all_none()
|
||||
|
||||
if success:
|
||||
print("\n🎉 所有测试通过!memory_points的None值处理已修复。")
|
||||
else:
|
||||
print("\n❌ 部分测试失败,需要进一步检查。")
|
||||
Loading…
Reference in New Issue