mirror of https://github.com/Mai-with-u/MaiBot.git
Enhance private chat handling with improved memory and state management
parent
570a9418ab
commit
45172e9dca
|
|
@ -50,17 +50,46 @@ def init_prompt():
|
|||
|
||||
Prompt(
|
||||
"""
|
||||
你正在和{sender_name}进行私聊。以下是聊天内容:
|
||||
你正在和{sender_name}进行私聊。
|
||||
|
||||
### 聊天背景
|
||||
- 对方身份:{sender_name}
|
||||
- 关系程度:{relationship_level}
|
||||
- 最近互动:{recent_interactions}
|
||||
- 对方情绪:{sender_emotion}
|
||||
|
||||
### 当前聊天记录
|
||||
{chat_info}
|
||||
|
||||
你的名字是{bot_name},{prompt_personality}。对于"{target_message}",你想表达:{in_mind_reply},原因是:{reason}。
|
||||
### 你的状态
|
||||
- 你的名字是{bot_name}
|
||||
- {prompt_personality}
|
||||
- 当前情绪:{current_mood}
|
||||
|
||||
请以友好、直接的方式回复。你的回复应该:
|
||||
1. 简洁明了,直接表达核心意思
|
||||
2. 语气友好但不过分亲密
|
||||
3. 保持适当的礼貌和距离感
|
||||
4. 避免过于复杂的表达方式
|
||||
5. 不要使用群聊中的梗或过于随意的表达
|
||||
### 对话焦点
|
||||
- 触发消息:"{target_message}"
|
||||
- 你想表达:{in_mind_reply}
|
||||
- 原因:{reason}
|
||||
|
||||
### 回复指南
|
||||
1. 根据关系程度调整语气:
|
||||
- 陌生:保持礼貌客气
|
||||
- 熟悉:适度亲和
|
||||
- 亲密:自然随和
|
||||
2. 考虑对方情绪:
|
||||
- 积极:可以更活泼
|
||||
- 消极:需要更温和
|
||||
- 中性:保持平和
|
||||
3. 回复要求:
|
||||
- 简洁明了,直接表达核心意思
|
||||
- 语气友好但不过分亲密
|
||||
- 保持适当的礼貌和距离感
|
||||
- 避免过于复杂的表达方式
|
||||
- 不要使用群聊中的梗或过于随意的表达
|
||||
4. 特殊注意:
|
||||
- 如果有上下文关联,要自然承接
|
||||
- 如果是新话题,要平滑过渡
|
||||
- 如果对方情绪明显,要适当回应
|
||||
|
||||
{config_expression_style}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包等),只输出一条回复就好。
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from src.chat.focus_chat.planners.action_manager import ActionManager
|
|||
from json_repair import repair_json
|
||||
from src.chat.focus_chat.info.chat_info import ChattingInfo
|
||||
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
|
||||
from src.chat.focus_chat.info.message_recv import MessageRecv
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
|
|
@ -99,48 +100,109 @@ class ActionPlanner:
|
|||
Returns:
|
||||
Dict[str, Any]: 包含action_type, action_data和reasoning的字典
|
||||
"""
|
||||
# 提取聊天信息
|
||||
chatting_info = None
|
||||
working_memory_info = None
|
||||
for info in all_plan_info:
|
||||
if isinstance(info, ChattingInfo):
|
||||
chatting_info = info
|
||||
elif isinstance(info, WorkingMemoryInfo):
|
||||
working_memory_info = info
|
||||
try:
|
||||
# 提取聊天信息
|
||||
chatting_info = None
|
||||
working_memory_info = None
|
||||
for info in all_plan_info:
|
||||
if isinstance(info, ChattingInfo):
|
||||
chatting_info = info
|
||||
elif isinstance(info, WorkingMemoryInfo):
|
||||
working_memory_info = info
|
||||
|
||||
if not chatting_info:
|
||||
# 基本错误检查
|
||||
if not chatting_info:
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"action_data": {},
|
||||
"reasoning": "没有检测到新的聊天信息",
|
||||
}
|
||||
|
||||
# 检查是否有新消息需要回复
|
||||
if not chatting_info.new_messages:
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"action_data": {},
|
||||
"reasoning": "没有新消息需要回复",
|
||||
}
|
||||
|
||||
# 获取最新消息
|
||||
latest_message = chatting_info.new_messages[-1]
|
||||
|
||||
try:
|
||||
# 检查消息有效性
|
||||
if not latest_message.processed_plain_text.strip():
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"action_data": {},
|
||||
"reasoning": "消息内容为空",
|
||||
}
|
||||
|
||||
# 检查是否是系统消息或其他特殊消息
|
||||
if latest_message.message_info.is_system_message:
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"action_data": {},
|
||||
"reasoning": "系统消息,无需回复",
|
||||
}
|
||||
|
||||
# 检查工作记忆中是否有相关上下文
|
||||
context = ""
|
||||
if working_memory_info and working_memory_info.related_memory:
|
||||
context = working_memory_info.related_memory
|
||||
|
||||
# 检查是否需要停止专注聊天
|
||||
if self._should_stop_focus_chat(latest_message):
|
||||
return {
|
||||
"action_type": "stop_focus_chat",
|
||||
"action_data": {},
|
||||
"reasoning": "检测到停止专注聊天的信号",
|
||||
}
|
||||
|
||||
# 构建回复动作
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"action_data": {
|
||||
"message": latest_message,
|
||||
"context": context,
|
||||
"style": "direct", # 私聊使用直接回复风格
|
||||
"target": latest_message.processed_plain_text, # 添加目标消息
|
||||
"text": self._generate_reply_text(latest_message, context), # 生成回复文本
|
||||
},
|
||||
"reasoning": "收到新的私聊消息,直接回复",
|
||||
}
|
||||
|
||||
except Exception as msg_e:
|
||||
logger.error(f"处理消息时出错: {msg_e}")
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"action_data": {},
|
||||
"reasoning": f"处理消息时出错: {str(msg_e)}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"简单规划器出错: {e}")
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"action_data": {},
|
||||
"reasoning": "没有检测到新的聊天信息",
|
||||
"reasoning": f"规划器出错: {str(e)}",
|
||||
}
|
||||
|
||||
# 检查是否有新消息需要回复
|
||||
if not chatting_info.new_messages:
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"action_data": {},
|
||||
"reasoning": "没有新消息需要回复",
|
||||
}
|
||||
def _should_stop_focus_chat(self, message: MessageRecv) -> bool:
|
||||
"""
|
||||
检查是否应该停止专注聊天
|
||||
"""
|
||||
# 检查消息内容是否包含停止关键词
|
||||
stop_keywords = ["再见", "拜拜", "88", "bye", "goodbye", "晚安"]
|
||||
return any(keyword in message.processed_plain_text.lower() for keyword in stop_keywords)
|
||||
|
||||
# 获取最新消息
|
||||
latest_message = chatting_info.new_messages[-1]
|
||||
|
||||
# 检查工作记忆中是否有相关上下文
|
||||
context = ""
|
||||
if working_memory_info and working_memory_info.related_memory:
|
||||
context = working_memory_info.related_memory
|
||||
|
||||
# 构建回复动作
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"action_data": {
|
||||
"message": latest_message,
|
||||
"context": context,
|
||||
"style": "direct", # 私聊使用直接回复风格
|
||||
},
|
||||
"reasoning": "收到新的私聊消息,直接回复",
|
||||
}
|
||||
def _generate_reply_text(self, message: MessageRecv, context: str) -> str:
|
||||
"""
|
||||
根据消息和上下文生成回复文本
|
||||
"""
|
||||
# 这里可以添加更复杂的回复生成逻辑
|
||||
# 目前简单返回消息内容
|
||||
return message.processed_plain_text
|
||||
|
||||
async def plan(self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import asyncio
|
|||
import random
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
import time
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -15,41 +17,135 @@ class WorkingMemory:
|
|||
从属于特定的流,用chat_id来标识
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
|
||||
"""
|
||||
初始化工作记忆管理器
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.is_group_chat = None # 将在initialize时设置
|
||||
self._memory_items = []
|
||||
self._max_items = 50 # 默认最大条目数
|
||||
self._initialized = False
|
||||
|
||||
Args:
|
||||
max_memories_per_chat: 每个聊天的最大记忆数量
|
||||
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
|
||||
"""
|
||||
self.memory_manager = MemoryManager(chat_id)
|
||||
async def initialize(self):
|
||||
"""初始化工作记忆,确定聊天类型"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 记忆容量上限
|
||||
self.max_memories_per_chat = max_memories_per_chat
|
||||
# 获取聊天类型
|
||||
chat_stream = await asyncio.to_thread(chat_manager.get_stream, self.chat_id)
|
||||
self.is_group_chat = chat_stream.group_info is not None
|
||||
|
||||
# 自动衰减间隔
|
||||
self.auto_decay_interval = auto_decay_interval
|
||||
# 根据聊天类型调整参数
|
||||
if not self.is_group_chat:
|
||||
self._max_items = 20 # 私聊保持较少的记忆条目
|
||||
self._memory_cleanup_threshold = 0.8 # 私聊更激进的清理阈值
|
||||
else:
|
||||
self._max_items = 50 # 群聊保持更多记忆条目
|
||||
self._memory_cleanup_threshold = 0.9 # 群聊较保守的清理阈值
|
||||
|
||||
# 衰减任务
|
||||
self.decay_task = None
|
||||
self._initialized = True
|
||||
|
||||
# 启动自动衰减任务
|
||||
self._start_auto_decay()
|
||||
async def add_memory_item(self, item: MemoryItem):
|
||||
"""添加新的记忆项"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
def _start_auto_decay(self):
|
||||
"""启动自动衰减任务"""
|
||||
if self.decay_task is None:
|
||||
self.decay_task = asyncio.create_task(self._auto_decay_loop())
|
||||
# 检查是否需要清理
|
||||
if len(self._memory_items) >= self._max_items * self._memory_cleanup_threshold:
|
||||
await self._cleanup_old_memories()
|
||||
|
||||
async def _auto_decay_loop(self):
|
||||
"""自动衰减循环"""
|
||||
while True:
|
||||
await asyncio.sleep(self.auto_decay_interval)
|
||||
try:
|
||||
await self.decay_all_memories()
|
||||
except Exception as e:
|
||||
print(f"自动衰减记忆时出错: {str(e)}")
|
||||
# 私聊时进行额外的重复检查
|
||||
if not self.is_group_chat:
|
||||
# 检查是否有相似内容,有则更新而不是添加
|
||||
for existing_item in self._memory_items:
|
||||
if self._is_similar_content(existing_item.content, item.content):
|
||||
existing_item.update_with(item)
|
||||
return
|
||||
|
||||
self._memory_items.append(item)
|
||||
if len(self._memory_items) > self._max_items:
|
||||
self._memory_items.pop(0)
|
||||
|
||||
def _is_similar_content(self, content1: str, content2: str) -> bool:
|
||||
"""检查两个内容是否相似"""
|
||||
# 这里可以实现更复杂的相似度检查
|
||||
# 目前简单实现
|
||||
return content1.strip() == content2.strip()
|
||||
|
||||
async def _cleanup_old_memories(self):
|
||||
"""清理旧的记忆项"""
|
||||
if not self._memory_items:
|
||||
return
|
||||
|
||||
if not self.is_group_chat:
|
||||
# 私聊:保留最近的对话和重要的记忆
|
||||
self._memory_items = [
|
||||
item for item in self._memory_items
|
||||
if (time.time() - item.timestamp < 3600) or item.importance > 0.7
|
||||
]
|
||||
else:
|
||||
# 群聊:使用现有的清理逻辑
|
||||
# 按重要性和时间排序
|
||||
self._memory_items.sort(key=lambda x: (x.importance, -x.timestamp), reverse=True)
|
||||
# 保留前80%
|
||||
keep_count = int(self._max_items * 0.8)
|
||||
self._memory_items = self._memory_items[:keep_count]
|
||||
|
||||
async def get_related_memory(self, query: str, limit: int = 5) -> List[MemoryItem]:
|
||||
"""获取与查询相关的记忆项"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not self._memory_items:
|
||||
return []
|
||||
|
||||
# 根据聊天类型调整搜索策略
|
||||
if not self.is_group_chat:
|
||||
# 私聊:优先考虑最近的对话
|
||||
recent_limit = min(limit * 2, len(self._memory_items))
|
||||
recent_items = self._memory_items[-recent_limit:]
|
||||
|
||||
# 计算相关性分数
|
||||
scored_items = []
|
||||
for item in recent_items:
|
||||
score = self._calculate_relevance(query, item.content)
|
||||
if score > 0.3: # 提高私聊的相关性阈值
|
||||
scored_items.append((score, item))
|
||||
|
||||
# 按分数排序并返回前limit个
|
||||
scored_items.sort(reverse=True)
|
||||
return [item for score, item in scored_items[:limit]]
|
||||
else:
|
||||
# 群聊:使用现有的搜索逻辑
|
||||
scored_items = []
|
||||
for item in self._memory_items:
|
||||
score = self._calculate_relevance(query, item.content)
|
||||
if score > 0.2:
|
||||
scored_items.append((score, item))
|
||||
|
||||
scored_items.sort(reverse=True)
|
||||
return [item for score, item in scored_items[:limit]]
|
||||
|
||||
def _calculate_relevance(self, query: str, content: str) -> float:
|
||||
"""计算查询与内容的相关性分数"""
|
||||
# 这里可以实现更复杂的相关性计算
|
||||
# 目前使用简单的包含关系检查
|
||||
query_words = set(query.split())
|
||||
content_words = set(content.split())
|
||||
common_words = query_words & content_words
|
||||
if not query_words:
|
||||
return 0.0
|
||||
return len(common_words) / len(query_words)
|
||||
|
||||
async def clear_memory(self):
|
||||
"""清空工作记忆"""
|
||||
self._memory_items = []
|
||||
|
||||
def get_all_memories(self) -> List[MemoryItem]:
|
||||
"""获取所有记忆项"""
|
||||
return self._memory_items.copy()
|
||||
|
||||
def get_memory_count(self) -> int:
|
||||
"""获取当前记忆项数量"""
|
||||
return len(self._memory_items)
|
||||
|
||||
async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None):
|
||||
"""
|
||||
|
|
@ -181,12 +277,3 @@ class WorkingMemory:
|
|||
await self.decay_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def get_all_memories(self) -> List[MemoryItem]:
|
||||
"""
|
||||
获取所有记忆项目
|
||||
|
||||
Returns:
|
||||
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
|
||||
"""
|
||||
return self.memory_manager.get_all_items()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,133 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageInfo, UserInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
@pytest.fixture
|
||||
async def subheartflow_manager():
|
||||
manager = SubHeartflowManager()
|
||||
yield manager
|
||||
# 清理所有子心流
|
||||
for subflow in manager.get_all_subheartflows():
|
||||
await subflow.shutdown()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_stream():
|
||||
stream = MagicMock(spec=ChatStream)
|
||||
stream.group_info = None # 私聊没有群信息
|
||||
stream.platform = "test"
|
||||
return stream
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message():
|
||||
message = MagicMock(spec=MessageRecv)
|
||||
message.message_info = MagicMock(spec=MessageInfo)
|
||||
message.message_info.user_info = MagicMock(spec=UserInfo)
|
||||
message.message_info.message_id = "test_msg_1"
|
||||
message.message_info.time = 1234567890
|
||||
message.message_info.platform = "test"
|
||||
message.processed_plain_text = "测试消息"
|
||||
return message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_chat_absent_to_focused(subheartflow_manager, mock_chat_stream, mock_message):
|
||||
"""测试私聊从ABSENT到FOCUSED的转换"""
|
||||
# 创建私聊子心流
|
||||
subflow_id = "private_test_1"
|
||||
subflow = await subheartflow_manager.get_or_create_subheartflow(subflow_id)
|
||||
assert subflow is not None
|
||||
|
||||
# 确认初始状态为ABSENT
|
||||
assert subflow.chat_state.chat_status == ChatState.ABSENT
|
||||
|
||||
# 模拟新消息
|
||||
with patch("src.chat.heart_flow.observation.chatting_observation.ChattingObservation") as mock_obs:
|
||||
mock_obs_instance = mock_obs.return_value
|
||||
mock_obs_instance.has_new_messages_since.return_value = True
|
||||
|
||||
# 触发状态检查
|
||||
await subheartflow_manager.sbhf_absent_private_into_focus()
|
||||
|
||||
# 验证状态已转换为FOCUSED
|
||||
assert subflow.chat_state.chat_status == ChatState.FOCUSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_chat_focused_to_absent(subheartflow_manager, mock_chat_stream, mock_message):
|
||||
"""测试私聊从FOCUSED到ABSENT的转换"""
|
||||
# 创建私聊子心流
|
||||
subflow_id = "private_test_2"
|
||||
subflow = await subheartflow_manager.get_or_create_subheartflow(subflow_id)
|
||||
assert subflow is not None
|
||||
|
||||
# 设置初始状态为FOCUSED
|
||||
await subflow.change_chat_state(ChatState.FOCUSED)
|
||||
assert subflow.chat_state.chat_status == ChatState.FOCUSED
|
||||
|
||||
# 触发状态转换
|
||||
await subheartflow_manager.sbhf_focus_into_normal(subflow_id)
|
||||
|
||||
# 验证状态已转换为ABSENT(私聊特殊处理)
|
||||
assert subflow.chat_state.chat_status == ChatState.ABSENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_chat_stop_keywords(subheartflow_manager, mock_chat_stream, mock_message):
|
||||
"""测试私聊停止关键词触发状态转换"""
|
||||
# 创建私聊子心流
|
||||
subflow_id = "private_test_3"
|
||||
subflow = await subheartflow_manager.get_or_create_subheartflow(subflow_id)
|
||||
assert subflow is not None
|
||||
|
||||
# 设置初始状态为FOCUSED
|
||||
await subflow.change_chat_state(ChatState.FOCUSED)
|
||||
|
||||
# 模拟包含停止关键词的消息
|
||||
mock_message.processed_plain_text = "再见啦"
|
||||
|
||||
# 模拟处理消息
|
||||
with patch("src.chat.focus_chat.planners.planner.ActionPlanner") as mock_planner:
|
||||
mock_planner_instance = mock_planner.return_value
|
||||
mock_planner_instance.plan_simple.return_value = {
|
||||
"action_type": "stop_focus_chat",
|
||||
"action_data": {},
|
||||
"reasoning": "检测到停止专注聊天的信号"
|
||||
}
|
||||
|
||||
# 触发状态转换
|
||||
await subheartflow_manager.sbhf_focus_into_normal(subflow_id)
|
||||
|
||||
# 验证状态已转换为ABSENT
|
||||
assert subflow.chat_state.chat_status == ChatState.ABSENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_chat_error_recovery(subheartflow_manager, mock_chat_stream, mock_message):
|
||||
"""测试私聊错误恢复"""
|
||||
# 创建私聊子心流
|
||||
subflow_id = "private_test_4"
|
||||
subflow = await subheartflow_manager.get_or_create_subheartflow(subflow_id)
|
||||
assert subflow is not None
|
||||
|
||||
# 设置初始状态为FOCUSED
|
||||
await subflow.change_chat_state(ChatState.FOCUSED)
|
||||
|
||||
# 模拟处理器错误
|
||||
with patch("src.chat.focus_chat.heartFC_chat.HeartFChatting._process_processors") as mock_proc:
|
||||
mock_proc.side_effect = Exception("处理器错误")
|
||||
|
||||
# 触发处理
|
||||
# 注:这里需要实际调用处理逻辑
|
||||
|
||||
# 验证状态保持为FOCUSED(错误不应导致状态改变)
|
||||
assert subflow.chat_state.chat_status == ChatState.FOCUSED
|
||||
|
||||
# 验证错误后的下一次处理仍然正常
|
||||
mock_proc.side_effect = None # 清除错误
|
||||
mock_proc.return_value = ([], {}) # 返回空结果
|
||||
|
||||
# 再次触发处理
|
||||
# 注:这里需要实际调用处理逻辑
|
||||
|
||||
# 验证状态仍然正常
|
||||
assert subflow.chat_state.chat_status == ChatState.FOCUSED
|
||||
Loading…
Reference in New Issue