mirror of https://github.com/Mai-with-u/MaiBot.git
拆分新的utils
parent
e253a2ed2a
commit
b9faed4924
|
|
@ -9,7 +9,7 @@ from src.common.logger import get_logger
|
||||||
from src.common.data_models.chat_session_data_model import MaiChatSession
|
from src.common.data_models.chat_session_data_model import MaiChatSession
|
||||||
from src.common.database.database_model import ChatSession
|
from src.common.database.database_model import ChatSession
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.utils.utils_message import MessageUtils
|
from src.common.utils.utils_session import SessionUtils
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .message import SessionMessage
|
from .message import SessionMessage
|
||||||
|
|
@ -95,7 +95,7 @@ class ChatManager:
|
||||||
Raises:
|
Raises:
|
||||||
Exception: 获取或创建会话时发生错误
|
Exception: 获取或创建会话时发生错误
|
||||||
"""
|
"""
|
||||||
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||||
if session := self.get_session_by_session_id(session_id):
|
if session := self.get_session_by_session_id(session_id):
|
||||||
session.update_active_time()
|
session.update_active_time()
|
||||||
return session
|
return session
|
||||||
|
|
@ -131,7 +131,7 @@ class ChatManager:
|
||||||
raise ValueError("消息缺少平台信息")
|
raise ValueError("消息缺少平台信息")
|
||||||
user_id = message.message_info.user_info.user_id
|
user_id = message.message_info.user_info.user_id
|
||||||
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
||||||
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||||
message.session_id = session_id # 确保消息的session_id正确设置
|
message.session_id = session_id # 确保消息的session_id正确设置
|
||||||
self.last_messages[session_id] = message
|
self.last_messages[session_id] = message
|
||||||
|
|
||||||
|
|
@ -199,7 +199,7 @@ class ChatManager:
|
||||||
Returns:
|
Returns:
|
||||||
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
|
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
|
||||||
"""
|
"""
|
||||||
session_id = MessageUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
|
||||||
return self.get_session_by_session_id(session_id)
|
return self.get_session_by_session_id(session_id)
|
||||||
|
|
||||||
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
|
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
|
||||||
|
|
|
||||||
|
|
@ -128,17 +128,6 @@ class MessageUtils:
|
||||||
(False, None),
|
(False, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
|
|
||||||
"""计算会话ID"""
|
|
||||||
if not user_id and not group_id:
|
|
||||||
raise ValueError("UserID 或 GroupID 必须提供其一")
|
|
||||||
if group_id:
|
|
||||||
components = [platform, group_id]
|
|
||||||
else:
|
|
||||||
components = [platform, user_id, "private"]
|
|
||||||
return hashlib.md5("_".join(components).encode()).hexdigest()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def store_message_to_db(message: "SessionMessage"):
|
def store_message_to_db(message: "SessionMessage"):
|
||||||
"""存储消息到数据库"""
|
"""存储消息到数据库"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
class SessionUtils:
|
||||||
|
@staticmethod
|
||||||
|
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
|
||||||
|
"""计算session_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台名称
|
||||||
|
user_id: 用户ID(如果是私聊)
|
||||||
|
group_id: 群ID(如果是群聊)
|
||||||
|
Returns:
|
||||||
|
str: 计算得到的会话ID
|
||||||
|
Raises:
|
||||||
|
ValueError: 当 user_id 和 group_id 都未提供时抛出
|
||||||
|
"""
|
||||||
|
if not user_id and not group_id:
|
||||||
|
raise ValueError("UserID 或 GroupID 必须提供其一")
|
||||||
|
if group_id:
|
||||||
|
components = [platform, group_id]
|
||||||
|
else:
|
||||||
|
components = [platform, user_id, "private"]
|
||||||
|
return hashlib.md5("_".join(components).encode()).hexdigest()
|
||||||
|
|
@ -349,7 +349,7 @@ class TargetItem(ConfigBase):
|
||||||
"x-icon": "hash",
|
"x-icon": "hash",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
"""用户ID,与平台一起留空表示全局"""
|
"""用户/群ID,与平台一起留空表示全局"""
|
||||||
|
|
||||||
rule_type: Literal["group", "private"] = Field(
|
rule_type: Literal["group", "private"] = Field(
|
||||||
default="group",
|
default="group",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue