From 37f8c923c15fd3eec9c1a9cb5f5392f35591c5a1 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 22 Feb 2026 22:26:28 +0800 Subject: [PATCH] =?UTF-8?q?ChatSession(=E5=8E=9FChatStream)=E4=B8=8EChatMa?= =?UTF-8?q?nager=EF=BC=9B=E7=94=B1=E4=BA=8E=E5=8A=9F=E8=83=BD=E7=AE=80?= =?UTF-8?q?=E5=8D=95=EF=BC=8C=E6=B5=8B=E8=AF=95=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- changelogs/mai_next_todo.md | 3 + src/chat/message_receive/chat_manager.py | 228 ++++++++++++++++++ .../data_models/chat_session_data_model.py | 26 +- .../data_models/person_info_data_model.py | 5 +- src/common/database/database_model.py | 2 +- src/common/utils/utils_message.py | 11 + 6 files changed, 268 insertions(+), 7 deletions(-) create mode 100644 src/chat/message_receive/chat_manager.py diff --git a/changelogs/mai_next_todo.md b/changelogs/mai_next_todo.md index 2c68b177..58d7672f 100644 --- a/changelogs/mai_next_todo.md +++ b/changelogs/mai_next_todo.md @@ -142,6 +142,9 @@ version 0.3.0 - 2026-01-11 - [ ] 代码示例 ## 消息链构建(仿Astrbot模式) 将消息仿照Astrbot的消息链模式进行构建,消息链中的每个元素都是一个消息组件,消息链本身也是一个数据模型,包含了消息组件列表以及一些元信息(如是否为转发消息等)。 +### Accept Format检查 +- [ ] 在最后发送消息的时候进行Accept Format检查,确保消息链中的每个消息组件都符合平台的Accept Format要求 +- [ ] 如果消息链中的某个消息组件不符合Accept Format要求,应该抛弃该消息组件,并记录日志说明被抛弃的消息组件的类型和内容 ## 表情包系统 - [ ] 移除大量冗余代码,全部返回单一对象MaiEmoji diff --git a/src/chat/message_receive/chat_manager.py b/src/chat/message_receive/chat_manager.py new file mode 100644 index 00000000..f0cc4e80 --- /dev/null +++ b/src/chat/message_receive/chat_manager.py @@ -0,0 +1,228 @@ +from datetime import datetime +from rich.traceback import install +from sqlmodel import select +from typing import Optional, TYPE_CHECKING, List, Dict + +import asyncio + +from src.common.logger import get_logger +from src.common.data_models.chat_session_data_model import MaiChatSession +from src.common.database.database_model import ChatSession +from src.common.database.database import get_db_session +from src.common.utils.utils_message import MessageUtils + +if TYPE_CHECKING: + from .message import SessionMessage + +install(extra_lines=3) + +logger = get_logger("chat_manager") + + +class SessionContext: + """会话上下文""" + + def __init__(self, message: "SessionMessage"): + self.message = message + self.template_name: Optional[str] = None + + def update_template(self, template_name: str): + """更新当前使用的回复模板""" + self.template_name = template_name + + +class BotChatSession(MaiChatSession): + def __init__( + self, + session_id: str, + platform: str, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + created_timestamp: Optional[datetime] = None, + last_active_timestamp: Optional[datetime] = None, + ): + self.context: Optional[SessionContext] = None + self.accept_format: List[str] = [] + + super().__init__( + session_id=session_id, + platform=platform, + user_id=user_id, + group_id=group_id, + created_timestamp=created_timestamp, + last_active_timestamp=last_active_timestamp, + ) + + def check_types(self, types: List[str]) -> bool: + """检查消息是否符合可接受类型列表""" + return all(t in self.accept_format for t in types) + + def update_active_time(self): + """更新最后活跃时间""" + self.last_active_timestamp = datetime.now() + + def set_context(self, message: "SessionMessage"): + """设置会话上下文""" + self.context = SessionContext(message=message) + + +class ChatManager: + """聊天管理器,负责管理所有聊天会话""" + + def __init__(self) -> None: + self.sessions: Dict[str, BotChatSession] = {} # session_id -> BotChatSession + self.last_messages: Dict[str, "SessionMessage"] = {} # session_id -> SessionMessage + + async def initialize(self): + """初始化聊天管理器""" + try: + await self.load_all_sessions_from_db() + logger.info(f"已加载 {len(self.sessions)} 个会话记录到内存中") + except Exception as e: + logger.error(f"初始化聊天管理器出现错误: {e}") + + async def get_or_create_session( + self, platform: str, user_id: str, group_id: Optional[str] = None + ) -> BotChatSession: + """获取会话,如果不存在则创建一个新会话;一个封装方法。 + + Args: + platform: 平台 + user_id: 用户ID + group_id: 群ID(如果是群聊) + Returns: + return (BotChatSession) 会话对象 + Raises: + Exception: 获取或创建会话时发生错误 + """ + session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + if session := self.get_session_by_session_id(session_id): + session.update_active_time() + return session + + # 内存没有就找db + try: + with get_db_session() as db_session: + statement = select(ChatSession).filter_by(session_id=session_id) + if result := db_session.exec(statement).first(): + session = BotChatSession.from_db_instance(result) + self.sessions[session.session_id] = session + return session + except Exception as e: + logger.error(f"从数据库获取会话时发生错误: {e}") + raise e + + # 都没有就创建新的 + new_session = BotChatSession( + session_id=session_id, + platform=platform, + user_id=user_id, + group_id=group_id, + ) + self.sessions[new_session.session_id] = new_session + if new_session.session_id in self.last_messages: + new_session.set_context(self.last_messages[new_session.session_id]) + self._save_session(new_session) + return new_session + + def register_message(self, message: "SessionMessage"): + platform = message.platform + if not platform: + raise ValueError("消息缺少平台信息") + user_id = message.message_info.user_info.user_id + group_id = message.message_info.group_info.group_id if message.message_info.group_info else None + session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + message.session_id = session_id # 确保消息的session_id正确设置 + self.last_messages[session_id] = message + + async def load_all_sessions_from_db(self): + """从数据库加载全部会话记录到内存中""" + self.sessions.clear() + try: + await asyncio.to_thread(self._load_sessions_from_db) + except Exception as e: + logger.error(f"从数据库加载会话记录时发生错误: {e}") + self.sessions.clear() + raise e + + def save_all_sessions(self): + """将内存中的全部会话记录保存到数据库""" + try: + for session in self.sessions.values(): + self._save_session(session) + logger.info(f"已保存 {len(self.sessions)} 个会话记录到数据库中") + except Exception as e: + logger.error(f"保存会话记录到数据库时发生错误: {e}") + raise e + + def get_session_name(self, session_id: str) -> Optional[str]: + """根据会话ID获取会话名称 + + Args: + session_id: 会话ID + Returns: + Optional[str]: 会话名称,如果无法获取则返回None + """ + session = self.sessions.get(session_id) + if not session: + return None + if session.is_group_session: + if session.context and session.context.message and session.context.message.message_info.group_info: + return session.context.message.message_info.group_info.group_name + elif session.context and session.context.message and session.context.message.message_info.user_info: + nickname = session.context.message.message_info.user_info.user_nickname + return f"{nickname}的私聊" + return None + + def get_session_by_info( + self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None + ) -> Optional[BotChatSession]: + """根据平台、用户ID和群ID获取对应的会话 + + Args: + platform: 平台 + user_id: 用户ID + group_id: 群ID(如果是群聊) + Returns: + return (Optional[BotChatSession]): 会话对象,如果不存在则返回None + """ + session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + return self.get_session_by_session_id(session_id) + + def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]: + """根据会话ID获取对应的会话 + + Args: + session_id: 会话ID + Returns: + Optional[BotChatSession]: 会话对象,如果不存在则返回None + """ + session = self.sessions.get(session_id) + if session and session_id in self.last_messages: + session.set_context(self.last_messages[session_id]) + return session + + def _load_sessions_from_db(self): + """从数据库加载单个会话记录""" + with get_db_session() as session: + statements = select(ChatSession) + for model_instance in session.exec(statements).all(): + bot_chat_session = BotChatSession.from_db_instance(model_instance) + self.sessions[bot_chat_session.session_id] = bot_chat_session + if bot_chat_session.session_id in self.last_messages: + bot_chat_session.set_context(self.last_messages[bot_chat_session.session_id]) + + def _save_session(self, session: BotChatSession): + """将会话记录保存到数据库""" + with get_db_session() as db_session: + db_instance = session.to_db_instance() + statement = select(ChatSession).filter_by(session_id=db_instance.session_id) + if result := db_session.exec(statement).first(): + result.created_timestamp = db_instance.created_timestamp + result.last_active_timestamp = db_instance.last_active_timestamp + db_session.add(result) + else: + db_session.add(db_instance) + + +chat_manager = ChatManager() diff --git a/src/common/data_models/chat_session_data_model.py b/src/common/data_models/chat_session_data_model.py index 33fb6cde..1648c34e 100644 --- a/src/common/data_models/chat_session_data_model.py +++ b/src/common/data_models/chat_session_data_model.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Optional from src.common.database.database_model import ChatSession @@ -6,11 +7,22 @@ from . import BaseDatabaseDataModel class MaiChatSession(BaseDatabaseDataModel[ChatSession]): - def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None): - self.session_id = session_id - self.platform = platform - self.user_id = user_id - self.group_id = group_id + def __init__( + self, + session_id: str, + platform: str, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + created_timestamp: Optional[datetime] = None, + last_active_timestamp: Optional[datetime] = None, + ): + self.session_id: str = session_id + self.platform: str = platform + self.user_id: Optional[str] = user_id + self.group_id: Optional[str] = group_id + self.created_timestamp: datetime = created_timestamp or datetime.now() + """会话创建时间,默认为当前时间""" + self.last_active_timestamp: Optional[datetime] = last_active_timestamp # 验证字段 assert self.platform, "Platform must be provided" @@ -26,6 +38,8 @@ class MaiChatSession(BaseDatabaseDataModel[ChatSession]): platform=db_record.platform, user_id=db_record.user_id, group_id=db_record.group_id, + created_timestamp=db_record.created_timestamp, + last_active_timestamp=db_record.last_active_timestamp, ) def to_db_instance(self) -> ChatSession: @@ -34,4 +48,6 @@ class MaiChatSession(BaseDatabaseDataModel[ChatSession]): platform=self.platform, user_id=self.user_id, group_id=self.group_id, + created_timestamp=self.created_timestamp, + last_active_timestamp=self.last_active_timestamp, ) diff --git a/src/common/data_models/person_info_data_model.py b/src/common/data_models/person_info_data_model.py index ac15e9dd..4cbb62d8 100644 --- a/src/common/data_models/person_info_data_model.py +++ b/src/common/data_models/person_info_data_model.py @@ -78,6 +78,9 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): ) def to_db_instance(self) -> "PersonInfo": + group_cardname = ( + json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None + ) return PersonInfo( is_known=self.is_known, person_id=self.person_id, @@ -86,7 +89,7 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): platform=self.platform, user_id=self.user_id, user_nickname=self.user_nickname, - group_cardname=json.dumps([gn.__dict__ for gn in self.group_cardname_list]) if self.group_cardname_list else None, + group_cardname=group_cardname, memory_points=json.dumps(self.memory_points) if self.memory_points else None, know_counts=self.know_counts, first_known_time=self.first_known_time, diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 48aafd79..bc03758e 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -300,7 +300,7 @@ class ChatSession(SQLModel, table=True): created_timestamp: datetime = Field( default_factory=datetime.now, sa_column=Column(DateTime, index=True) ) # 创建时间 - last_active_timestamp: datetime = Field( + last_active_timestamp: Optional[datetime] = Field( default_factory=datetime.now, sa_column=Column(DateTime, index=True) ) # 最后活跃时间 diff --git a/src/common/utils/utils_message.py b/src/common/utils/utils_message.py index 2c33a506..96481baa 100644 --- a/src/common/utils/utils_message.py +++ b/src/common/utils/utils_message.py @@ -124,3 +124,14 @@ class MessageUtils: ((True, pattern) for pattern in global_config.message_receive.ban_msgs_regex if re.search(pattern, text)), (False, None), ) + + @staticmethod + def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str: + """计算会话ID""" + if not user_id and not group_id: + raise ValueError("UserID 或 GroupID 必须提供其一") + if group_id: + components = [platform, group_id] + else: + components = [platform, user_id, "private"] + return hashlib.md5("_".join(components).encode()).hexdigest()