mirror of https://github.com/Mai-with-u/MaiBot.git
HFC对应修改(部分)
parent
12d4f236be
commit
a8e8f6b7b3
|
|
@ -1,11 +1,11 @@
|
||||||
import traceback
|
from typing import Optional, Dict
|
||||||
from typing import Any, Optional, Dict
|
|
||||||
|
import traceback
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.message_receive.chat_manager import chat_manager
|
||||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||||
from src.chat.brain_chat.brain_chat import BrainChatting
|
from src.chat.brain_chat.brain_chat import BrainChatting
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
|
|
||||||
logger = get_logger("heartflow")
|
logger = get_logger("heartflow")
|
||||||
|
|
||||||
|
|
@ -14,29 +14,26 @@ class Heartflow:
|
||||||
"""主心流协调器,负责初始化并协调聊天"""
|
"""主心流协调器,负责初始化并协调聊天"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
|
self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {}
|
||||||
|
|
||||||
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
|
async def get_or_create_heartflow_chat(self, session_id: str) -> Optional[HeartFChatting | BrainChatting]:
|
||||||
"""获取或创建一个新的HeartFChatting实例"""
|
"""获取或创建一个新的HeartFChatting实例"""
|
||||||
try:
|
try:
|
||||||
if chat_id in self.heartflow_chat_list:
|
if chat := self.heartflow_chat_list.get(session_id):
|
||||||
if chat := self.heartflow_chat_list.get(chat_id):
|
return chat
|
||||||
return chat
|
chat_session = chat_manager.get_session_by_session_id(session_id)
|
||||||
else:
|
if not chat_session:
|
||||||
chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
|
raise ValueError(f"未找到 session_id={session_id} 的聊天流")
|
||||||
if not chat_stream:
|
new_chat = (
|
||||||
raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
|
HeartFChatting(session_id=session_id) if chat_session.group_id else BrainChatting(session_id=session_id)
|
||||||
if chat_stream.group_info:
|
)
|
||||||
new_chat = HeartFChatting(chat_id=chat_id)
|
await new_chat.start()
|
||||||
else:
|
self.heartflow_chat_list[session_id] = new_chat
|
||||||
new_chat = BrainChatting(chat_id=chat_id)
|
return new_chat
|
||||||
await new_chat.start()
|
|
||||||
self.heartflow_chat_list[chat_id] = new_chat
|
|
||||||
return new_chat
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True)
|
logger.error(f"创建心流聊天 {session_id} 失败: {e}", exc_info=True)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
heartflow = Heartflow()
|
heartflow = Heartflow()
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,16 @@
|
||||||
import re
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
import traceback
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
|
||||||
from src.chat.heart_flow.heartflow import heartflow
|
from src.chat.heart_flow.heartflow import heartflow
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
|
||||||
from src.chat.utils.chat_message_builder import replace_user_references
|
# from src.chat.utils.chat_message_builder import replace_user_references
|
||||||
|
from src.common.utils.utils_message import MessageUtils
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
from sqlmodel import select, col
|
|
||||||
from src.common.database.database import get_db_session
|
|
||||||
from src.common.database.database_model import Images, ImageType
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
from src.chat.message_receive.message import SessionMessage
|
||||||
|
|
||||||
logger = get_logger("chat")
|
logger = get_logger("chat")
|
||||||
|
|
||||||
|
|
@ -24,10 +19,9 @@ class HeartFCMessageReceiver:
|
||||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化心流处理器,创建消息存储实例"""
|
pass
|
||||||
self.storage = MessageStorage()
|
|
||||||
|
|
||||||
async def process_message(self, message: MessageRecv) -> None:
|
async def process_message(self, message: "SessionMessage"):
|
||||||
"""处理接收到的原始消息数据
|
"""处理接收到的原始消息数据
|
||||||
|
|
||||||
主要流程:
|
主要流程:
|
||||||
|
|
@ -38,7 +32,7 @@ class HeartFCMessageReceiver:
|
||||||
5. 关系处理
|
5. 关系处理
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_data: 原始消息字符串
|
message: SessionMessage对象,包含原始消息数据和相关信息
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 通知消息不处理
|
# 通知消息不处理
|
||||||
|
|
@ -48,70 +42,46 @@ class HeartFCMessageReceiver:
|
||||||
|
|
||||||
# 1. 消息解析与初始化
|
# 1. 消息解析与初始化
|
||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
chat = message.chat_stream
|
group_info = message.message_info.group_info
|
||||||
if userinfo is None or message.message_info.platform is None:
|
if userinfo is None or message.platform is None:
|
||||||
raise ValueError("message userinfo or platform is missing")
|
raise ValueError("message userinfo or platform is missing")
|
||||||
if userinfo.user_id is None or userinfo.user_nickname is None:
|
if userinfo.user_id is None or userinfo.user_nickname is None:
|
||||||
raise ValueError("message userinfo id or nickname is missing")
|
raise ValueError("message userinfo id or nickname is missing")
|
||||||
user_id = userinfo.user_id
|
user_id = userinfo.user_id
|
||||||
nickname = userinfo.user_nickname
|
nickname = userinfo.user_nickname
|
||||||
|
|
||||||
# 2. 计算at信息
|
# 2. 计算at信息 (现在转移给Adapter完成)
|
||||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
# is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||||
# print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
|
# # print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
|
||||||
message.is_mentioned = is_mentioned
|
# message.is_mentioned = is_mentioned
|
||||||
message.is_at = is_at
|
# message.is_at = is_at
|
||||||
message.reply_probability_boost = reply_probability_boost
|
|
||||||
|
|
||||||
await self.storage.store_message(message, chat)
|
MessageUtils.store_message_to_db(message) # 存储消息到数据库
|
||||||
|
|
||||||
await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
await heartflow.get_or_create_heartflow_chat(message.session_id)
|
||||||
|
|
||||||
# 3. 日志记录
|
# 3. 日志记录
|
||||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
mes_name = group_info.group_name if group_info else "私聊"
|
||||||
|
|
||||||
# 用这个pattern截取出id部分,picid是一个list,并替换成对应的图片描述
|
# TODO: 修复引用格式替换
|
||||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
# # 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||||
picid_list = re.findall(picid_pattern, message.processed_plain_text)
|
# processed_plain_text = replace_user_references(
|
||||||
|
# processed_text, message.message_info.platform, replace_bot_name=True
|
||||||
|
# )
|
||||||
|
# # if not processed_plain_text:
|
||||||
|
# # print(message)
|
||||||
|
|
||||||
# 创建替换后的文本
|
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||||
processed_text = message.processed_plain_text
|
|
||||||
if picid_list:
|
|
||||||
for picid in picid_list:
|
|
||||||
with get_db_session() as session:
|
|
||||||
statement = (
|
|
||||||
select(Images).where(
|
|
||||||
(col(Images.id) == int(picid)) & (col(Images.image_type) == ImageType.IMAGE)
|
|
||||||
)
|
|
||||||
if picid.isdigit()
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
image = session.exec(statement).first() if statement is not None else None
|
|
||||||
if image and image.description:
|
|
||||||
# 将[picid:xxxx]替换成图片描述
|
|
||||||
processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]")
|
|
||||||
else:
|
|
||||||
# 如果没有找到图片描述,则移除[picid:xxxx]标记
|
|
||||||
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
|
|
||||||
|
|
||||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
|
||||||
processed_plain_text = replace_user_references(
|
|
||||||
processed_text, message.message_info.platform, replace_bot_name=True
|
|
||||||
)
|
|
||||||
# if not processed_plain_text:
|
|
||||||
# print(message)
|
|
||||||
|
|
||||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
|
|
||||||
|
|
||||||
# 如果是群聊,获取群号和群昵称
|
# 如果是群聊,获取群号和群昵称
|
||||||
group_id = None
|
group_id = None
|
||||||
group_nick_name = None
|
group_nick_name = None
|
||||||
if chat.group_info:
|
if group_info:
|
||||||
group_id = chat.group_info.group_id
|
group_id = group_info.group_id
|
||||||
group_nick_name = userinfo.user_cardname
|
group_nick_name = userinfo.user_cardname
|
||||||
|
|
||||||
_ = Person.register_person(
|
_ = Person.register_person(
|
||||||
platform=message.message_info.platform,
|
platform=message.platform,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
nickname=nickname,
|
nickname=nickname,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
|
|
|
||||||
|
|
@ -1,250 +0,0 @@
|
||||||
from datetime import datetime
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from sqlmodel import col, select
|
|
||||||
from src.common.database.database import get_db_session
|
|
||||||
from src.common.database.database_model import Images, ImageType, Messages
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.data_models.message_component_model import MessageSequence, TextComponent
|
|
||||||
from src.common.utils.utils_message import MessageUtils
|
|
||||||
from .chat_stream import ChatStream
|
|
||||||
from .message import MessageRecv, MessageSending
|
|
||||||
|
|
||||||
logger = get_logger("message_storage")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageStorage:
|
|
||||||
@staticmethod
|
|
||||||
def _coerce_str_list(value: object) -> list[str]:
|
|
||||||
if isinstance(value, list):
|
|
||||||
return [str(item) for item in value]
|
|
||||||
if isinstance(value, tuple):
|
|
||||||
return [str(item) for item in value]
|
|
||||||
if isinstance(value, set):
|
|
||||||
return [str(item) for item in value]
|
|
||||||
if isinstance(value, str):
|
|
||||||
return [value]
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_str(mapping: Mapping[str, object], key: str, default: str = "") -> str:
|
|
||||||
value = mapping.get(key)
|
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_optional_str(mapping: Mapping[str, object], key: str) -> str | None:
|
|
||||||
value = mapping.get(key)
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _serialize_keywords(keywords: list[str] | None) -> str:
|
|
||||||
"""将关键词列表序列化为JSON字符串"""
|
|
||||||
if isinstance(keywords, list):
|
|
||||||
return json.dumps(keywords, ensure_ascii=False)
|
|
||||||
return "[]"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _deserialize_keywords(keywords_str: str) -> list[str]:
|
|
||||||
"""将JSON字符串反序列化为关键词列表"""
|
|
||||||
if not keywords_str:
|
|
||||||
return []
|
|
||||||
try:
|
|
||||||
parsed = cast(object, json.loads(keywords_str))
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
return []
|
|
||||||
if isinstance(parsed, list):
|
|
||||||
return [str(item) for item in parsed]
|
|
||||||
if isinstance(parsed, str):
|
|
||||||
return [parsed]
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
|
|
||||||
"""存储消息到数据库"""
|
|
||||||
try:
|
|
||||||
# 通知消息不存储
|
|
||||||
if isinstance(message, MessageRecv) and message.is_notify:
|
|
||||||
logger.debug("通知消息,跳过存储")
|
|
||||||
return
|
|
||||||
|
|
||||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
|
||||||
|
|
||||||
# print(message)
|
|
||||||
|
|
||||||
processed_plain_text = message.processed_plain_text
|
|
||||||
|
|
||||||
# print(processed_plain_text)
|
|
||||||
|
|
||||||
if processed_plain_text:
|
|
||||||
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
|
|
||||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
|
||||||
else:
|
|
||||||
filtered_processed_plain_text = ""
|
|
||||||
|
|
||||||
if isinstance(message, MessageSending):
|
|
||||||
display_message = message.display_message
|
|
||||||
if display_message:
|
|
||||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
|
||||||
else:
|
|
||||||
filtered_display_message = ""
|
|
||||||
interest_value = 0
|
|
||||||
is_mentioned = False
|
|
||||||
is_at = False
|
|
||||||
reply_probability_boost = 0.0
|
|
||||||
reply_to = message.reply_to
|
|
||||||
priority_mode = ""
|
|
||||||
priority_info = {}
|
|
||||||
is_emoji = False
|
|
||||||
is_picture = False
|
|
||||||
is_notify = False
|
|
||||||
is_command = False
|
|
||||||
key_words = ""
|
|
||||||
key_words_lite = ""
|
|
||||||
selected_expressions = message.selected_expressions
|
|
||||||
intercept_message_level = 0
|
|
||||||
else:
|
|
||||||
filtered_display_message = ""
|
|
||||||
interest_value = message.interest_value
|
|
||||||
is_mentioned = message.is_mentioned
|
|
||||||
is_at = message.is_at
|
|
||||||
reply_probability_boost = message.reply_probability_boost
|
|
||||||
reply_to = ""
|
|
||||||
priority_mode = message.priority_mode
|
|
||||||
priority_info = message.priority_info
|
|
||||||
is_emoji = message.is_emoji
|
|
||||||
is_picture = message.is_picid
|
|
||||||
is_notify = message.is_notify
|
|
||||||
is_command = message.is_command
|
|
||||||
intercept_message_level = getattr(message, "intercept_message_level", 0)
|
|
||||||
# 序列化关键词列表为JSON字符串
|
|
||||||
key_words = MessageStorage._serialize_keywords(MessageStorage._coerce_str_list(message.key_words))
|
|
||||||
key_words_lite = MessageStorage._serialize_keywords(
|
|
||||||
MessageStorage._coerce_str_list(message.key_words_lite)
|
|
||||||
)
|
|
||||||
selected_expressions = ""
|
|
||||||
|
|
||||||
chat_info_dict = cast(dict[str, object], chat_stream.to_dict())
|
|
||||||
if message.message_info.user_info is None:
|
|
||||||
raise ValueError("message.user_info is required")
|
|
||||||
user_info_dict = cast(dict[str, object], message.message_info.user_info.to_dict())
|
|
||||||
|
|
||||||
# message_id 现在是 TextField,直接使用字符串值
|
|
||||||
msg_id = message.message_info.message_id or ""
|
|
||||||
|
|
||||||
# 安全地获取 group_info, 如果为 None 则视为空字典
|
|
||||||
group_info_from_chat = cast(dict[str, object], chat_info_dict.get("group_info") or {})
|
|
||||||
|
|
||||||
additional_config: dict[str, object] = dict(message.message_info.additional_config or {})
|
|
||||||
additional_config.update(
|
|
||||||
{
|
|
||||||
"interest_value": interest_value,
|
|
||||||
"priority_mode": priority_mode,
|
|
||||||
"priority_info": priority_info,
|
|
||||||
"reply_probability_boost": reply_probability_boost,
|
|
||||||
"intercept_message_level": intercept_message_level,
|
|
||||||
"key_words": key_words,
|
|
||||||
"key_words_lite": key_words_lite,
|
|
||||||
"selected_expressions": selected_expressions,
|
|
||||||
"is_picid": is_picture,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
processed_text_for_raw = filtered_processed_plain_text or filtered_display_message or ""
|
|
||||||
raw_sequence = MessageSequence([TextComponent(processed_text_for_raw)] if processed_text_for_raw else [])
|
|
||||||
raw_content = MessageUtils.from_MaiSeq_to_db_record_msg(raw_sequence)
|
|
||||||
|
|
||||||
timestamp_value = message.message_info.time
|
|
||||||
if timestamp_value is None:
|
|
||||||
raise ValueError("message.message_info.time is required")
|
|
||||||
db_message = Messages(
|
|
||||||
message_id=str(msg_id),
|
|
||||||
timestamp=datetime.fromtimestamp(float(timestamp_value)),
|
|
||||||
platform=MessageStorage._get_str(chat_info_dict, "platform"),
|
|
||||||
user_id=MessageStorage._get_str(user_info_dict, "user_id"),
|
|
||||||
user_nickname=MessageStorage._get_str(user_info_dict, "user_nickname"),
|
|
||||||
user_cardname=MessageStorage._get_optional_str(user_info_dict, "user_cardname"),
|
|
||||||
group_id=MessageStorage._get_optional_str(group_info_from_chat, "group_id"),
|
|
||||||
group_name=MessageStorage._get_optional_str(group_info_from_chat, "group_name"),
|
|
||||||
is_mentioned=bool(is_mentioned),
|
|
||||||
is_at=bool(is_at),
|
|
||||||
session_id=chat_stream.stream_id,
|
|
||||||
reply_to=reply_to,
|
|
||||||
is_emoji=is_emoji,
|
|
||||||
is_picture=is_picture,
|
|
||||||
is_command=is_command,
|
|
||||||
is_notify=is_notify,
|
|
||||||
raw_content=raw_content,
|
|
||||||
processed_plain_text=filtered_processed_plain_text,
|
|
||||||
display_message=filtered_display_message,
|
|
||||||
additional_config=json.dumps(additional_config, ensure_ascii=False),
|
|
||||||
)
|
|
||||||
with get_db_session() as session:
|
|
||||||
session.add(db_message)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("存储消息失败")
|
|
||||||
logger.error(f"消息:{message}")
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
# 如果需要其他存储相关的函数,可以在这里添加
|
|
||||||
@staticmethod
|
|
||||||
def update_message(mmc_message_id: str | None, qq_message_id: str | None) -> bool:
|
|
||||||
"""实时更新数据库的自身发送消息ID"""
|
|
||||||
try:
|
|
||||||
if not qq_message_id:
|
|
||||||
logger.info("消息不存在message_id,无法更新")
|
|
||||||
return False
|
|
||||||
with get_db_session() as session:
|
|
||||||
statement = (
|
|
||||||
select(Messages)
|
|
||||||
.where(col(Messages.message_id) == mmc_message_id)
|
|
||||||
.order_by(col(Messages.timestamp).desc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
matched_message = session.exec(statement).first()
|
|
||||||
if matched_message:
|
|
||||||
matched_message.message_id = qq_message_id
|
|
||||||
session.add(matched_message)
|
|
||||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
|
||||||
return True
|
|
||||||
logger.debug("未找到匹配的消息")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"更新消息ID失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def replace_image_descriptions(text: str) -> str:
|
|
||||||
"""将[图片:描述]替换为[picid:image_id]"""
|
|
||||||
# 先检查文本中是否有图片标记
|
|
||||||
pattern = r"\[图片:([^\]]+)\]"
|
|
||||||
matches = re.findall(pattern, text)
|
|
||||||
|
|
||||||
if not matches:
|
|
||||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
|
||||||
return text
|
|
||||||
|
|
||||||
def replace_match(match: re.Match[str]) -> str:
|
|
||||||
description = match.group(1).strip()
|
|
||||||
try:
|
|
||||||
with get_db_session() as session:
|
|
||||||
statement = (
|
|
||||||
select(Images)
|
|
||||||
.where((col(Images.description) == description) & (col(Images.image_type) == ImageType.IMAGE))
|
|
||||||
.order_by(col(Images.record_time).desc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
image_record = session.exec(statement).first()
|
|
||||||
return f"[picid:{image_record.id}]" if image_record else match.group(0)
|
|
||||||
except Exception:
|
|
||||||
return match.group(0)
|
|
||||||
|
|
||||||
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)
|
|
||||||
|
|
@ -1,89 +0,0 @@
|
||||||
"""
|
|
||||||
TOML文件工具函数 - 保留格式和注释
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import tomlkit
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def save_toml_with_format(data: dict[str, Any], file_path: str) -> None:
|
|
||||||
"""
|
|
||||||
保存TOML数据到文件,保留现有格式(如果文件存在)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 要保存的数据字典
|
|
||||||
file_path: 文件路径
|
|
||||||
"""
|
|
||||||
# 如果文件不存在,直接创建
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
|
||||||
tomlkit.dump(data, f)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 如果文件存在,尝试读取现有文件以保留格式
|
|
||||||
try:
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
existing_doc = tomlkit.load(f)
|
|
||||||
except Exception:
|
|
||||||
# 如果读取失败,直接覆盖
|
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
|
||||||
tomlkit.dump(data, f)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 递归更新,保留现有格式
|
|
||||||
_merge_toml_preserving_format(existing_doc, data)
|
|
||||||
|
|
||||||
# 保存
|
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
|
||||||
tomlkit.dump(existing_doc, f)
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_toml_preserving_format(target: dict[str, Any], source: dict[str, Any]) -> None:
|
|
||||||
"""
|
|
||||||
递归合并source到target,保留target中的格式和注释
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target: 目标文档(保留格式)
|
|
||||||
source: 源数据(新数据)
|
|
||||||
"""
|
|
||||||
for key, value in source.items():
|
|
||||||
if key in target:
|
|
||||||
# 如果两个都是字典且都是表格,递归合并
|
|
||||||
if isinstance(value, dict) and isinstance(target[key], dict):
|
|
||||||
if hasattr(target[key], "items"): # 确实是字典/表格
|
|
||||||
_merge_toml_preserving_format(target[key], value)
|
|
||||||
else:
|
|
||||||
target[key] = value
|
|
||||||
else:
|
|
||||||
# 其他情况直接替换
|
|
||||||
target[key] = value
|
|
||||||
else:
|
|
||||||
# 新键直接添加
|
|
||||||
target[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
def _update_toml_doc(target: dict[str, Any], source: dict[str, Any]) -> None:
|
|
||||||
"""
|
|
||||||
更新TOML文档中的字段,保留现有的格式和注释
|
|
||||||
|
|
||||||
这是一个递归函数,用于在部分更新配置时保留现有的格式和注释。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target: 目标表格(会被修改)
|
|
||||||
source: 源数据(新数据)
|
|
||||||
"""
|
|
||||||
for key, value in source.items():
|
|
||||||
if key in target:
|
|
||||||
# 如果两个都是字典,递归更新
|
|
||||||
if isinstance(value, dict) and isinstance(target[key], dict):
|
|
||||||
if hasattr(target[key], "items"): # 确实是表格
|
|
||||||
_update_toml_doc(target[key], value)
|
|
||||||
else:
|
|
||||||
target[key] = value
|
|
||||||
else:
|
|
||||||
# 直接更新值,保留注释
|
|
||||||
target[key] = value
|
|
||||||
else:
|
|
||||||
# 新键直接添加
|
|
||||||
target[key] = value
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from maim_message import MessageBase, Seg
|
from maim_message import MessageBase, Seg
|
||||||
from typing import List, Tuple, Optional, Sequence
|
from typing import List, Tuple, Optional, Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
@ -19,6 +19,9 @@ from src.common.data_models.message_component_data_model import (
|
||||||
)
|
)
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.message import SessionMessage
|
||||||
|
|
||||||
|
|
||||||
class MessageUtils:
|
class MessageUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -135,3 +138,12 @@ class MessageUtils:
|
||||||
else:
|
else:
|
||||||
components = [platform, user_id, "private"]
|
components = [platform, user_id, "private"]
|
||||||
return hashlib.md5("_".join(components).encode()).hexdigest()
|
return hashlib.md5("_".join(components).encode()).hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def store_message_to_db(message: "SessionMessage"):
|
||||||
|
"""存储消息到数据库"""
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
db_message = message.to_db_instance()
|
||||||
|
session.add(db_message)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue