mirror of https://github.com/Mai-with-u/MaiBot.git
ChatSession(原ChatStream)与ChatManager;由于功能简单,测试略
parent
04a5bf3c6d
commit
37f8c923c1
|
|
@ -142,6 +142,9 @@ version 0.3.0 - 2026-01-11
|
|||
- [ ] 代码示例
|
||||
## 消息链构建(仿Astrbot模式)
|
||||
将消息仿照Astrbot的消息链模式进行构建,消息链中的每个元素都是一个消息组件,消息链本身也是一个数据模型,包含了消息组件列表以及一些元信息(如是否为转发消息等)。
|
||||
### Accept Format检查
|
||||
- [ ] 在最后发送消息的时候进行Accept Format检查,确保消息链中的每个消息组件都符合平台的Accept Format要求
|
||||
- [ ] 如果消息链中的某个消息组件不符合Accept Format要求,应该抛弃该消息组件,并记录日志说明被抛弃的消息组件的类型和内容
|
||||
|
||||
## 表情包系统
|
||||
- [ ] 移除大量冗余代码,全部返回单一对象MaiEmoji
|
||||
|
|
|
|||
|
|
@ -0,0 +1,228 @@
|
|||
from datetime import datetime
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import Optional, TYPE_CHECKING, List, Dict
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.chat_session_data_model import MaiChatSession
|
||||
from src.common.database.database_model import ChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .message import SessionMessage
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_manager")
|
||||
|
||||
|
||||
class SessionContext:
|
||||
"""会话上下文"""
|
||||
|
||||
def __init__(self, message: "SessionMessage"):
|
||||
self.message = message
|
||||
self.template_name: Optional[str] = None
|
||||
|
||||
def update_template(self, template_name: str):
|
||||
"""更新当前使用的回复模板"""
|
||||
self.template_name = template_name
|
||||
|
||||
|
||||
class BotChatSession(MaiChatSession):
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
platform: str,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
created_timestamp: Optional[datetime] = None,
|
||||
last_active_timestamp: Optional[datetime] = None,
|
||||
):
|
||||
self.context: Optional[SessionContext] = None
|
||||
self.accept_format: List[str] = []
|
||||
|
||||
super().__init__(
|
||||
session_id=session_id,
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
created_timestamp=created_timestamp,
|
||||
last_active_timestamp=last_active_timestamp,
|
||||
)
|
||||
|
||||
def check_types(self, types: List[str]) -> bool:
|
||||
"""检查消息是否符合可接受类型列表"""
|
||||
return all(t in self.accept_format for t in types)
|
||||
|
||||
def update_active_time(self):
|
||||
"""更新最后活跃时间"""
|
||||
self.last_active_timestamp = datetime.now()
|
||||
|
||||
def set_context(self, message: "SessionMessage"):
|
||||
"""设置会话上下文"""
|
||||
self.context = SessionContext(message=message)
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器,负责管理所有聊天会话"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sessions: Dict[str, BotChatSession] = {} # session_id -> BotChatSession
|
||||
self.last_messages: Dict[str, "SessionMessage"] = {} # session_id -> SessionMessage
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化聊天管理器"""
|
||||
try:
|
||||
await self.load_all_sessions_from_db()
|
||||
logger.info(f"已加载 {len(self.sessions)} 个会话记录到内存中")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化聊天管理器出现错误: {e}")
|
||||
|
||||
async def get_or_create_session(
|
||||
self, platform: str, user_id: str, group_id: Optional[str] = None
|
||||
) -> BotChatSession:
|
||||
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
|
||||
|
||||
Args:
|
||||
platform: 平台
|
||||
user_id: 用户ID
|
||||
group_id: 群ID(如果是群聊)
|
||||
Returns:
|
||||
return (BotChatSession) 会话对象
|
||||
Raises:
|
||||
Exception: 获取或创建会话时发生错误
|
||||
"""
|
||||
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
if session := self.get_session_by_session_id(session_id):
|
||||
session.update_active_time()
|
||||
return session
|
||||
|
||||
# 内存没有就找db
|
||||
try:
|
||||
with get_db_session() as db_session:
|
||||
statement = select(ChatSession).filter_by(session_id=session_id)
|
||||
if result := db_session.exec(statement).first():
|
||||
session = BotChatSession.from_db_instance(result)
|
||||
self.sessions[session.session_id] = session
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取会话时发生错误: {e}")
|
||||
raise e
|
||||
|
||||
# 都没有就创建新的
|
||||
new_session = BotChatSession(
|
||||
session_id=session_id,
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
self.sessions[new_session.session_id] = new_session
|
||||
if new_session.session_id in self.last_messages:
|
||||
new_session.set_context(self.last_messages[new_session.session_id])
|
||||
self._save_session(new_session)
|
||||
return new_session
|
||||
|
||||
def register_message(self, message: "SessionMessage"):
|
||||
platform = message.platform
|
||||
if not platform:
|
||||
raise ValueError("消息缺少平台信息")
|
||||
user_id = message.message_info.user_info.user_id
|
||||
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
||||
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
message.session_id = session_id # 确保消息的session_id正确设置
|
||||
self.last_messages[session_id] = message
|
||||
|
||||
async def load_all_sessions_from_db(self):
|
||||
"""从数据库加载全部会话记录到内存中"""
|
||||
self.sessions.clear()
|
||||
try:
|
||||
await asyncio.to_thread(self._load_sessions_from_db)
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载会话记录时发生错误: {e}")
|
||||
self.sessions.clear()
|
||||
raise e
|
||||
|
||||
def save_all_sessions(self):
|
||||
"""将内存中的全部会话记录保存到数据库"""
|
||||
try:
|
||||
for session in self.sessions.values():
|
||||
self._save_session(session)
|
||||
logger.info(f"已保存 {len(self.sessions)} 个会话记录到数据库中")
|
||||
except Exception as e:
|
||||
logger.error(f"保存会话记录到数据库时发生错误: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_name(self, session_id: str) -> Optional[str]:
|
||||
"""根据会话ID获取会话名称
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
Returns:
|
||||
Optional[str]: 会话名称,如果无法获取则返回None
|
||||
"""
|
||||
session = self.sessions.get(session_id)
|
||||
if not session:
|
||||
return None
|
||||
if session.is_group_session:
|
||||
if session.context and session.context.message and session.context.message.message_info.group_info:
|
||||
return session.context.message.message_info.group_info.group_name
|
||||
elif session.context and session.context.message and session.context.message.message_info.user_info:
|
||||
nickname = session.context.message.message_info.user_info.user_nickname
|
||||
return f"{nickname}的私聊"
|
||||
return None
|
||||
|
||||
def get_session_by_info(
|
||||
self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None
|
||||
) -> Optional[BotChatSession]:
|
||||
"""根据平台、用户ID和群ID获取对应的会话
|
||||
|
||||
Args:
|
||||
platform: 平台
|
||||
user_id: 用户ID
|
||||
group_id: 群ID(如果是群聊)
|
||||
Returns:
|
||||
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
|
||||
"""
|
||||
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
return self.get_session_by_session_id(session_id)
|
||||
|
||||
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
|
||||
"""根据会话ID获取对应的会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
Returns:
|
||||
Optional[BotChatSession]: 会话对象,如果不存在则返回None
|
||||
"""
|
||||
session = self.sessions.get(session_id)
|
||||
if session and session_id in self.last_messages:
|
||||
session.set_context(self.last_messages[session_id])
|
||||
return session
|
||||
|
||||
def _load_sessions_from_db(self):
|
||||
"""从数据库加载单个会话记录"""
|
||||
with get_db_session() as session:
|
||||
statements = select(ChatSession)
|
||||
for model_instance in session.exec(statements).all():
|
||||
bot_chat_session = BotChatSession.from_db_instance(model_instance)
|
||||
self.sessions[bot_chat_session.session_id] = bot_chat_session
|
||||
if bot_chat_session.session_id in self.last_messages:
|
||||
bot_chat_session.set_context(self.last_messages[bot_chat_session.session_id])
|
||||
|
||||
def _save_session(self, session: BotChatSession):
|
||||
"""将会话记录保存到数据库"""
|
||||
with get_db_session() as db_session:
|
||||
db_instance = session.to_db_instance()
|
||||
statement = select(ChatSession).filter_by(session_id=db_instance.session_id)
|
||||
if result := db_session.exec(statement).first():
|
||||
result.created_timestamp = db_instance.created_timestamp
|
||||
result.last_active_timestamp = db_instance.last_active_timestamp
|
||||
db_session.add(result)
|
||||
else:
|
||||
db_session.add(db_instance)
|
||||
|
||||
|
||||
chat_manager = ChatManager()
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from src.common.database.database_model import ChatSession
|
||||
|
|
@ -6,11 +7,22 @@ from . import BaseDatabaseDataModel
|
|||
|
||||
|
||||
class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
||||
def __init__(self, session_id: str, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None):
|
||||
self.session_id = session_id
|
||||
self.platform = platform
|
||||
self.user_id = user_id
|
||||
self.group_id = group_id
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
platform: str,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
created_timestamp: Optional[datetime] = None,
|
||||
last_active_timestamp: Optional[datetime] = None,
|
||||
):
|
||||
self.session_id: str = session_id
|
||||
self.platform: str = platform
|
||||
self.user_id: Optional[str] = user_id
|
||||
self.group_id: Optional[str] = group_id
|
||||
self.created_timestamp: datetime = created_timestamp or datetime.now()
|
||||
"""会话创建时间,默认为当前时间"""
|
||||
self.last_active_timestamp: Optional[datetime] = last_active_timestamp
|
||||
|
||||
# 验证字段
|
||||
assert self.platform, "Platform must be provided"
|
||||
|
|
@ -26,6 +38,8 @@ class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
|||
platform=db_record.platform,
|
||||
user_id=db_record.user_id,
|
||||
group_id=db_record.group_id,
|
||||
created_timestamp=db_record.created_timestamp,
|
||||
last_active_timestamp=db_record.last_active_timestamp,
|
||||
)
|
||||
|
||||
def to_db_instance(self) -> ChatSession:
|
||||
|
|
@ -34,4 +48,6 @@ class MaiChatSession(BaseDatabaseDataModel[ChatSession]):
|
|||
platform=self.platform,
|
||||
user_id=self.user_id,
|
||||
group_id=self.group_id,
|
||||
created_timestamp=self.created_timestamp,
|
||||
last_active_timestamp=self.last_active_timestamp,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -78,6 +78,9 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
|||
)
|
||||
|
||||
def to_db_instance(self) -> "PersonInfo":
|
||||
group_cardname = (
|
||||
json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
|
||||
)
|
||||
return PersonInfo(
|
||||
is_known=self.is_known,
|
||||
person_id=self.person_id,
|
||||
|
|
@ -86,7 +89,7 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
|||
platform=self.platform,
|
||||
user_id=self.user_id,
|
||||
user_nickname=self.user_nickname,
|
||||
group_cardname=json.dumps([gn.__dict__ for gn in self.group_cardname_list]) if self.group_cardname_list else None,
|
||||
group_cardname=group_cardname,
|
||||
memory_points=json.dumps(self.memory_points) if self.memory_points else None,
|
||||
know_counts=self.know_counts,
|
||||
first_known_time=self.first_known_time,
|
||||
|
|
|
|||
|
|
@ -300,7 +300,7 @@ class ChatSession(SQLModel, table=True):
|
|||
created_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 创建时间
|
||||
last_active_timestamp: datetime = Field(
|
||||
last_active_timestamp: Optional[datetime] = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 最后活跃时间
|
||||
|
||||
|
|
|
|||
|
|
@ -124,3 +124,14 @@ class MessageUtils:
|
|||
((True, pattern) for pattern in global_config.message_receive.ban_msgs_regex if re.search(pattern, text)),
|
||||
(False, None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
|
||||
"""计算会话ID"""
|
||||
if not user_id and not group_id:
|
||||
raise ValueError("UserID 或 GroupID 必须提供其一")
|
||||
if group_id:
|
||||
components = [platform, group_id]
|
||||
else:
|
||||
components = [platform, user_id, "private"]
|
||||
return hashlib.md5("_".join(components).encode()).hexdigest()
|
||||
|
|
|
|||
Loading…
Reference in New Issue