mirror of https://github.com/Mai-with-u/MaiBot.git
拆分新的utils
parent
e253a2ed2a
commit
b9faed4924
|
|
@ -9,7 +9,7 @@ from src.common.logger import get_logger
|
|||
from src.common.data_models.chat_session_data_model import MaiChatSession
|
||||
from src.common.database.database_model import ChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .message import SessionMessage
|
||||
|
|
@ -95,7 +95,7 @@ class ChatManager:
|
|||
Raises:
|
||||
Exception: 获取或创建会话时发生错误
|
||||
"""
|
||||
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
if session := self.get_session_by_session_id(session_id):
|
||||
session.update_active_time()
|
||||
return session
|
||||
|
|
@ -131,7 +131,7 @@ class ChatManager:
|
|||
raise ValueError("消息缺少平台信息")
|
||||
user_id = message.message_info.user_info.user_id
|
||||
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
||||
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
message.session_id = session_id # 确保消息的session_id正确设置
|
||||
self.last_messages[session_id] = message
|
||||
|
||||
|
|
@ -199,7 +199,7 @@ class ChatManager:
|
|||
Returns:
|
||||
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
|
||||
"""
|
||||
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||
return self.get_session_by_session_id(session_id)
|
||||
|
||||
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
|
||||
|
|
|
|||
|
|
@ -128,17 +128,6 @@ class MessageUtils:
|
|||
(False, None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
|
||||
"""计算会话ID"""
|
||||
if not user_id and not group_id:
|
||||
raise ValueError("UserID 或 GroupID 必须提供其一")
|
||||
if group_id:
|
||||
components = [platform, group_id]
|
||||
else:
|
||||
components = [platform, user_id, "private"]
|
||||
return hashlib.md5("_".join(components).encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def store_message_to_db(message: "SessionMessage"):
|
||||
"""存储消息到数据库"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
from typing import Optional
|
||||
|
||||
import hashlib
|
||||
|
||||
|
||||
class SessionUtils:
|
||||
@staticmethod
|
||||
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
|
||||
"""计算session_id
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
user_id: 用户ID(如果是私聊)
|
||||
group_id: 群ID(如果是群聊)
|
||||
Returns:
|
||||
str: 计算得到的会话ID
|
||||
Raises:
|
||||
ValueError: 当 user_id 和 group_id 都未提供时抛出
|
||||
"""
|
||||
if not user_id and not group_id:
|
||||
raise ValueError("UserID 或 GroupID 必须提供其一")
|
||||
if group_id:
|
||||
components = [platform, group_id]
|
||||
else:
|
||||
components = [platform, user_id, "private"]
|
||||
return hashlib.md5("_".join(components).encode()).hexdigest()
|
||||
|
|
@ -349,7 +349,7 @@ class TargetItem(ConfigBase):
|
|||
"x-icon": "hash",
|
||||
},
|
||||
)
|
||||
"""用户ID,与平台一起留空表示全局"""
|
||||
"""用户/群ID,与平台一起留空表示全局"""
|
||||
|
||||
rule_type: Literal["group", "private"] = Field(
|
||||
default="group",
|
||||
|
|
|
|||
Loading…
Reference in New Issue