HFC对应修改(部分)

r-dev
UnCLAS-Prommer 2026-02-24 15:59:35 +08:00
parent 12d4f236be
commit a8e8f6b7b3
No known key found for this signature in database
5 changed files with 62 additions and 422 deletions

View File

@ -1,11 +1,11 @@
import traceback
from typing import Any, Optional, Dict
from typing import Optional, Dict
import traceback
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.chat.message_receive.chat_manager import chat_manager
from src.chat.heart_flow.heartFC_chat import HeartFChatting
from src.chat.brain_chat.brain_chat import BrainChatting
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("heartflow")
@ -14,29 +14,26 @@ class Heartflow:
"""主心流协调器,负责初始化并协调聊天"""
def __init__(self):
self.heartflow_chat_list: Dict[Any, HeartFChatting | BrainChatting] = {}
self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {}
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting | BrainChatting]:
async def get_or_create_heartflow_chat(self, session_id: str) -> Optional[HeartFChatting | BrainChatting]:
"""获取或创建一个新的HeartFChatting实例"""
try:
if chat_id in self.heartflow_chat_list:
if chat := self.heartflow_chat_list.get(chat_id):
return chat
else:
chat_stream: ChatStream | None = get_chat_manager().get_stream(chat_id)
if not chat_stream:
raise ValueError(f"未找到 chat_id={chat_id} 的聊天流")
if chat_stream.group_info:
new_chat = HeartFChatting(chat_id=chat_id)
else:
new_chat = BrainChatting(chat_id=chat_id)
await new_chat.start()
self.heartflow_chat_list[chat_id] = new_chat
return new_chat
if chat := self.heartflow_chat_list.get(session_id):
return chat
chat_session = chat_manager.get_session_by_session_id(session_id)
if not chat_session:
raise ValueError(f"未找到 session_id={session_id} 的聊天流")
new_chat = (
HeartFChatting(session_id=session_id) if chat_session.group_id else BrainChatting(session_id=session_id)
)
await new_chat.start()
self.heartflow_chat_list[session_id] = new_chat
return new_chat
except Exception as e:
logger.error(f"创建心流聊天 {chat_id} 失败: {e}", exc_info=True)
logger.error(f"创建心流聊天 {session_id} 失败: {e}", exc_info=True)
traceback.print_exc()
return None
return None
heartflow = Heartflow()

View File

@ -1,21 +1,16 @@
import re
import traceback
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
import traceback
from src.chat.heart_flow.heartflow import heartflow
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.chat_message_builder import replace_user_references
# from src.chat.utils.chat_message_builder import replace_user_references
from src.common.utils.utils_message import MessageUtils
from src.common.logger import get_logger
from src.person_info.person_info import Person
from sqlmodel import select, col
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
if TYPE_CHECKING:
pass
from src.chat.message_receive.message import SessionMessage
logger = get_logger("chat")
@ -24,10 +19,9 @@ class HeartFCMessageReceiver:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
def __init__(self):
"""初始化心流处理器,创建消息存储实例"""
self.storage = MessageStorage()
pass
async def process_message(self, message: MessageRecv) -> None:
async def process_message(self, message: "SessionMessage"):
"""处理接收到的原始消息数据
主要流程:
@ -38,7 +32,7 @@ class HeartFCMessageReceiver:
5. 关系处理
Args:
message_data: 原始消息字符串
message: SessionMessage对象包含原始消息数据和相关信息
"""
try:
# 通知消息不处理
@ -48,70 +42,46 @@ class HeartFCMessageReceiver:
# 1. 消息解析与初始化
userinfo = message.message_info.user_info
chat = message.chat_stream
if userinfo is None or message.message_info.platform is None:
group_info = message.message_info.group_info
if userinfo is None or message.platform is None:
raise ValueError("message userinfo or platform is missing")
if userinfo.user_id is None or userinfo.user_nickname is None:
raise ValueError("message userinfo id or nickname is missing")
user_id = userinfo.user_id
nickname = userinfo.user_nickname
# 2. 计算at信息
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
# print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
message.is_mentioned = is_mentioned
message.is_at = is_at
message.reply_probability_boost = reply_probability_boost
# 2. 计算at信息 现在转移给Adapter完成
# is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
# # print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
# message.is_mentioned = is_mentioned
# message.is_at = is_at
await self.storage.store_message(message, chat)
MessageUtils.store_message_to_db(message) # 存储消息到数据库
await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
await heartflow.get_or_create_heartflow_chat(message.session_id)
# 3. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
mes_name = group_info.group_name if group_info else "私聊"
# 用这个pattern截取出id部分picid是一个list并替换成对应的图片描述
picid_pattern = r"\[picid:([^\]]+)\]"
picid_list = re.findall(picid_pattern, message.processed_plain_text)
# TODO: 修复引用格式替换
# # 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
# processed_plain_text = replace_user_references(
# processed_text, message.message_info.platform, replace_bot_name=True
# )
# # if not processed_plain_text:
# # print(message)
# 创建替换后的文本
processed_text = message.processed_plain_text
if picid_list:
for picid in picid_list:
with get_db_session() as session:
statement = (
select(Images).where(
(col(Images.id) == int(picid)) & (col(Images.image_type) == ImageType.IMAGE)
)
if picid.isdigit()
else None
)
image = session.exec(statement).first() if statement is not None else None
if image and image.description:
# 将[picid:xxxx]替换成图片描述
processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]")
else:
# 如果没有找到图片描述,则移除[picid:xxxx]标记
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
processed_plain_text = replace_user_references(
processed_text, message.message_info.platform, replace_bot_name=True
)
# if not processed_plain_text:
# print(message)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{message.processed_plain_text}")
# 如果是群聊,获取群号和群昵称
group_id = None
group_nick_name = None
if chat.group_info:
group_id = chat.group_info.group_id
if group_info:
group_id = group_info.group_id
group_nick_name = userinfo.user_cardname
_ = Person.register_person(
platform=message.message_info.platform,
platform=message.platform,
user_id=user_id,
nickname=nickname,
group_id=group_id,

View File

@ -1,250 +0,0 @@
from datetime import datetime
from collections.abc import Mapping
from typing import cast
import json
import re
import traceback
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType, Messages
from src.common.logger import get_logger
from src.common.data_models.message_component_model import MessageSequence, TextComponent
from src.common.utils.utils_message import MessageUtils
from .chat_stream import ChatStream
from .message import MessageRecv, MessageSending
logger = get_logger("message_storage")
class MessageStorage:
@staticmethod
def _coerce_str_list(value: object) -> list[str]:
if isinstance(value, list):
return [str(item) for item in value]
if isinstance(value, tuple):
return [str(item) for item in value]
if isinstance(value, set):
return [str(item) for item in value]
if isinstance(value, str):
return [value]
return []
@staticmethod
def _get_str(mapping: Mapping[str, object], key: str, default: str = "") -> str:
value = mapping.get(key)
if value is None:
return default
return str(value)
@staticmethod
def _get_optional_str(mapping: Mapping[str, object], key: str) -> str | None:
value = mapping.get(key)
if value is None:
return None
return str(value)
@staticmethod
def _serialize_keywords(keywords: list[str] | None) -> str:
"""将关键词列表序列化为JSON字符串"""
if isinstance(keywords, list):
return json.dumps(keywords, ensure_ascii=False)
return "[]"
@staticmethod
def _deserialize_keywords(keywords_str: str) -> list[str]:
"""将JSON字符串反序列化为关键词列表"""
if not keywords_str:
return []
try:
parsed = cast(object, json.loads(keywords_str))
except (json.JSONDecodeError, TypeError):
return []
if isinstance(parsed, list):
return [str(item) for item in parsed]
if isinstance(parsed, str):
return [parsed]
return []
@staticmethod
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 通知消息不存储
if isinstance(message, MessageRecv) and message.is_notify:
logger.debug("通知消息,跳过存储")
return
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
# print(message)
processed_plain_text = message.processed_plain_text
# print(processed_plain_text)
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else:
filtered_display_message = ""
interest_value = 0
is_mentioned = False
is_at = False
reply_probability_boost = 0.0
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picture = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
selected_expressions = message.selected_expressions
intercept_message_level = 0
else:
filtered_display_message = ""
interest_value = message.interest_value
is_mentioned = message.is_mentioned
is_at = message.is_at
reply_probability_boost = message.reply_probability_boost
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picture = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
intercept_message_level = getattr(message, "intercept_message_level", 0)
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(MessageStorage._coerce_str_list(message.key_words))
key_words_lite = MessageStorage._serialize_keywords(
MessageStorage._coerce_str_list(message.key_words_lite)
)
selected_expressions = ""
chat_info_dict = cast(dict[str, object], chat_stream.to_dict())
if message.message_info.user_info is None:
raise ValueError("message.user_info is required")
user_info_dict = cast(dict[str, object], message.message_info.user_info.to_dict())
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id or ""
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = cast(dict[str, object], chat_info_dict.get("group_info") or {})
additional_config: dict[str, object] = dict(message.message_info.additional_config or {})
additional_config.update(
{
"interest_value": interest_value,
"priority_mode": priority_mode,
"priority_info": priority_info,
"reply_probability_boost": reply_probability_boost,
"intercept_message_level": intercept_message_level,
"key_words": key_words,
"key_words_lite": key_words_lite,
"selected_expressions": selected_expressions,
"is_picid": is_picture,
}
)
processed_text_for_raw = filtered_processed_plain_text or filtered_display_message or ""
raw_sequence = MessageSequence([TextComponent(processed_text_for_raw)] if processed_text_for_raw else [])
raw_content = MessageUtils.from_MaiSeq_to_db_record_msg(raw_sequence)
timestamp_value = message.message_info.time
if timestamp_value is None:
raise ValueError("message.message_info.time is required")
db_message = Messages(
message_id=str(msg_id),
timestamp=datetime.fromtimestamp(float(timestamp_value)),
platform=MessageStorage._get_str(chat_info_dict, "platform"),
user_id=MessageStorage._get_str(user_info_dict, "user_id"),
user_nickname=MessageStorage._get_str(user_info_dict, "user_nickname"),
user_cardname=MessageStorage._get_optional_str(user_info_dict, "user_cardname"),
group_id=MessageStorage._get_optional_str(group_info_from_chat, "group_id"),
group_name=MessageStorage._get_optional_str(group_info_from_chat, "group_name"),
is_mentioned=bool(is_mentioned),
is_at=bool(is_at),
session_id=chat_stream.stream_id,
reply_to=reply_to,
is_emoji=is_emoji,
is_picture=is_picture,
is_command=is_command,
is_notify=is_notify,
raw_content=raw_content,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
additional_config=json.dumps(additional_config, ensure_ascii=False),
)
with get_db_session() as session:
session.add(db_message)
except Exception:
logger.exception("存储消息失败")
logger.error(f"消息:{message}")
traceback.print_exc()
# 如果需要其他存储相关的函数,可以在这里添加
@staticmethod
def update_message(mmc_message_id: str | None, qq_message_id: str | None) -> bool:
"""实时更新数据库的自身发送消息ID"""
try:
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return False
with get_db_session() as session:
statement = (
select(Messages)
.where(col(Messages.message_id) == mmc_message_id)
.order_by(col(Messages.timestamp).desc())
.limit(1)
)
matched_message = session.exec(statement).first()
if matched_message:
matched_message.message_id = qq_message_id
session.add(matched_message)
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
return True
logger.debug("未找到匹配的消息")
return False
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
return False
@staticmethod
def replace_image_descriptions(text: str) -> str:
"""将[图片:描述]替换为[picid:image_id]"""
# 先检查文本中是否有图片标记
pattern = r"\[图片:([^\]]+)\]"
matches = re.findall(pattern, text)
if not matches:
logger.debug("文本中没有图片标记,直接返回原文本")
return text
def replace_match(match: re.Match[str]) -> str:
description = match.group(1).strip()
try:
with get_db_session() as session:
statement = (
select(Images)
.where((col(Images.description) == description) & (col(Images.image_type) == ImageType.IMAGE))
.order_by(col(Images.record_time).desc())
.limit(1)
)
image_record = session.exec(statement).first()
return f"[picid:{image_record.id}]" if image_record else match.group(0)
except Exception:
return match.group(0)
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)

View File

@ -1,89 +0,0 @@
"""
TOML文件工具函数 - 保留格式和注释
"""
import os
import tomlkit
from typing import Any
def save_toml_with_format(data: dict[str, Any], file_path: str) -> None:
"""
保存TOML数据到文件保留现有格式如果文件存在
Args:
data: 要保存的数据字典
file_path: 文件路径
"""
# 如果文件不存在,直接创建
if not os.path.exists(file_path):
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(data, f)
return
# 如果文件存在,尝试读取现有文件以保留格式
try:
with open(file_path, "r", encoding="utf-8") as f:
existing_doc = tomlkit.load(f)
except Exception:
# 如果读取失败,直接覆盖
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(data, f)
return
# 递归更新,保留现有格式
_merge_toml_preserving_format(existing_doc, data)
# 保存
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(existing_doc, f)
def _merge_toml_preserving_format(target: dict[str, Any], source: dict[str, Any]) -> None:
"""
递归合并source到target保留target中的格式和注释
Args:
target: 目标文档保留格式
source: 源数据新数据
"""
for key, value in source.items():
if key in target:
# 如果两个都是字典且都是表格,递归合并
if isinstance(value, dict) and isinstance(target[key], dict):
if hasattr(target[key], "items"): # 确实是字典/表格
_merge_toml_preserving_format(target[key], value)
else:
target[key] = value
else:
# 其他情况直接替换
target[key] = value
else:
# 新键直接添加
target[key] = value
def _update_toml_doc(target: dict[str, Any], source: dict[str, Any]) -> None:
"""
更新TOML文档中的字段保留现有的格式和注释
这是一个递归函数用于在部分更新配置时保留现有的格式和注释
Args:
target: 目标表格会被修改
source: 源数据新数据
"""
for key, value in source.items():
if key in target:
# 如果两个都是字典,递归更新
if isinstance(value, dict) and isinstance(target[key], dict):
if hasattr(target[key], "items"): # 确实是表格
_update_toml_doc(target[key], value)
else:
target[key] = value
else:
# 直接更新值,保留注释
target[key] = value
else:
# 新键直接添加
target[key] = value

View File

@ -1,5 +1,5 @@
from maim_message import MessageBase, Seg
from typing import List, Tuple, Optional, Sequence
from typing import List, Tuple, Optional, Sequence, TYPE_CHECKING
import base64
import hashlib
@ -19,6 +19,9 @@ from src.common.data_models.message_component_data_model import (
)
from src.config.config import global_config
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
class MessageUtils:
@staticmethod
@ -135,3 +138,12 @@ class MessageUtils:
else:
components = [platform, user_id, "private"]
return hashlib.md5("_".join(components).encode()).hexdigest()
@staticmethod
def store_message_to_db(message: "SessionMessage"):
"""存储消息到数据库"""
from src.common.database.database import get_db_session
with get_db_session() as session:
db_message = message.to_db_instance()
session.add(db_message)