mirror of https://github.com/Mai-with-u/MaiBot.git
HFC对应修改(部分)
parent
12d4f236be
commit
a8e8f6b7b3
|
|
@ -1,11 +1,11 @@
|
|||
import traceback
|
||||
from typing import Any, Optional, Dict
|
||||
from typing import Optional, Dict
|
||||
|
||||
import traceback
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||
from src.chat.brain_chat.brain_chat import BrainChatting
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
|
|
@ -14,29 +14,26 @@ class Heartflow:
|
|||
"""主心流协调器,负责初始化并协调聊天"""
|
||||
|
||||
def __init__(self):
|
||||
self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
|
||||
self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {}
|
||||
|
||||
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
|
||||
async def get_or_create_heartflow_chat(self, session_id: str) -> Optional[HeartFChatting | BrainChatting]:
|
||||
"""获取或创建一个新的HeartFChatting实例"""
|
||||
try:
|
||||
if chat_id in self.heartflow_chat_list:
|
||||
if chat := self.heartflow_chat_list.get(chat_id):
|
||||
return chat
|
||||
else:
|
||||
chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
|
||||
if chat_stream.group_info:
|
||||
new_chat = HeartFChatting(chat_id=chat_id)
|
||||
else:
|
||||
new_chat = BrainChatting(chat_id=chat_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[chat_id] = new_chat
|
||||
return new_chat
|
||||
if chat := self.heartflow_chat_list.get(session_id):
|
||||
return chat
|
||||
chat_session = chat_manager.get_session_by_session_id(session_id)
|
||||
if not chat_session:
|
||||
raise ValueError(f"未找到 session_id={session_id} 的聊天流")
|
||||
new_chat = (
|
||||
HeartFChatting(session_id=session_id) if chat_session.group_id else BrainChatting(session_id=session_id)
|
||||
)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[session_id] = new_chat
|
||||
return new_chat
|
||||
except Exception as e:
|
||||
logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True)
|
||||
logger.error(f"创建心流聊天 {session_id} 失败: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
heartflow = Heartflow()
|
||||
|
|
|
|||
|
|
@ -1,21 +1,16 @@
|
|||
import re
|
||||
import traceback
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
import traceback
|
||||
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.chat_message_builder import replace_user_references
|
||||
|
||||
# from src.chat.utils.chat_message_builder import replace_user_references
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import Person
|
||||
from sqlmodel import select, col
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
|
@ -24,10 +19,9 @@ class HeartFCMessageReceiver:
|
|||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
pass
|
||||
|
||||
async def process_message(self, message: MessageRecv) -> None:
|
||||
async def process_message(self, message: "SessionMessage"):
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
|
|
@ -38,7 +32,7 @@ class HeartFCMessageReceiver:
|
|||
5. 关系处理
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
message: SessionMessage对象,包含原始消息数据和相关信息
|
||||
"""
|
||||
try:
|
||||
# 通知消息不处理
|
||||
|
|
@ -48,70 +42,46 @@ class HeartFCMessageReceiver:
|
|||
|
||||
# 1. 消息解析与初始化
|
||||
userinfo = message.message_info.user_info
|
||||
chat = message.chat_stream
|
||||
if userinfo is None or message.message_info.platform is None:
|
||||
group_info = message.message_info.group_info
|
||||
if userinfo is None or message.platform is None:
|
||||
raise ValueError("message userinfo or platform is missing")
|
||||
if userinfo.user_id is None or userinfo.user_nickname is None:
|
||||
raise ValueError("message userinfo id or nickname is missing")
|
||||
user_id = userinfo.user_id
|
||||
nickname = userinfo.user_nickname
|
||||
|
||||
# 2. 计算at信息
|
||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
# print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
|
||||
message.is_mentioned = is_mentioned
|
||||
message.is_at = is_at
|
||||
message.reply_probability_boost = reply_probability_boost
|
||||
# 2. 计算at信息 (现在转移给Adapter完成)
|
||||
# is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
# # print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
|
||||
# message.is_mentioned = is_mentioned
|
||||
# message.is_at = is_at
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
MessageUtils.store_message_to_db(message) # 存储消息到数据库
|
||||
|
||||
await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
await heartflow.get_or_create_heartflow_chat(message.session_id)
|
||||
|
||||
# 3. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
mes_name = group_info.group_name if group_info else "私聊"
|
||||
|
||||
# 用这个pattern截取出id部分,picid是一个list,并替换成对应的图片描述
|
||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||
picid_list = re.findall(picid_pattern, message.processed_plain_text)
|
||||
# TODO: 修复引用格式替换
|
||||
# # 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||
# processed_plain_text = replace_user_references(
|
||||
# processed_text, message.message_info.platform, replace_bot_name=True
|
||||
# )
|
||||
# # if not processed_plain_text:
|
||||
# # print(message)
|
||||
|
||||
# 创建替换后的文本
|
||||
processed_text = message.processed_plain_text
|
||||
if picid_list:
|
||||
for picid in picid_list:
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Images).where(
|
||||
(col(Images.id) == int(picid)) & (col(Images.image_type) == ImageType.IMAGE)
|
||||
)
|
||||
if picid.isdigit()
|
||||
else None
|
||||
)
|
||||
image = session.exec(statement).first() if statement is not None else None
|
||||
if image and image.description:
|
||||
# 将[picid:xxxx]替换成图片描述
|
||||
processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]")
|
||||
else:
|
||||
# 如果没有找到图片描述,则移除[picid:xxxx]标记
|
||||
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
|
||||
|
||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||
processed_plain_text = replace_user_references(
|
||||
processed_text, message.message_info.platform, replace_bot_name=True
|
||||
)
|
||||
# if not processed_plain_text:
|
||||
# print(message)
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
|
||||
# 如果是群聊,获取群号和群昵称
|
||||
group_id = None
|
||||
group_nick_name = None
|
||||
if chat.group_info:
|
||||
group_id = chat.group_info.group_id
|
||||
if group_info:
|
||||
group_id = group_info.group_id
|
||||
group_nick_name = userinfo.user_cardname
|
||||
|
||||
_ = Person.register_person(
|
||||
platform=message.message_info.platform,
|
||||
platform=message.platform,
|
||||
user_id=user_id,
|
||||
nickname=nickname,
|
||||
group_id=group_id,
|
||||
|
|
|
|||
|
|
@ -1,250 +0,0 @@
|
|||
from datetime import datetime
|
||||
from collections.abc import Mapping
|
||||
from typing import cast
|
||||
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from sqlmodel import col, select
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType, Messages
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.message_component_model import MessageSequence, TextComponent
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageRecv, MessageSending
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
@staticmethod
|
||||
def _coerce_str_list(value: object) -> list[str]:
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [str(item) for item in value]
|
||||
if isinstance(value, set):
|
||||
return [str(item) for item in value]
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_str(mapping: Mapping[str, object], key: str, default: str = "") -> str:
|
||||
value = mapping.get(key)
|
||||
if value is None:
|
||||
return default
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _get_optional_str(mapping: Mapping[str, object], key: str) -> str | None:
|
||||
value = mapping.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_keywords(keywords: list[str] | None) -> str:
|
||||
"""将关键词列表序列化为JSON字符串"""
|
||||
if isinstance(keywords, list):
|
||||
return json.dumps(keywords, ensure_ascii=False)
|
||||
return "[]"
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_keywords(keywords_str: str) -> list[str]:
|
||||
"""将JSON字符串反序列化为关键词列表"""
|
||||
if not keywords_str:
|
||||
return []
|
||||
try:
|
||||
parsed = cast(object, json.loads(keywords_str))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
if isinstance(parsed, str):
|
||||
return [parsed]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 通知消息不存储
|
||||
if isinstance(message, MessageRecv) and message.is_notify:
|
||||
logger.debug("通知消息,跳过存储")
|
||||
return
|
||||
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
# print(message)
|
||||
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
# print(processed_plain_text)
|
||||
|
||||
if processed_plain_text:
|
||||
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
if isinstance(message, MessageSending):
|
||||
display_message = message.display_message
|
||||
if display_message:
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
interest_value = 0
|
||||
is_mentioned = False
|
||||
is_at = False
|
||||
reply_probability_boost = 0.0
|
||||
reply_to = message.reply_to
|
||||
priority_mode = ""
|
||||
priority_info = {}
|
||||
is_emoji = False
|
||||
is_picture = False
|
||||
is_notify = False
|
||||
is_command = False
|
||||
key_words = ""
|
||||
key_words_lite = ""
|
||||
selected_expressions = message.selected_expressions
|
||||
intercept_message_level = 0
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
interest_value = message.interest_value
|
||||
is_mentioned = message.is_mentioned
|
||||
is_at = message.is_at
|
||||
reply_probability_boost = message.reply_probability_boost
|
||||
reply_to = ""
|
||||
priority_mode = message.priority_mode
|
||||
priority_info = message.priority_info
|
||||
is_emoji = message.is_emoji
|
||||
is_picture = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
intercept_message_level = getattr(message, "intercept_message_level", 0)
|
||||
# 序列化关键词列表为JSON字符串
|
||||
key_words = MessageStorage._serialize_keywords(MessageStorage._coerce_str_list(message.key_words))
|
||||
key_words_lite = MessageStorage._serialize_keywords(
|
||||
MessageStorage._coerce_str_list(message.key_words_lite)
|
||||
)
|
||||
selected_expressions = ""
|
||||
|
||||
chat_info_dict = cast(dict[str, object], chat_stream.to_dict())
|
||||
if message.message_info.user_info is None:
|
||||
raise ValueError("message.user_info is required")
|
||||
user_info_dict = cast(dict[str, object], message.message_info.user_info.to_dict())
|
||||
|
||||
# message_id 现在是 TextField,直接使用字符串值
|
||||
msg_id = message.message_info.message_id or ""
|
||||
|
||||
# 安全地获取 group_info, 如果为 None 则视为空字典
|
||||
group_info_from_chat = cast(dict[str, object], chat_info_dict.get("group_info") or {})
|
||||
|
||||
additional_config: dict[str, object] = dict(message.message_info.additional_config or {})
|
||||
additional_config.update(
|
||||
{
|
||||
"interest_value": interest_value,
|
||||
"priority_mode": priority_mode,
|
||||
"priority_info": priority_info,
|
||||
"reply_probability_boost": reply_probability_boost,
|
||||
"intercept_message_level": intercept_message_level,
|
||||
"key_words": key_words,
|
||||
"key_words_lite": key_words_lite,
|
||||
"selected_expressions": selected_expressions,
|
||||
"is_picid": is_picture,
|
||||
}
|
||||
)
|
||||
processed_text_for_raw = filtered_processed_plain_text or filtered_display_message or ""
|
||||
raw_sequence = MessageSequence([TextComponent(processed_text_for_raw)] if processed_text_for_raw else [])
|
||||
raw_content = MessageUtils.from_MaiSeq_to_db_record_msg(raw_sequence)
|
||||
|
||||
timestamp_value = message.message_info.time
|
||||
if timestamp_value is None:
|
||||
raise ValueError("message.message_info.time is required")
|
||||
db_message = Messages(
|
||||
message_id=str(msg_id),
|
||||
timestamp=datetime.fromtimestamp(float(timestamp_value)),
|
||||
platform=MessageStorage._get_str(chat_info_dict, "platform"),
|
||||
user_id=MessageStorage._get_str(user_info_dict, "user_id"),
|
||||
user_nickname=MessageStorage._get_str(user_info_dict, "user_nickname"),
|
||||
user_cardname=MessageStorage._get_optional_str(user_info_dict, "user_cardname"),
|
||||
group_id=MessageStorage._get_optional_str(group_info_from_chat, "group_id"),
|
||||
group_name=MessageStorage._get_optional_str(group_info_from_chat, "group_name"),
|
||||
is_mentioned=bool(is_mentioned),
|
||||
is_at=bool(is_at),
|
||||
session_id=chat_stream.stream_id,
|
||||
reply_to=reply_to,
|
||||
is_emoji=is_emoji,
|
||||
is_picture=is_picture,
|
||||
is_command=is_command,
|
||||
is_notify=is_notify,
|
||||
raw_content=raw_content,
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
additional_config=json.dumps(additional_config, ensure_ascii=False),
|
||||
)
|
||||
with get_db_session() as session:
|
||||
session.add(db_message)
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
logger.error(f"消息:{message}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
@staticmethod
|
||||
def update_message(mmc_message_id: str | None, qq_message_id: str | None) -> bool:
|
||||
"""实时更新数据库的自身发送消息ID"""
|
||||
try:
|
||||
if not qq_message_id:
|
||||
logger.info("消息不存在message_id,无法更新")
|
||||
return False
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where(col(Messages.message_id) == mmc_message_id)
|
||||
.order_by(col(Messages.timestamp).desc())
|
||||
.limit(1)
|
||||
)
|
||||
matched_message = session.exec(statement).first()
|
||||
if matched_message:
|
||||
matched_message.message_id = qq_message_id
|
||||
session.add(matched_message)
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
return True
|
||||
logger.debug("未找到匹配的消息")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息ID失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def replace_image_descriptions(text: str) -> str:
|
||||
"""将[图片:描述]替换为[picid:image_id]"""
|
||||
# 先检查文本中是否有图片标记
|
||||
pattern = r"\[图片:([^\]]+)\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
if not matches:
|
||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
||||
return text
|
||||
|
||||
def replace_match(match: re.Match[str]) -> str:
|
||||
description = match.group(1).strip()
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Images)
|
||||
.where((col(Images.description) == description) & (col(Images.image_type) == ImageType.IMAGE))
|
||||
.order_by(col(Images.record_time).desc())
|
||||
.limit(1)
|
||||
)
|
||||
image_record = session.exec(statement).first()
|
||||
return f"[picid:{image_record.id}]" if image_record else match.group(0)
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
|
||||
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
TOML文件工具函数 - 保留格式和注释
|
||||
"""
|
||||
|
||||
import os
|
||||
import tomlkit
|
||||
from typing import Any
|
||||
|
||||
|
||||
def save_toml_with_format(data: dict[str, Any], file_path: str) -> None:
|
||||
"""
|
||||
保存TOML数据到文件,保留现有格式(如果文件存在)
|
||||
|
||||
Args:
|
||||
data: 要保存的数据字典
|
||||
file_path: 文件路径
|
||||
"""
|
||||
# 如果文件不存在,直接创建
|
||||
if not os.path.exists(file_path):
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(data, f)
|
||||
return
|
||||
|
||||
# 如果文件存在,尝试读取现有文件以保留格式
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
existing_doc = tomlkit.load(f)
|
||||
except Exception:
|
||||
# 如果读取失败,直接覆盖
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(data, f)
|
||||
return
|
||||
|
||||
# 递归更新,保留现有格式
|
||||
_merge_toml_preserving_format(existing_doc, data)
|
||||
|
||||
# 保存
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(existing_doc, f)
|
||||
|
||||
|
||||
def _merge_toml_preserving_format(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||
"""
|
||||
递归合并source到target,保留target中的格式和注释
|
||||
|
||||
Args:
|
||||
target: 目标文档(保留格式)
|
||||
source: 源数据(新数据)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
if key in target:
|
||||
# 如果两个都是字典且都是表格,递归合并
|
||||
if isinstance(value, dict) and isinstance(target[key], dict):
|
||||
if hasattr(target[key], "items"): # 确实是字典/表格
|
||||
_merge_toml_preserving_format(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
else:
|
||||
# 其他情况直接替换
|
||||
target[key] = value
|
||||
else:
|
||||
# 新键直接添加
|
||||
target[key] = value
|
||||
|
||||
|
||||
def _update_toml_doc(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||
"""
|
||||
更新TOML文档中的字段,保留现有的格式和注释
|
||||
|
||||
这是一个递归函数,用于在部分更新配置时保留现有的格式和注释。
|
||||
|
||||
Args:
|
||||
target: 目标表格(会被修改)
|
||||
source: 源数据(新数据)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
if key in target:
|
||||
# 如果两个都是字典,递归更新
|
||||
if isinstance(value, dict) and isinstance(target[key], dict):
|
||||
if hasattr(target[key], "items"): # 确实是表格
|
||||
_update_toml_doc(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
else:
|
||||
# 直接更新值,保留注释
|
||||
target[key] = value
|
||||
else:
|
||||
# 新键直接添加
|
||||
target[key] = value
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from maim_message import MessageBase, Seg
|
||||
from typing import List, Tuple, Optional, Sequence
|
||||
from typing import List, Tuple, Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
|
|
@ -19,6 +19,9 @@ from src.common.data_models.message_component_data_model import (
|
|||
)
|
||||
from src.config.config import global_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class MessageUtils:
|
||||
@staticmethod
|
||||
|
|
@ -135,3 +138,12 @@ class MessageUtils:
|
|||
else:
|
||||
components = [platform, user_id, "private"]
|
||||
return hashlib.md5("_".join(components).encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def store_message_to_db(message: "SessionMessage"):
|
||||
"""存储消息到数据库"""
|
||||
from src.common.database.database import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
db_message = message.to_db_instance()
|
||||
session.add(db_message)
|
||||
|
|
|
|||
Loading…
Reference in New Issue