mirror of https://github.com/Mai-with-u/MaiBot.git
Enhance private chat handling with improved memory and state management
parent
570a9418ab
commit
45172e9dca
|
|
@ -50,17 +50,46 @@ def init_prompt():
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
你正在和{sender_name}进行私聊。以下是聊天内容:
|
你正在和{sender_name}进行私聊。
|
||||||
|
|
||||||
|
### 聊天背景
|
||||||
|
- 对方身份:{sender_name}
|
||||||
|
- 关系程度:{relationship_level}
|
||||||
|
- 最近互动:{recent_interactions}
|
||||||
|
- 对方情绪:{sender_emotion}
|
||||||
|
|
||||||
|
### 当前聊天记录
|
||||||
{chat_info}
|
{chat_info}
|
||||||
|
|
||||||
你的名字是{bot_name},{prompt_personality}。对于"{target_message}",你想表达:{in_mind_reply},原因是:{reason}。
|
### 你的状态
|
||||||
|
- 你的名字是{bot_name}
|
||||||
|
- {prompt_personality}
|
||||||
|
- 当前情绪:{current_mood}
|
||||||
|
|
||||||
请以友好、直接的方式回复。你的回复应该:
|
### 对话焦点
|
||||||
1. 简洁明了,直接表达核心意思
|
- 触发消息:"{target_message}"
|
||||||
2. 语气友好但不过分亲密
|
- 你想表达:{in_mind_reply}
|
||||||
3. 保持适当的礼貌和距离感
|
- 原因:{reason}
|
||||||
4. 避免过于复杂的表达方式
|
|
||||||
5. 不要使用群聊中的梗或过于随意的表达
|
### 回复指南
|
||||||
|
1. 根据关系程度调整语气:
|
||||||
|
- 陌生:保持礼貌客气
|
||||||
|
- 熟悉:适度亲和
|
||||||
|
- 亲密:自然随和
|
||||||
|
2. 考虑对方情绪:
|
||||||
|
- 积极:可以更活泼
|
||||||
|
- 消极:需要更温和
|
||||||
|
- 中性:保持平和
|
||||||
|
3. 回复要求:
|
||||||
|
- 简洁明了,直接表达核心意思
|
||||||
|
- 语气友好但不过分亲密
|
||||||
|
- 保持适当的礼貌和距离感
|
||||||
|
- 避免过于复杂的表达方式
|
||||||
|
- 不要使用群聊中的梗或过于随意的表达
|
||||||
|
4. 特殊注意:
|
||||||
|
- 如果有上下文关联,要自然承接
|
||||||
|
- 如果是新话题,要平滑过渡
|
||||||
|
- 如果对方情绪明显,要适当回应
|
||||||
|
|
||||||
{config_expression_style}
|
{config_expression_style}
|
||||||
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包等),只输出一条回复就好。
|
不要输出多余内容(包括前后缀,冒号和引号,括号,表情包等),只输出一条回复就好。
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from src.chat.focus_chat.info.chat_info import ChattingInfo
|
from src.chat.focus_chat.info.chat_info import ChattingInfo
|
||||||
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
|
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
|
||||||
|
from src.chat.focus_chat.info.message_recv import MessageRecv
|
||||||
|
|
||||||
logger = get_logger("planner")
|
logger = get_logger("planner")
|
||||||
|
|
||||||
|
|
@ -99,48 +100,109 @@ class ActionPlanner:
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 包含action_type, action_data和reasoning的字典
|
Dict[str, Any]: 包含action_type, action_data和reasoning的字典
|
||||||
"""
|
"""
|
||||||
# 提取聊天信息
|
try:
|
||||||
chatting_info = None
|
# 提取聊天信息
|
||||||
working_memory_info = None
|
chatting_info = None
|
||||||
for info in all_plan_info:
|
working_memory_info = None
|
||||||
if isinstance(info, ChattingInfo):
|
for info in all_plan_info:
|
||||||
chatting_info = info
|
if isinstance(info, ChattingInfo):
|
||||||
elif isinstance(info, WorkingMemoryInfo):
|
chatting_info = info
|
||||||
working_memory_info = info
|
elif isinstance(info, WorkingMemoryInfo):
|
||||||
|
working_memory_info = info
|
||||||
|
|
||||||
if not chatting_info:
|
# 基本错误检查
|
||||||
|
if not chatting_info:
|
||||||
|
return {
|
||||||
|
"action_type": "no_reply",
|
||||||
|
"action_data": {},
|
||||||
|
"reasoning": "没有检测到新的聊天信息",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查是否有新消息需要回复
|
||||||
|
if not chatting_info.new_messages:
|
||||||
|
return {
|
||||||
|
"action_type": "no_reply",
|
||||||
|
"action_data": {},
|
||||||
|
"reasoning": "没有新消息需要回复",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取最新消息
|
||||||
|
latest_message = chatting_info.new_messages[-1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 检查消息有效性
|
||||||
|
if not latest_message.processed_plain_text.strip():
|
||||||
|
return {
|
||||||
|
"action_type": "no_reply",
|
||||||
|
"action_data": {},
|
||||||
|
"reasoning": "消息内容为空",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查是否是系统消息或其他特殊消息
|
||||||
|
if latest_message.message_info.is_system_message:
|
||||||
|
return {
|
||||||
|
"action_type": "no_reply",
|
||||||
|
"action_data": {},
|
||||||
|
"reasoning": "系统消息,无需回复",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查工作记忆中是否有相关上下文
|
||||||
|
context = ""
|
||||||
|
if working_memory_info and working_memory_info.related_memory:
|
||||||
|
context = working_memory_info.related_memory
|
||||||
|
|
||||||
|
# 检查是否需要停止专注聊天
|
||||||
|
if self._should_stop_focus_chat(latest_message):
|
||||||
|
return {
|
||||||
|
"action_type": "stop_focus_chat",
|
||||||
|
"action_data": {},
|
||||||
|
"reasoning": "检测到停止专注聊天的信号",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建回复动作
|
||||||
|
return {
|
||||||
|
"action_type": "reply",
|
||||||
|
"action_data": {
|
||||||
|
"message": latest_message,
|
||||||
|
"context": context,
|
||||||
|
"style": "direct", # 私聊使用直接回复风格
|
||||||
|
"target": latest_message.processed_plain_text, # 添加目标消息
|
||||||
|
"text": self._generate_reply_text(latest_message, context), # 生成回复文本
|
||||||
|
},
|
||||||
|
"reasoning": "收到新的私聊消息,直接回复",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as msg_e:
|
||||||
|
logger.error(f"处理消息时出错: {msg_e}")
|
||||||
|
return {
|
||||||
|
"action_type": "no_reply",
|
||||||
|
"action_data": {},
|
||||||
|
"reasoning": f"处理消息时出错: {str(msg_e)}",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"简单规划器出错: {e}")
|
||||||
return {
|
return {
|
||||||
"action_type": "no_reply",
|
"action_type": "no_reply",
|
||||||
"action_data": {},
|
"action_data": {},
|
||||||
"reasoning": "没有检测到新的聊天信息",
|
"reasoning": f"规划器出错: {str(e)}",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查是否有新消息需要回复
|
def _should_stop_focus_chat(self, message: MessageRecv) -> bool:
|
||||||
if not chatting_info.new_messages:
|
"""
|
||||||
return {
|
检查是否应该停止专注聊天
|
||||||
"action_type": "no_reply",
|
"""
|
||||||
"action_data": {},
|
# 检查消息内容是否包含停止关键词
|
||||||
"reasoning": "没有新消息需要回复",
|
stop_keywords = ["再见", "拜拜", "88", "bye", "goodbye", "晚安"]
|
||||||
}
|
return any(keyword in message.processed_plain_text.lower() for keyword in stop_keywords)
|
||||||
|
|
||||||
# 获取最新消息
|
def _generate_reply_text(self, message: MessageRecv, context: str) -> str:
|
||||||
latest_message = chatting_info.new_messages[-1]
|
"""
|
||||||
|
根据消息和上下文生成回复文本
|
||||||
# 检查工作记忆中是否有相关上下文
|
"""
|
||||||
context = ""
|
# 这里可以添加更复杂的回复生成逻辑
|
||||||
if working_memory_info and working_memory_info.related_memory:
|
# 目前简单返回消息内容
|
||||||
context = working_memory_info.related_memory
|
return message.processed_plain_text
|
||||||
|
|
||||||
# 构建回复动作
|
|
||||||
return {
|
|
||||||
"action_type": "reply",
|
|
||||||
"action_data": {
|
|
||||||
"message": latest_message,
|
|
||||||
"context": context,
|
|
||||||
"style": "direct", # 私聊使用直接回复风格
|
|
||||||
},
|
|
||||||
"reasoning": "收到新的私聊消息,直接回复",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def plan(self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]]) -> Dict[str, Any]:
|
async def plan(self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ import asyncio
|
||||||
import random
|
import random
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
||||||
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
|
import time
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -15,41 +17,135 @@ class WorkingMemory:
|
||||||
从属于特定的流,用chat_id来标识
|
从属于特定的流,用chat_id来标识
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
|
def __init__(self, chat_id: str):
|
||||||
"""
|
self.chat_id = chat_id
|
||||||
初始化工作记忆管理器
|
self.is_group_chat = None # 将在initialize时设置
|
||||||
|
self._memory_items = []
|
||||||
|
self._max_items = 50 # 默认最大条目数
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
Args:
|
async def initialize(self):
|
||||||
max_memories_per_chat: 每个聊天的最大记忆数量
|
"""初始化工作记忆,确定聊天类型"""
|
||||||
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
|
if self._initialized:
|
||||||
"""
|
return
|
||||||
self.memory_manager = MemoryManager(chat_id)
|
|
||||||
|
|
||||||
# 记忆容量上限
|
# 获取聊天类型
|
||||||
self.max_memories_per_chat = max_memories_per_chat
|
chat_stream = await asyncio.to_thread(chat_manager.get_stream, self.chat_id)
|
||||||
|
self.is_group_chat = chat_stream.group_info is not None
|
||||||
|
|
||||||
# 自动衰减间隔
|
# 根据聊天类型调整参数
|
||||||
self.auto_decay_interval = auto_decay_interval
|
if not self.is_group_chat:
|
||||||
|
self._max_items = 20 # 私聊保持较少的记忆条目
|
||||||
|
self._memory_cleanup_threshold = 0.8 # 私聊更激进的清理阈值
|
||||||
|
else:
|
||||||
|
self._max_items = 50 # 群聊保持更多记忆条目
|
||||||
|
self._memory_cleanup_threshold = 0.9 # 群聊较保守的清理阈值
|
||||||
|
|
||||||
# 衰减任务
|
self._initialized = True
|
||||||
self.decay_task = None
|
|
||||||
|
|
||||||
# 启动自动衰减任务
|
async def add_memory_item(self, item: MemoryItem):
|
||||||
self._start_auto_decay()
|
"""添加新的记忆项"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
def _start_auto_decay(self):
|
# 检查是否需要清理
|
||||||
"""启动自动衰减任务"""
|
if len(self._memory_items) >= self._max_items * self._memory_cleanup_threshold:
|
||||||
if self.decay_task is None:
|
await self._cleanup_old_memories()
|
||||||
self.decay_task = asyncio.create_task(self._auto_decay_loop())
|
|
||||||
|
|
||||||
async def _auto_decay_loop(self):
|
# 私聊时进行额外的重复检查
|
||||||
"""自动衰减循环"""
|
if not self.is_group_chat:
|
||||||
while True:
|
# 检查是否有相似内容,有则更新而不是添加
|
||||||
await asyncio.sleep(self.auto_decay_interval)
|
for existing_item in self._memory_items:
|
||||||
try:
|
if self._is_similar_content(existing_item.content, item.content):
|
||||||
await self.decay_all_memories()
|
existing_item.update_with(item)
|
||||||
except Exception as e:
|
return
|
||||||
print(f"自动衰减记忆时出错: {str(e)}")
|
|
||||||
|
self._memory_items.append(item)
|
||||||
|
if len(self._memory_items) > self._max_items:
|
||||||
|
self._memory_items.pop(0)
|
||||||
|
|
||||||
|
def _is_similar_content(self, content1: str, content2: str) -> bool:
|
||||||
|
"""检查两个内容是否相似"""
|
||||||
|
# 这里可以实现更复杂的相似度检查
|
||||||
|
# 目前简单实现
|
||||||
|
return content1.strip() == content2.strip()
|
||||||
|
|
||||||
|
async def _cleanup_old_memories(self):
|
||||||
|
"""清理旧的记忆项"""
|
||||||
|
if not self._memory_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.is_group_chat:
|
||||||
|
# 私聊:保留最近的对话和重要的记忆
|
||||||
|
self._memory_items = [
|
||||||
|
item for item in self._memory_items
|
||||||
|
if (time.time() - item.timestamp < 3600) or item.importance > 0.7
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# 群聊:使用现有的清理逻辑
|
||||||
|
# 按重要性和时间排序
|
||||||
|
self._memory_items.sort(key=lambda x: (x.importance, -x.timestamp), reverse=True)
|
||||||
|
# 保留前80%
|
||||||
|
keep_count = int(self._max_items * 0.8)
|
||||||
|
self._memory_items = self._memory_items[:keep_count]
|
||||||
|
|
||||||
|
async def get_related_memory(self, query: str, limit: int = 5) -> List[MemoryItem]:
|
||||||
|
"""获取与查询相关的记忆项"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
if not self._memory_items:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 根据聊天类型调整搜索策略
|
||||||
|
if not self.is_group_chat:
|
||||||
|
# 私聊:优先考虑最近的对话
|
||||||
|
recent_limit = min(limit * 2, len(self._memory_items))
|
||||||
|
recent_items = self._memory_items[-recent_limit:]
|
||||||
|
|
||||||
|
# 计算相关性分数
|
||||||
|
scored_items = []
|
||||||
|
for item in recent_items:
|
||||||
|
score = self._calculate_relevance(query, item.content)
|
||||||
|
if score > 0.3: # 提高私聊的相关性阈值
|
||||||
|
scored_items.append((score, item))
|
||||||
|
|
||||||
|
# 按分数排序并返回前limit个
|
||||||
|
scored_items.sort(reverse=True)
|
||||||
|
return [item for score, item in scored_items[:limit]]
|
||||||
|
else:
|
||||||
|
# 群聊:使用现有的搜索逻辑
|
||||||
|
scored_items = []
|
||||||
|
for item in self._memory_items:
|
||||||
|
score = self._calculate_relevance(query, item.content)
|
||||||
|
if score > 0.2:
|
||||||
|
scored_items.append((score, item))
|
||||||
|
|
||||||
|
scored_items.sort(reverse=True)
|
||||||
|
return [item for score, item in scored_items[:limit]]
|
||||||
|
|
||||||
|
def _calculate_relevance(self, query: str, content: str) -> float:
|
||||||
|
"""计算查询与内容的相关性分数"""
|
||||||
|
# 这里可以实现更复杂的相关性计算
|
||||||
|
# 目前使用简单的包含关系检查
|
||||||
|
query_words = set(query.split())
|
||||||
|
content_words = set(content.split())
|
||||||
|
common_words = query_words & content_words
|
||||||
|
if not query_words:
|
||||||
|
return 0.0
|
||||||
|
return len(common_words) / len(query_words)
|
||||||
|
|
||||||
|
async def clear_memory(self):
|
||||||
|
"""清空工作记忆"""
|
||||||
|
self._memory_items = []
|
||||||
|
|
||||||
|
def get_all_memories(self) -> List[MemoryItem]:
|
||||||
|
"""获取所有记忆项"""
|
||||||
|
return self._memory_items.copy()
|
||||||
|
|
||||||
|
def get_memory_count(self) -> int:
|
||||||
|
"""获取当前记忆项数量"""
|
||||||
|
return len(self._memory_items)
|
||||||
|
|
||||||
async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None):
|
async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -181,12 +277,3 @@ class WorkingMemory:
|
||||||
await self.decay_task
|
await self.decay_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_all_memories(self) -> List[MemoryItem]:
|
|
||||||
"""
|
|
||||||
获取所有记忆项目
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
|
|
||||||
"""
|
|
||||||
return self.memory_manager.get_all_items()
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,133 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||||
|
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
|
||||||
|
from src.chat.message_receive.message import MessageRecv, MessageInfo, UserInfo
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def subheartflow_manager():
|
||||||
|
manager = SubHeartflowManager()
|
||||||
|
yield manager
|
||||||
|
# 清理所有子心流
|
||||||
|
for subflow in manager.get_all_subheartflows():
|
||||||
|
await subflow.shutdown()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_chat_stream():
|
||||||
|
stream = MagicMock(spec=ChatStream)
|
||||||
|
stream.group_info = None # 私聊没有群信息
|
||||||
|
stream.platform = "test"
|
||||||
|
return stream
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_message():
|
||||||
|
message = MagicMock(spec=MessageRecv)
|
||||||
|
message.message_info = MagicMock(spec=MessageInfo)
|
||||||
|
message.message_info.user_info = MagicMock(spec=UserInfo)
|
||||||
|
message.message_info.message_id = "test_msg_1"
|
||||||
|
message.message_info.time = 1234567890
|
||||||
|
message.message_info.platform = "test"
|
||||||
|
message.processed_plain_text = "测试消息"
|
||||||
|
return message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_private_chat_absent_to_focused(subheartflow_manager, mock_chat_stream, mock_message):
|
||||||
|
"""测试私聊从ABSENT到FOCUSED的转换"""
|
||||||
|
# 创建私聊子心流
|
||||||
|
subflow_id = "private_test_1"
|
||||||
|
subflow = await subheartflow_manager.get_or_create_subheartflow(subflow_id)
|
||||||
|
assert subflow is not None
|
||||||
|
|
||||||
|
# 确认初始状态为ABSENT
|
||||||
|
assert subflow.chat_state.chat_status == ChatState.ABSENT
|
||||||
|
|
||||||
|
# 模拟新消息
|
||||||
|
with patch("src.chat.heart_flow.observation.chatting_observation.ChattingObservation") as mock_obs:
|
||||||
|
mock_obs_instance = mock_obs.return_value
|
||||||
|
mock_obs_instance.has_new_messages_since.return_value = True
|
||||||
|
|
||||||
|
# 触发状态检查
|
||||||
|
await subheartflow_manager.sbhf_absent_private_into_focus()
|
||||||
|
|
||||||
|
# 验证状态已转换为FOCUSED
|
||||||
|
assert subflow.chat_state.chat_status == ChatState.FOCUSED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_private_chat_focused_to_absent(subheartflow_manager, mock_chat_stream, mock_message):
|
||||||
|
"""测试私聊从FOCUSED到ABSENT的转换"""
|
||||||
|
# 创建私聊子心流
|
||||||
|
subflow_id = "private_test_2"
|
||||||
|
subflow = await subheartflow_manager.get_or_create_subheartflow(subflow_id)
|
||||||
|
assert subflow is not None
|
||||||
|
|
||||||
|
# 设置初始状态为FOCUSED
|
||||||
|
await subflow.change_chat_state(ChatState.FOCUSED)
|
||||||
|
assert subflow.chat_state.chat_status == ChatState.FOCUSED
|
||||||
|
|
||||||
|
# 触发状态转换
|
||||||
|
await subheartflow_manager.sbhf_focus_into_normal(subflow_id)
|
||||||
|
|
||||||
|
# 验证状态已转换为ABSENT(私聊特殊处理)
|
||||||
|
assert subflow.chat_state.chat_status == ChatState.ABSENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_private_chat_stop_keywords(subheartflow_manager, mock_chat_stream, mock_message):
|
||||||
|
"""测试私聊停止关键词触发状态转换"""
|
||||||
|
# 创建私聊子心流
|
||||||
|
subflow_id = "private_test_3"
|
||||||
|
subflow = await subheartflow_manager.get_or_create_subheartflow(subflow_id)
|
||||||
|
assert subflow is not None
|
||||||
|
|
||||||
|
# 设置初始状态为FOCUSED
|
||||||
|
await subflow.change_chat_state(ChatState.FOCUSED)
|
||||||
|
|
||||||
|
# 模拟包含停止关键词的消息
|
||||||
|
mock_message.processed_plain_text = "再见啦"
|
||||||
|
|
||||||
|
# 模拟处理消息
|
||||||
|
with patch("src.chat.focus_chat.planners.planner.ActionPlanner") as mock_planner:
|
||||||
|
mock_planner_instance = mock_planner.return_value
|
||||||
|
mock_planner_instance.plan_simple.return_value = {
|
||||||
|
"action_type": "stop_focus_chat",
|
||||||
|
"action_data": {},
|
||||||
|
"reasoning": "检测到停止专注聊天的信号"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 触发状态转换
|
||||||
|
await subheartflow_manager.sbhf_focus_into_normal(subflow_id)
|
||||||
|
|
||||||
|
# 验证状态已转换为ABSENT
|
||||||
|
assert subflow.chat_state.chat_status == ChatState.ABSENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_private_chat_error_recovery(subheartflow_manager, mock_chat_stream, mock_message):
|
||||||
|
"""测试私聊错误恢复"""
|
||||||
|
# 创建私聊子心流
|
||||||
|
subflow_id = "private_test_4"
|
||||||
|
subflow = await subheartflow_manager.get_or_create_subheartflow(subflow_id)
|
||||||
|
assert subflow is not None
|
||||||
|
|
||||||
|
# 设置初始状态为FOCUSED
|
||||||
|
await subflow.change_chat_state(ChatState.FOCUSED)
|
||||||
|
|
||||||
|
# 模拟处理器错误
|
||||||
|
with patch("src.chat.focus_chat.heartFC_chat.HeartFChatting._process_processors") as mock_proc:
|
||||||
|
mock_proc.side_effect = Exception("处理器错误")
|
||||||
|
|
||||||
|
# 触发处理
|
||||||
|
# 注:这里需要实际调用处理逻辑
|
||||||
|
|
||||||
|
# 验证状态保持为FOCUSED(错误不应导致状态改变)
|
||||||
|
assert subflow.chat_state.chat_status == ChatState.FOCUSED
|
||||||
|
|
||||||
|
# 验证错误后的下一次处理仍然正常
|
||||||
|
mock_proc.side_effect = None # 清除错误
|
||||||
|
mock_proc.return_value = ([], {}) # 返回空结果
|
||||||
|
|
||||||
|
# 再次触发处理
|
||||||
|
# 注:这里需要实际调用处理逻辑
|
||||||
|
|
||||||
|
# 验证状态仍然正常
|
||||||
|
assert subflow.chat_state.chat_status == ChatState.FOCUSED
|
||||||
Loading…
Reference in New Issue