diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index febff2d5..d6c62f94 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -1,11 +1,11 @@ -import traceback -from typing import Any, Optional, Dict +from typing import Optional, Dict + +import traceback -from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger +from src.chat.message_receive.chat_manager import chat_manager from src.chat.heart_flow.heartFC_chat import HeartFChatting from src.chat.brain_chat.brain_chat import BrainChatting -from src.chat.message_receive.chat_stream import ChatStream logger = get_logger("heartflow") @@ -14,29 +14,26 @@ class Heartflow: """主心流协调器,负责初始化并协调聊天""" def __init__(self): - self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {} + self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {} - async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]: + async def get_or_create_heartflow_chat(self, session_id: str) -> Optional[HeartFChatting | BrainChatting]: """获取或创建一个新的HeartFChatting实例""" try: - if chat_id in self.heartflow_chat_list: - if chat := self.heartflow_chat_list.get(chat_id): - return chat - else: - chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id) - if not chat_stream: - raise ValueError(f"未找到 chat_id={chat_id} 的聊天流") - if chat_stream.group_info: - new_chat = HeartFChatting(chat_id=chat_id) - else: - new_chat = BrainChatting(chat_id=chat_id) - await new_chat.start() - self.heartflow_chat_list[chat_id] = new_chat - return new_chat + if chat := self.heartflow_chat_list.get(session_id): + return chat + chat_session = chat_manager.get_session_by_session_id(session_id) + if not chat_session: + raise ValueError(f"未找到 session_id={session_id} 的聊天流") + new_chat = ( + HeartFChatting(session_id=session_id) if chat_session.group_id else BrainChatting(session_id=session_id) + ) + await new_chat.start() + self.heartflow_chat_list[session_id] = new_chat + return new_chat except Exception as e: - logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True) + logger.error(f"创建心流聊天 {session_id} 失败: {e}", exc_info=True) traceback.print_exc() - return None + return None heartflow = Heartflow() diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 9301980b..753049d1 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -1,21 +1,16 @@ -import re -import traceback - from typing import TYPE_CHECKING -from src.chat.message_receive.message import MessageRecv -from src.chat.message_receive.storage import MessageStorage +import traceback + from src.chat.heart_flow.heartflow import heartflow -from src.chat.utils.utils import is_mentioned_bot_in_message -from src.chat.utils.chat_message_builder import replace_user_references + +# from src.chat.utils.chat_message_builder import replace_user_references +from src.common.utils.utils_message import MessageUtils from src.common.logger import get_logger from src.person_info.person_info import Person -from sqlmodel import select, col -from src.common.database.database import get_db_session -from src.common.database.database_model import Images, ImageType if TYPE_CHECKING: - pass + from src.chat.message_receive.message import SessionMessage logger = get_logger("chat") @@ -24,10 +19,9 @@ class HeartFCMessageReceiver: """心流处理器,负责处理接收到的消息并计算兴趣度""" def __init__(self): - """初始化心流处理器,创建消息存储实例""" - self.storage = MessageStorage() + pass - async def process_message(self, message: MessageRecv) -> None: + async def process_message(self, message: "SessionMessage"): """处理接收到的原始消息数据 主要流程: @@ -38,7 +32,7 @@ class HeartFCMessageReceiver: 5. 关系处理 Args: - message_data: 原始消息字符串 + message: SessionMessage对象,包含原始消息数据和相关信息 """ try: # 通知消息不处理 @@ -48,70 +42,46 @@ class HeartFCMessageReceiver: # 1. 消息解析与初始化 userinfo = message.message_info.user_info - chat = message.chat_stream - if userinfo is None or message.message_info.platform is None: + group_info = message.message_info.group_info + if userinfo is None or message.platform is None: raise ValueError("message userinfo or platform is missing") if userinfo.user_id is None or userinfo.user_nickname is None: raise ValueError("message userinfo id or nickname is missing") user_id = userinfo.user_id nickname = userinfo.user_nickname - # 2. 计算at信息 - is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message) - # print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}") - message.is_mentioned = is_mentioned - message.is_at = is_at - message.reply_probability_boost = reply_probability_boost + # 2. 计算at信息 (现在转移给Adapter完成) + # is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message) + # # print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}") + # message.is_mentioned = is_mentioned + # message.is_at = is_at - await self.storage.store_message(message, chat) + MessageUtils.store_message_to_db(message) # 存储消息到数据库 - await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore + await heartflow.get_or_create_heartflow_chat(message.session_id) # 3. 日志记录 - mes_name = chat.group_info.group_name if chat.group_info else "私聊" + mes_name = group_info.group_name if group_info else "私聊" - # 用这个pattern截取出id部分,picid是一个list,并替换成对应的图片描述 - picid_pattern = r"\[picid:([^\]]+)\]" - picid_list = re.findall(picid_pattern, message.processed_plain_text) + # TODO: 修复引用格式替换 + # # 应用用户引用格式替换,将回复和@格式转换为可读格式 + # processed_plain_text = replace_user_references( + # processed_text, message.message_info.platform, replace_bot_name=True + # ) + # # if not processed_plain_text: + # # print(message) - # 创建替换后的文本 - processed_text = message.processed_plain_text - if picid_list: - for picid in picid_list: - with get_db_session() as session: - statement = ( - select(Images).where( - (col(Images.id) == int(picid)) & (col(Images.image_type) == ImageType.IMAGE) - ) - if picid.isdigit() - else None - ) - image = session.exec(statement).first() if statement is not None else None - if image and image.description: - # 将[picid:xxxx]替换成图片描述 - processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]") - else: - # 如果没有找到图片描述,则移除[picid:xxxx]标记 - processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]") - - # 应用用户引用格式替换,将回复和@格式转换为可读格式 - processed_plain_text = replace_user_references( - processed_text, message.message_info.platform, replace_bot_name=True - ) - # if not processed_plain_text: - # print(message) - - logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") + logger.info(f"[{mes_name}]{userinfo.user_nickname}:{message.processed_plain_text}") # 如果是群聊,获取群号和群昵称 group_id = None group_nick_name = None - if chat.group_info: - group_id = chat.group_info.group_id + if group_info: + group_id = group_info.group_id group_nick_name = userinfo.user_cardname _ = Person.register_person( - platform=message.message_info.platform, + platform=message.platform, user_id=user_id, nickname=nickname, group_id=group_id, diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py deleted file mode 100644 index 6defe19d..00000000 --- a/src/chat/message_receive/storage.py +++ /dev/null @@ -1,250 +0,0 @@ -from datetime import datetime -from collections.abc import Mapping -from typing import cast - -import json -import re -import traceback - -from sqlmodel import col, select -from src.common.database.database import get_db_session -from src.common.database.database_model import Images, ImageType, Messages -from src.common.logger import get_logger -from src.common.data_models.message_component_model import MessageSequence, TextComponent -from src.common.utils.utils_message import MessageUtils -from .chat_stream import ChatStream -from .message import MessageRecv, MessageSending - -logger = get_logger("message_storage") - - -class MessageStorage: - @staticmethod - def _coerce_str_list(value: object) -> list[str]: - if isinstance(value, list): - return [str(item) for item in value] - if isinstance(value, tuple): - return [str(item) for item in value] - if isinstance(value, set): - return [str(item) for item in value] - if isinstance(value, str): - return [value] - return [] - - @staticmethod - def _get_str(mapping: Mapping[str, object], key: str, default: str = "") -> str: - value = mapping.get(key) - if value is None: - return default - return str(value) - - @staticmethod - def _get_optional_str(mapping: Mapping[str, object], key: str) -> str | None: - value = mapping.get(key) - if value is None: - return None - return str(value) - - @staticmethod - def _serialize_keywords(keywords: list[str] | None) -> str: - """将关键词列表序列化为JSON字符串""" - if isinstance(keywords, list): - return json.dumps(keywords, ensure_ascii=False) - return "[]" - - @staticmethod - def _deserialize_keywords(keywords_str: str) -> list[str]: - """将JSON字符串反序列化为关键词列表""" - if not keywords_str: - return [] - try: - parsed = cast(object, json.loads(keywords_str)) - except (json.JSONDecodeError, TypeError): - return [] - if isinstance(parsed, list): - return [str(item) for item in parsed] - if isinstance(parsed, str): - return [parsed] - return [] - - @staticmethod - async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None: - """存储消息到数据库""" - try: - # 通知消息不存储 - if isinstance(message, MessageRecv) and message.is_notify: - logger.debug("通知消息,跳过存储") - return - - pattern = r".*?|.*?|.*?" - - # print(message) - - processed_plain_text = message.processed_plain_text - - # print(processed_plain_text) - - if processed_plain_text: - processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text) - filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL) - else: - filtered_processed_plain_text = "" - - if isinstance(message, MessageSending): - display_message = message.display_message - if display_message: - filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) - else: - filtered_display_message = "" - interest_value = 0 - is_mentioned = False - is_at = False - reply_probability_boost = 0.0 - reply_to = message.reply_to - priority_mode = "" - priority_info = {} - is_emoji = False - is_picture = False - is_notify = False - is_command = False - key_words = "" - key_words_lite = "" - selected_expressions = message.selected_expressions - intercept_message_level = 0 - else: - filtered_display_message = "" - interest_value = message.interest_value - is_mentioned = message.is_mentioned - is_at = message.is_at - reply_probability_boost = message.reply_probability_boost - reply_to = "" - priority_mode = message.priority_mode - priority_info = message.priority_info - is_emoji = message.is_emoji - is_picture = message.is_picid - is_notify = message.is_notify - is_command = message.is_command - intercept_message_level = getattr(message, "intercept_message_level", 0) - # 序列化关键词列表为JSON字符串 - key_words = MessageStorage._serialize_keywords(MessageStorage._coerce_str_list(message.key_words)) - key_words_lite = MessageStorage._serialize_keywords( - MessageStorage._coerce_str_list(message.key_words_lite) - ) - selected_expressions = "" - - chat_info_dict = cast(dict[str, object], chat_stream.to_dict()) - if message.message_info.user_info is None: - raise ValueError("message.user_info is required") - user_info_dict = cast(dict[str, object], message.message_info.user_info.to_dict()) - - # message_id 现在是 TextField,直接使用字符串值 - msg_id = message.message_info.message_id or "" - - # 安全地获取 group_info, 如果为 None 则视为空字典 - group_info_from_chat = cast(dict[str, object], chat_info_dict.get("group_info") or {}) - - additional_config: dict[str, object] = dict(message.message_info.additional_config or {}) - additional_config.update( - { - "interest_value": interest_value, - "priority_mode": priority_mode, - "priority_info": priority_info, - "reply_probability_boost": reply_probability_boost, - "intercept_message_level": intercept_message_level, - "key_words": key_words, - "key_words_lite": key_words_lite, - "selected_expressions": selected_expressions, - "is_picid": is_picture, - } - ) - processed_text_for_raw = filtered_processed_plain_text or filtered_display_message or "" - raw_sequence = MessageSequence([TextComponent(processed_text_for_raw)] if processed_text_for_raw else []) - raw_content = MessageUtils.from_MaiSeq_to_db_record_msg(raw_sequence) - - timestamp_value = message.message_info.time - if timestamp_value is None: - raise ValueError("message.message_info.time is required") - db_message = Messages( - message_id=str(msg_id), - timestamp=datetime.fromtimestamp(float(timestamp_value)), - platform=MessageStorage._get_str(chat_info_dict, "platform"), - user_id=MessageStorage._get_str(user_info_dict, "user_id"), - user_nickname=MessageStorage._get_str(user_info_dict, "user_nickname"), - user_cardname=MessageStorage._get_optional_str(user_info_dict, "user_cardname"), - group_id=MessageStorage._get_optional_str(group_info_from_chat, "group_id"), - group_name=MessageStorage._get_optional_str(group_info_from_chat, "group_name"), - is_mentioned=bool(is_mentioned), - is_at=bool(is_at), - session_id=chat_stream.stream_id, - reply_to=reply_to, - is_emoji=is_emoji, - is_picture=is_picture, - is_command=is_command, - is_notify=is_notify, - raw_content=raw_content, - processed_plain_text=filtered_processed_plain_text, - display_message=filtered_display_message, - additional_config=json.dumps(additional_config, ensure_ascii=False), - ) - with get_db_session() as session: - session.add(db_message) - except Exception: - logger.exception("存储消息失败") - logger.error(f"消息:{message}") - traceback.print_exc() - - # 如果需要其他存储相关的函数,可以在这里添加 - @staticmethod - def update_message(mmc_message_id: str | None, qq_message_id: str | None) -> bool: - """实时更新数据库的自身发送消息ID""" - try: - if not qq_message_id: - logger.info("消息不存在message_id,无法更新") - return False - with get_db_session() as session: - statement = ( - select(Messages) - .where(col(Messages.message_id) == mmc_message_id) - .order_by(col(Messages.timestamp).desc()) - .limit(1) - ) - matched_message = session.exec(statement).first() - if matched_message: - matched_message.message_id = qq_message_id - session.add(matched_message) - logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") - return True - logger.debug("未找到匹配的消息") - return False - - except Exception as e: - logger.error(f"更新消息ID失败: {e}") - return False - - @staticmethod - def replace_image_descriptions(text: str) -> str: - """将[图片:描述]替换为[picid:image_id]""" - # 先检查文本中是否有图片标记 - pattern = r"\[图片:([^\]]+)\]" - matches = re.findall(pattern, text) - - if not matches: - logger.debug("文本中没有图片标记,直接返回原文本") - return text - - def replace_match(match: re.Match[str]) -> str: - description = match.group(1).strip() - try: - with get_db_session() as session: - statement = ( - select(Images) - .where((col(Images.description) == description) & (col(Images.image_type) == ImageType.IMAGE)) - .order_by(col(Images.record_time).desc()) - .limit(1) - ) - image_record = session.exec(statement).first() - return f"[picid:{image_record.id}]" if image_record else match.group(0) - except Exception: - return match.group(0) - - return re.sub(r"\[图片:([^\]]+)\]", replace_match, text) diff --git a/src/common/toml_utils.py b/src/common/toml_utils.py deleted file mode 100644 index 8a9ecb99..00000000 --- a/src/common/toml_utils.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -TOML文件工具函数 - 保留格式和注释 -""" - -import os -import tomlkit -from typing import Any - - -def save_toml_with_format(data: dict[str, Any], file_path: str) -> None: - """ - 保存TOML数据到文件,保留现有格式(如果文件存在) - - Args: - data: 要保存的数据字典 - file_path: 文件路径 - """ - # 如果文件不存在,直接创建 - if not os.path.exists(file_path): - with open(file_path, "w", encoding="utf-8") as f: - tomlkit.dump(data, f) - return - - # 如果文件存在,尝试读取现有文件以保留格式 - try: - with open(file_path, "r", encoding="utf-8") as f: - existing_doc = tomlkit.load(f) - except Exception: - # 如果读取失败,直接覆盖 - with open(file_path, "w", encoding="utf-8") as f: - tomlkit.dump(data, f) - return - - # 递归更新,保留现有格式 - _merge_toml_preserving_format(existing_doc, data) - - # 保存 - with open(file_path, "w", encoding="utf-8") as f: - tomlkit.dump(existing_doc, f) - - -def _merge_toml_preserving_format(target: dict[str, Any], source: dict[str, Any]) -> None: - """ - 递归合并source到target,保留target中的格式和注释 - - Args: - target: 目标文档(保留格式) - source: 源数据(新数据) - """ - for key, value in source.items(): - if key in target: - # 如果两个都是字典且都是表格,递归合并 - if isinstance(value, dict) and isinstance(target[key], dict): - if hasattr(target[key], "items"): # 确实是字典/表格 - _merge_toml_preserving_format(target[key], value) - else: - target[key] = value - else: - # 其他情况直接替换 - target[key] = value - else: - # 新键直接添加 - target[key] = value - - -def _update_toml_doc(target: dict[str, Any], source: dict[str, Any]) -> None: - """ - 更新TOML文档中的字段,保留现有的格式和注释 - - 这是一个递归函数,用于在部分更新配置时保留现有的格式和注释。 - - Args: - target: 目标表格(会被修改) - source: 源数据(新数据) - """ - for key, value in source.items(): - if key in target: - # 如果两个都是字典,递归更新 - if isinstance(value, dict) and isinstance(target[key], dict): - if hasattr(target[key], "items"): # 确实是表格 - _update_toml_doc(target[key], value) - else: - target[key] = value - else: - # 直接更新值,保留注释 - target[key] = value - else: - # 新键直接添加 - target[key] = value diff --git a/src/common/utils/utils_message.py b/src/common/utils/utils_message.py index 8ac898c3..0d10bb04 100644 --- a/src/common/utils/utils_message.py +++ b/src/common/utils/utils_message.py @@ -1,5 +1,5 @@ from maim_message import MessageBase, Seg -from typing import List, Tuple, Optional, Sequence +from typing import List, Tuple, Optional, Sequence, TYPE_CHECKING import base64 import hashlib @@ -19,6 +19,9 @@ from src.common.data_models.message_component_data_model import ( ) from src.config.config import global_config +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage + class MessageUtils: @staticmethod @@ -135,3 +138,12 @@ class MessageUtils: else: components = [platform, user_id, "private"] return hashlib.md5("_".join(components).encode()).hexdigest() + + @staticmethod + def store_message_to_db(message: "SessionMessage"): + """存储消息到数据库""" + from src.common.database.database import get_db_session + + with get_db_session() as session: + db_message = message.to_db_instance() + session.add(db_message)