pull/1299/head
SengokuCola 2025-10-14 12:12:09 +08:00
commit 0b9a7743fc
10 changed files with 741 additions and 835 deletions

File diff suppressed because it is too large Load Diff

View File

@ -249,6 +249,8 @@ class BrainPlanner:
# 获取必要信息
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
# 提及/被@ 的处理由心流或统一判定模块驱动Planner 不再做硬编码强制回复
# 应用激活类型过滤
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)

View File

@ -1,7 +1,7 @@
import re
import traceback
from typing import Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
@ -17,31 +17,6 @@ if TYPE_CHECKING:
logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
"""计算消息的兴趣度
Args:
message: 待处理的消息对象
Returns:
Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词)
"""
if message.is_picid or message.is_emoji:
return 0.0, []
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
# interested_rate = 0.0
keywords = []
message.interest_value = 1
message.is_mentioned = is_mentioned
message.is_at = is_at
message.reply_probability_boost = reply_probability_boost
return 1, keywords
class HeartFCMessageReceiver:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
@ -67,12 +42,16 @@ class HeartFCMessageReceiver:
userinfo = message.message_info.user_info
chat = message.chat_stream
# 2. 兴趣度计算与更新
_, keywords = await _calculate_interest(message)
# 2. 计算at信息
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
message.is_mentioned = is_mentioned
message.is_at = is_at
message.reply_probability_boost = reply_probability_boost
await self.storage.store_message(message, chat)
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
# 3. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"

View File

@ -221,6 +221,8 @@ class ChatBot:
# 处理消息内容,生成纯文本
await message.process()
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配
# 过滤检查
if _check_ban_words(
message.processed_plain_text,

View File

@ -130,6 +130,16 @@ class MessageRecv(Message):
self.key_words = []
self.key_words_lite = []
# 兼容适配器通过 additional_config 传入的 @ 标记
try:
msg_info_dict = message_dict.get("message_info", {})
add_cfg = msg_info_dict.get("additional_config") or {}
if isinstance(add_cfg, dict) and add_cfg.get("at_bot"):
# 标记为被提及,提高后续回复优先级
self.is_mentioned = True # type: ignore
except Exception:
pass
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream

View File

@ -529,7 +529,7 @@ class DefaultReplyer:
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是你和{sender}的对话你们正在交流中
这是上述中你和{sender}的对话摘要内容从上面的对话中截取便于你理解
{core_dialogue_prompt_str}
--------------------------------
"""

View File

@ -43,9 +43,12 @@ def replace_user_references(
if name_resolver is None:
def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
# 检查是否是机器人自己(支持多平台)
if replace_bot_name:
if platform == "qq" and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
if platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""):
return f"{global_config.bot.nickname}(你)"
person = Person(platform=platform, user_id=user_id)
return person.person_name or user_id # type: ignore
@ -92,6 +95,8 @@ def replace_user_references(
new_content += content[last_end:]
content = new_content
# Telegram 文本 @username 的显示映射交由适配器或平台层处理;此处不做硬编码替换
return content
@ -432,7 +437,10 @@ def _build_readable_messages_internal(
person_name = (
person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
)
if replace_bot_name and user_id == global_config.bot.qq_account:
if replace_bot_name and (
(platform == global_config.bot.platform and user_id == global_config.bot.qq_account)
or (platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""))
):
person_name = f"{global_config.bot.nickname}(你)"
# 使用独立函数处理用户引用格式
@ -866,7 +874,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
# print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
if user_id == global_config.bot.qq_account:
if (platform == "qq" and user_id == global_config.bot.qq_account) or (
platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", "")
):
# print("SELF11111111111111")
return "SELF"
try:

View File

@ -30,76 +30,146 @@ def is_english_letter(char: str) -> bool:
return "a" <= char.lower() <= "z"
def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n"
logger.debug(f"result: {result}")
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
"""解析 platforms 列表,返回平台到账号的映射
Args:
platforms: 格式为 ["platform:account"] 的列表 ["tg:123456789", "wx:wxid123"]
Returns:
字典键为平台名值为账号
"""
result = {}
for platform_entry in platforms:
if ":" in platform_entry:
platform_name, account = platform_entry.split(":", 1)
result[platform_name.strip()] = account.strip()
return result
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
"""根据当前平台获取对应的账号
Args:
platform: 当前消息的平台
platform_accounts: platforms 列表解析的平台账号映射
qq_account: QQ 账号兼容旧配置
Returns:
当前平台对应的账号
"""
if platform == "qq":
return qq_account
elif platform == "telegram":
# 优先使用 tg其次使用 telegram
return platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
else:
# 其他平台直接使用平台名作为键
return platform_accounts.get(platform, "")
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]:
"""检查消息是否提到了机器人"""
keywords = [global_config.bot.nickname] + list(global_config.bot.alias_names)
"""检查消息是否提到了机器人(统一多平台实现)"""
text = message.processed_plain_text or ""
platform = getattr(message.message_info, "platform", "") or ""
# 获取各平台账号
platforms_list = getattr(global_config.bot, "platforms", []) or []
platform_accounts = parse_platform_accounts(platforms_list)
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
# 获取当前平台对应的账号
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
nickname = str(global_config.bot.nickname or "")
alias_names = list(getattr(global_config.bot, "alias_names", []) or [])
keywords = [nickname] + alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
# 这部分怎么处理啊啊啊啊
# 我觉得可以给消息加一个 reply_probability_boost字段
if (
message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None
):
# 1) 直接的 additional_config 标记
add_cfg = getattr(message.message_info, "additional_config", None) or {}
if isinstance(add_cfg, dict):
if add_cfg.get("at_bot") or add_cfg.get("is_mentioned"):
is_mentioned = True
# 当提供数值型 is_mentioned 时,当作概率提升
try:
if add_cfg.get("is_mentioned") not in (None, ""):
reply_probability = float(add_cfg.get("is_mentioned")) # type: ignore
except Exception:
pass
# 2) 已经在上游设置过的 message.is_mentioned
if getattr(message, "is_mentioned", False):
is_mentioned = True
# 3) 扫描分段:是否包含 mention_bot适配器插入
def _has_mention_bot(seg) -> bool:
try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, is_at, reply_probability
except Exception as e:
logger.warning(str(e))
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
)
if seg is None:
return False
if getattr(seg, "type", None) == "mention_bot":
return True
if getattr(seg, "type", None) == "seglist":
for s in getattr(seg, "data", []) or []:
if _has_mention_bot(s):
return True
return False
except Exception:
return False
for keyword in keywords:
if keyword in message.processed_plain_text:
is_mentioned = True
# 判断是否被@
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", message.processed_plain_text):
if _has_mention_bot(getattr(message, "message_segment", None)):
is_at = True
is_mentioned = True
if is_at and global_config.chat.at_bot_inevitable_reply:
# 4) 统一的 @ 检测逻辑
if current_account and not is_at and not is_mentioned:
if platform == "qq":
# QQ 格式: @<name:qq_id>
if re.search(rf"@<(.+?):{re.escape(current_account)}>", text):
is_at = True
is_mentioned = True
else:
# 其他平台格式: @username 或 @account
if re.search(rf"@{re.escape(current_account)}(\b|$)", text, flags=re.IGNORECASE):
is_at = True
is_mentioned = True
# 5) 统一的回复检测逻辑
if not is_mentioned:
# 通用回复格式:包含 "(你)" 或 "(你)"
if re.search(r"\[回复 .*?\(你\)", text) or re.search(r"\[回复 .*?(你):", text):
is_mentioned = True
# ID 形式的回复检测
elif current_account:
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\)(.+?)\],说:", text):
is_mentioned = True
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text):
is_mentioned = True
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
if not is_mentioned and keywords:
msg_content = text
# 去除各种 @ 与 回复标记,避免误判
msg_content = re.sub(r"@(.+?)(\d+)", "", msg_content)
msg_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", msg_content)
msg_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id|你)\)(.+?)\],说:", "", msg_content)
msg_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", msg_content)
for kw in keywords:
if kw and kw in msg_content:
is_mentioned = True
break
# 7) 概率设置
if is_at and getattr(global_config.chat, "at_bot_inevitable_reply", 1):
reply_probability = 1.0
logger.debug("被@回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(
rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\)(.+?)\],说:", message.processed_plain_text
) or re.match(
rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>(.+?)\],说:",
message.processed_plain_text,
):
is_mentioned = True
else:
# 判断内容中是否被提及
message_content = re.sub(r"@(.+?)(\d+)", "", message.processed_plain_text)
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\)(.+?)\],说:", "", message_content)
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", message_content)
for keyword in keywords:
if keyword in message_content:
is_mentioned = True
if is_mentioned and global_config.chat.mentioned_bot_reply:
reply_probability = 1.0
logger.debug("被提及回复概率设置为100%")
elif is_mentioned and getattr(global_config.chat, "mentioned_bot_reply", 1):
reply_probability = max(reply_probability, 1.0)
logger.debug("被提及回复概率设置为100%")
return is_mentioned, is_at, reply_probability
@ -115,45 +185,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
return embedding
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
if not recent_messages:
return []
who_chat_in_group = []
for db_msg in recent_messages:
# user_info = UserInfo.from_dict(
# {
# "platform": msg_db_data["user_platform"],
# "user_id": msg_db_data["user_id"],
# "user_nickname": msg_db_data["user_nickname"],
# "user_cardname": msg_db_data.get("user_cardname", ""),
# }
# )
# if (
# (user_info.platform, user_info.user_id) != sender
# and user_info.user_id != global_config.bot.qq_account
# and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
# and len(who_chat_in_group) < 5
# ): # 排除重复排除消息发送者排除bot限制加载的关系数目
# who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
if (
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
and db_msg.user_info.user_id != global_config.bot.qq_account
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
who_chat_in_group.append(
(db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
)
return who_chat_in_group
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并
@ -410,42 +441,6 @@ def calculate_typing_time(
return total_time # 加上回车时间
def cosine_similarity(v1, v2):
"""计算余弦相似度"""
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
def text_to_vector(text):
"""将文本转换为词频向量"""
# 分词
words = jieba.lcut(text)
return Counter(words)
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
"""使用简单的余弦相似度计算文本相似度"""
# 将输入文本转换为词频向量
text_vector = text_to_vector(text)
# 计算每个主题的相似度
similarities = []
for topic in topics:
topic_vector = text_to_vector(topic)
# 获取所有唯一词
all_words = set(text_vector.keys()) | set(topic_vector.keys())
# 构建向量
v1 = [text_vector.get(word, 0) for word in all_words]
v2 = [topic_vector.get(word, 0) for word in all_words]
# 计算相似度
similarity = cosine_similarity(v1, v2)
similarities.append((topic, similarity))
# 按相似度降序排序并返回前k个
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
@ -523,47 +518,6 @@ def get_western_ratio(paragraph):
return western_count / len(alnum_chars)
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
"""计算两个时间点之间的消息数量和文本总长度
Args:
start_time (float): 起始时间戳 (不包含)
end_time (float): 结束时间戳 (包含)
stream_id (str): 聊天流ID
Returns:
tuple[int, int]: (消息数量, 文本总长度)
"""
count = 0
total_length = 0
# 参数校验 (可选但推荐)
if start_time >= end_time:
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
return 0, 0
if not stream_id:
logger.error("stream_id 不能为空")
return 0, 0
# 使用message_repository中的count_messages和find_messages函数
# 构建查询条件
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
try:
# 先获取消息数量
count = count_messages(filter_query)
# 获取消息内容计算总长度
messages = find_messages(message_filter=filter_query)
total_length = sum(len(msg.processed_plain_text or "") for msg in messages)
return count, total_length
except Exception as e:
logger.error(f"计算消息数量时发生意外错误: {e}")
return 0, 0
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
@ -698,65 +652,6 @@ def assign_message_ids(messages: List[DatabaseMessages]) -> List[Tuple[str, Data
return result
# def assign_message_ids_flexible(
# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
# ) -> list:
# """
# 为消息列表中的每个消息分配唯一的简短随机ID增强版
# Args:
# messages: 消息列表
# prefix: ID前缀默认为"msg"
# id_length: ID的总长度不包括前缀默认为6
# use_timestamp: 是否在ID中包含时间戳默认为False
# Returns:
# 包含 {'id': str, 'message': any} 格式的字典列表
# """
# result = []
# used_ids = set()
# for i, message in enumerate(messages):
# # 生成唯一的ID
# while True:
# if use_timestamp:
# # 使用时间戳的后几位 + 随机字符
# timestamp_suffix = str(int(time.time() * 1000))[-3:]
# remaining_length = id_length - 3
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{timestamp_suffix}{random_chars}"
# else:
# # 使用索引 + 随机字符
# index_str = str(i + 1)
# remaining_length = max(1, id_length - len(index_str))
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{index_str}{random_chars}"
# if message_id not in used_ids:
# used_ids.add(message_id)
# break
# result.append({"id": message_id, "message": message})
# return result
# 使用示例:
# messages = ["Hello", "World", "Test message"]
#
# # 基础版本
# result1 = assign_message_ids(messages)
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
#
# # 增强版本 - 自定义前缀和长度
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
#
# # 增强版本 - 使用时间戳
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
def parse_keywords_string(keywords_input) -> list[str]:
# sourcery skip: use-contextlib-suppress
"""

View File

@ -27,6 +27,9 @@ class BotConfig(ConfigBase):
nickname: str
"""昵称"""
platforms: list[str] = field(default_factory=lambda: [])
"""其他平台列表"""
alias_names: list[str] = field(default_factory=lambda: [])
"""别名列表"""

View File

@ -1,5 +1,5 @@
[inner]
version = "6.18.3"
version = "6.18.4"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@ -14,6 +14,9 @@ version = "6.18.3"
[bot]
platform = "qq"
qq_account = "1145141919810" # 麦麦的QQ账号
platforms = ["wx:114514","xx:1919810"] # 麦麦的其他平台账号
nickname = "麦麦" # 麦麦的昵称
alias_names = ["麦叠", "牢麦"] # 麦麦的别名