拆分新的utils

r-dev
UnCLAS-Prommer 2026-02-27 22:07:26 +08:00
parent e253a2ed2a
commit b9faed4924
No known key found for this signature in database
4 changed files with 31 additions and 16 deletions

View File

@ -9,7 +9,7 @@ from src.common.logger import get_logger
from src.common.data_models.chat_session_data_model import MaiChatSession from src.common.data_models.chat_session_data_model import MaiChatSession
from src.common.database.database_model import ChatSession from src.common.database.database_model import ChatSession
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_session import SessionUtils
if TYPE_CHECKING: if TYPE_CHECKING:
from .message import SessionMessage from .message import SessionMessage
@ -95,7 +95,7 @@ class ChatManager:
Raises: Raises:
Exception: 获取或创建会话时发生错误 Exception: 获取或创建会话时发生错误
""" """
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
if session := self.get_session_by_session_id(session_id): if session := self.get_session_by_session_id(session_id):
session.update_active_time() session.update_active_time()
return session return session
@ -131,7 +131,7 @@ class ChatManager:
raise ValueError("消息缺少平台信息") raise ValueError("消息缺少平台信息")
user_id = message.message_info.user_info.user_id user_id = message.message_info.user_info.user_id
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
message.session_id = session_id # 确保消息的session_id正确设置 message.session_id = session_id # 确保消息的session_id正确设置
self.last_messages[session_id] = message self.last_messages[session_id] = message
@ -199,7 +199,7 @@ class ChatManager:
Returns: Returns:
return (Optional[BotChatSession]): 会话对象如果不存在则返回None return (Optional[BotChatSession]): 会话对象如果不存在则返回None
""" """
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
return self.get_session_by_session_id(session_id) return self.get_session_by_session_id(session_id)
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]: def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:

View File

@ -128,17 +128,6 @@ class MessageUtils:
(False, None), (False, None),
) )
@staticmethod
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
"""计算会话ID"""
if not user_id and not group_id:
raise ValueError("UserID 或 GroupID 必须提供其一")
if group_id:
components = [platform, group_id]
else:
components = [platform, user_id, "private"]
return hashlib.md5("_".join(components).encode()).hexdigest()
@staticmethod @staticmethod
def store_message_to_db(message: "SessionMessage"): def store_message_to_db(message: "SessionMessage"):
"""存储消息到数据库""" """存储消息到数据库"""

View File

@ -0,0 +1,26 @@
from typing import Optional
import hashlib
class SessionUtils:
@staticmethod
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
"""计算session_id
Args:
platform: 平台名称
user_id: 用户ID如果是私聊
group_id: 群ID如果是群聊
Returns:
str: 计算得到的会话ID
Raises:
ValueError: user_id group_id 都未提供时抛出
"""
if not user_id and not group_id:
raise ValueError("UserID 或 GroupID 必须提供其一")
if group_id:
components = [platform, group_id]
else:
components = [platform, user_id, "private"]
return hashlib.md5("_".join(components).encode()).hexdigest()

View File

@ -349,7 +349,7 @@ class TargetItem(ConfigBase):
"x-icon": "hash", "x-icon": "hash",
}, },
) )
"""用户ID与平台一起留空表示全局""" """用户/群ID与平台一起留空表示全局"""
rule_type: Literal["group", "private"] = Field( rule_type: Literal["group", "private"] = Field(
default="group", default="group",