拆分新的utils

r-dev
UnCLAS-Prommer 2026-02-27 22:07:26 +08:00
parent e253a2ed2a
commit b9faed4924
No known key found for this signature in database
4 changed files with 31 additions and 16 deletions

View File

@ -9,7 +9,7 @@ from src.common.logger import get_logger
from src.common.data_models.chat_session_data_model import MaiChatSession
from src.common.database.database_model import ChatSession
from src.common.database.database import get_db_session
from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
if TYPE_CHECKING:
from .message import SessionMessage
@ -95,7 +95,7 @@ class ChatManager:
Raises:
Exception: 获取或创建会话时发生错误
"""
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
if session := self.get_session_by_session_id(session_id):
session.update_active_time()
return session
@ -131,7 +131,7 @@ class ChatManager:
raise ValueError("消息缺少平台信息")
user_id = message.message_info.user_info.user_id
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
message.session_id = session_id # 确保消息的session_id正确设置
self.last_messages[session_id] = message
@ -199,7 +199,7 @@ class ChatManager:
Returns:
return (Optional[BotChatSession]): 会话对象如果不存在则返回None
"""
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
return self.get_session_by_session_id(session_id)
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:

View File

@ -128,17 +128,6 @@ class MessageUtils:
(False, None),
)
@staticmethod
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
"""计算会话ID"""
if not user_id and not group_id:
raise ValueError("UserID 或 GroupID 必须提供其一")
if group_id:
components = [platform, group_id]
else:
components = [platform, user_id, "private"]
return hashlib.md5("_".join(components).encode()).hexdigest()
@staticmethod
def store_message_to_db(message: "SessionMessage"):
"""存储消息到数据库"""

View File

@ -0,0 +1,26 @@
from typing import Optional
import hashlib
class SessionUtils:
@staticmethod
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
"""计算session_id
Args:
platform: 平台名称
user_id: 用户ID如果是私聊
group_id: 群ID如果是群聊
Returns:
str: 计算得到的会话ID
Raises:
ValueError: user_id group_id 都未提供时抛出
"""
if not user_id and not group_id:
raise ValueError("UserID 或 GroupID 必须提供其一")
if group_id:
components = [platform, group_id]
else:
components = [platform, user_id, "private"]
return hashlib.md5("_".join(components).encode()).hexdigest()

View File

@ -349,7 +349,7 @@ class TargetItem(ConfigBase):
"x-icon": "hash",
},
)
"""用户ID与平台一起留空表示全局"""
"""用户/群ID与平台一起留空表示全局"""
rule_type: Literal["group", "private"] = Field(
default="group",