mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
commit
a97d8b4e3d
File diff suppressed because it is too large
Load Diff
|
|
@ -249,6 +249,8 @@ class BrainPlanner:
|
|||
# 获取必要信息
|
||||
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||
|
||||
# 提及/被@ 的处理由心流或统一判定模块驱动;Planner 不再做硬编码强制回复
|
||||
|
||||
# 应用激活类型过滤
|
||||
filtered_actions = self._filter_actions_by_activation_type(available_actions, chat_content_block_short)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import re
|
||||
import traceback
|
||||
|
||||
from typing import Tuple, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
|
@ -17,31 +17,6 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词)
|
||||
"""
|
||||
if message.is_picid or message.is_emoji:
|
||||
return 0.0, []
|
||||
|
||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
# interested_rate = 0.0
|
||||
keywords = []
|
||||
|
||||
message.interest_value = 1
|
||||
message.is_mentioned = is_mentioned
|
||||
message.is_at = is_at
|
||||
message.reply_probability_boost = reply_probability_boost
|
||||
|
||||
return 1, keywords
|
||||
|
||||
|
||||
class HeartFCMessageReceiver:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
|
|
@ -67,12 +42,16 @@ class HeartFCMessageReceiver:
|
|||
userinfo = message.message_info.user_info
|
||||
chat = message.chat_stream
|
||||
|
||||
# 2. 兴趣度计算与更新
|
||||
_, keywords = await _calculate_interest(message)
|
||||
# 2. 计算at信息
|
||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
|
||||
message.is_mentioned = is_mentioned
|
||||
message.is_at = is_at
|
||||
message.reply_probability_boost = reply_probability_boost
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
|
||||
# 3. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
|
|
|
|||
|
|
@ -221,6 +221,8 @@ class ChatBot:
|
|||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(
|
||||
message.processed_plain_text,
|
||||
|
|
|
|||
|
|
@ -130,6 +130,16 @@ class MessageRecv(Message):
|
|||
self.key_words = []
|
||||
self.key_words_lite = []
|
||||
|
||||
# 兼容适配器通过 additional_config 传入的 @ 标记
|
||||
try:
|
||||
msg_info_dict = message_dict.get("message_info", {})
|
||||
add_cfg = msg_info_dict.get("additional_config") or {}
|
||||
if isinstance(add_cfg, dict) and add_cfg.get("at_bot"):
|
||||
# 标记为被提及,提高后续回复优先级
|
||||
self.is_mentioned = True # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||
self.chat_stream = chat_stream
|
||||
|
||||
|
|
|
|||
|
|
@ -529,7 +529,7 @@ class DefaultReplyer:
|
|||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = f"""--------------------------------
|
||||
这是你和{sender}的对话,你们正在交流中:
|
||||
这是上述中你和{sender}的对话摘要,内容从上面的对话中截取,便于你理解:
|
||||
{core_dialogue_prompt_str}
|
||||
--------------------------------
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -43,9 +43,12 @@ def replace_user_references(
|
|||
if name_resolver is None:
|
||||
|
||||
def default_resolver(platform: str, user_id: str) -> str:
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
# 检查是否是机器人自己(支持多平台)
|
||||
if replace_bot_name:
|
||||
if platform == "qq" and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
if platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_id # type: ignore
|
||||
|
||||
|
|
@ -92,6 +95,8 @@ def replace_user_references(
|
|||
new_content += content[last_end:]
|
||||
content = new_content
|
||||
|
||||
# Telegram 文本 @username 的显示映射交由适配器或平台层处理;此处不做硬编码替换
|
||||
|
||||
return content
|
||||
|
||||
|
||||
|
|
@ -432,7 +437,10 @@ def _build_readable_messages_internal(
|
|||
person_name = (
|
||||
person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
|
||||
)
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
if replace_bot_name and (
|
||||
(platform == global_config.bot.platform and user_id == global_config.bot.qq_account)
|
||||
or (platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", ""))
|
||||
):
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
|
||||
# 使用独立函数处理用户引用格式
|
||||
|
|
@ -866,7 +874,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
|||
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
|
||||
# print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
|
||||
|
||||
if user_id == global_config.bot.qq_account:
|
||||
if (platform == "qq" and user_id == global_config.bot.qq_account) or (
|
||||
platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", "")
|
||||
):
|
||||
# print("SELF11111111111111")
|
||||
return "SELF"
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -492,10 +492,15 @@ class StatisticOutputTask(AsyncTask):
|
|||
continue
|
||||
|
||||
# Update name_mapping
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
try:
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
else:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
else:
|
||||
except (IndexError, TypeError) as e:
|
||||
logger.warning(f"更新 name_mapping 时发生错误,chat_id: {chat_id}, 错误: {e}")
|
||||
# 重置为正确的格式
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
|
|
@ -518,7 +523,21 @@ class StatisticOutputTask(AsyncTask):
|
|||
# 如果存在上次完整统计数据,则使用该数据进行增量统计
|
||||
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
|
||||
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
|
||||
# 修复 name_mapping 数据类型不匹配问题
|
||||
# JSON 中存储为列表,但代码期望为元组
|
||||
raw_name_mapping = last_stat["name_mapping"]
|
||||
self.name_mapping = {}
|
||||
for chat_id, value in raw_name_mapping.items():
|
||||
if isinstance(value, list) and len(value) == 2:
|
||||
# 将列表转换为元组
|
||||
self.name_mapping[chat_id] = (value[0], value[1])
|
||||
elif isinstance(value, tuple) and len(value) == 2:
|
||||
# 已经是元组,直接使用
|
||||
self.name_mapping[chat_id] = value
|
||||
else:
|
||||
# 数据格式不正确,跳过或使用默认值
|
||||
logger.warning(f"name_mapping 中 chat_id {chat_id} 的数据格式不正确: {value}")
|
||||
continue
|
||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||
last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
|
||||
self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
|
||||
|
|
@ -571,8 +590,14 @@ class StatisticOutputTask(AsyncTask):
|
|||
# 更新上次完整统计数据的时间戳
|
||||
# 将所有defaultdict转换为普通dict以避免类型冲突
|
||||
clean_stat_data = self._convert_defaultdict_to_dict(stat["all_time"])
|
||||
|
||||
# 将 name_mapping 中的元组转换为列表,因为JSON不支持元组
|
||||
json_safe_name_mapping = {}
|
||||
for chat_id, (chat_name, timestamp) in self.name_mapping.items():
|
||||
json_safe_name_mapping[chat_id] = [chat_name, timestamp]
|
||||
|
||||
local_storage["last_full_statistics"] = {
|
||||
"name_mapping": self.name_mapping,
|
||||
"name_mapping": json_safe_name_mapping,
|
||||
"stat_data": clean_stat_data,
|
||||
"timestamp": now.timestamp(),
|
||||
}
|
||||
|
|
@ -651,10 +676,13 @@ class StatisticOutputTask(AsyncTask):
|
|||
if stats[TOTAL_MSG_CNT] <= 0:
|
||||
return ""
|
||||
output = ["聊天消息统计:", " 联系人/群组名称 消息数量"]
|
||||
output.extend(
|
||||
f"{self.name_mapping[chat_id][0][:32]:<32} {count:>10}"
|
||||
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items())
|
||||
)
|
||||
for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items()):
|
||||
try:
|
||||
chat_name = self.name_mapping.get(chat_id, ("未知聊天", 0))[0]
|
||||
output.append(f"{chat_name[:32]:<32} {count:>10}")
|
||||
except (IndexError, TypeError) as e:
|
||||
logger.warning(f"格式化聊天统计时发生错误,chat_id: {chat_id}, 错误: {e}")
|
||||
output.append(f"{'未知聊天':<32} {count:>10}")
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
|
||||
|
|
@ -770,14 +798,16 @@ class StatisticOutputTask(AsyncTask):
|
|||
)
|
||||
|
||||
# 聊天消息统计
|
||||
chat_rows = "\n".join(
|
||||
[
|
||||
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
|
||||
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
|
||||
]
|
||||
if stat_data[MSG_CNT_BY_CHAT]
|
||||
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
chat_rows = []
|
||||
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()):
|
||||
try:
|
||||
chat_name = self.name_mapping.get(chat_id, ("未知聊天", 0))[0]
|
||||
chat_rows.append(f"<tr><td>{chat_name}</td><td>{count}</td></tr>")
|
||||
except (IndexError, TypeError) as e:
|
||||
logger.warning(f"生成HTML聊天统计时发生错误,chat_id: {chat_id}, 错误: {e}")
|
||||
chat_rows.append(f"<tr><td>未知聊天</td><td>{count}</td></tr>")
|
||||
|
||||
chat_rows_html = "\n".join(chat_rows) if chat_rows else "<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"
|
||||
# 生成HTML
|
||||
return f"""
|
||||
<div id=\"{div_id}\" class=\"tab-content\">
|
||||
|
|
@ -824,7 +854,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
<tr><th>联系人/群组名称</th><th>消息数量</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{chat_rows}
|
||||
{chat_rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
|
@ -975,7 +1005,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
}}
|
||||
|
||||
// 聊天消息分布饼图
|
||||
const chatLabels = {[self.name_mapping[chat_id][0] for chat_id in sorted(stat_data[MSG_CNT_BY_CHAT].keys())] if stat_data[MSG_CNT_BY_CHAT] else []};
|
||||
const chatLabels = {[self.name_mapping.get(chat_id, ("未知聊天", 0))[0] for chat_id in sorted(stat_data[MSG_CNT_BY_CHAT].keys())] if stat_data[MSG_CNT_BY_CHAT] else []};
|
||||
if (chatLabels.length > 0) {{
|
||||
const chatData = {{
|
||||
labels: chatLabels,
|
||||
|
|
|
|||
|
|
@ -30,76 +30,146 @@ def is_english_letter(char: str) -> bool:
|
|||
return "a" <= char.lower() <= "z"
|
||||
|
||||
|
||||
def db_message_to_str(message_dict: dict) -> str:
|
||||
logger.debug(f"message_dict: {message_dict}")
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||
try:
|
||||
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
|
||||
except Exception:
|
||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||
content = message_dict.get("processed_plain_text", "")
|
||||
result = f"[{time_str}] {name}: {content}\n"
|
||||
logger.debug(f"result: {result}")
|
||||
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
|
||||
"""解析 platforms 列表,返回平台到账号的映射
|
||||
|
||||
Args:
|
||||
platforms: 格式为 ["platform:account"] 的列表,如 ["tg:123456789", "wx:wxid123"]
|
||||
|
||||
Returns:
|
||||
字典,键为平台名,值为账号
|
||||
"""
|
||||
result = {}
|
||||
for platform_entry in platforms:
|
||||
if ":" in platform_entry:
|
||||
platform_name, account = platform_entry.split(":", 1)
|
||||
result[platform_name.strip()] = account.strip()
|
||||
return result
|
||||
|
||||
|
||||
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
|
||||
"""根据当前平台获取对应的账号
|
||||
|
||||
Args:
|
||||
platform: 当前消息的平台
|
||||
platform_accounts: 从 platforms 列表解析的平台账号映射
|
||||
qq_account: QQ 账号(兼容旧配置)
|
||||
|
||||
Returns:
|
||||
当前平台对应的账号
|
||||
"""
|
||||
if platform == "qq":
|
||||
return qq_account
|
||||
elif platform == "telegram":
|
||||
# 优先使用 tg,其次使用 telegram
|
||||
return platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
else:
|
||||
# 其他平台直接使用平台名作为键
|
||||
return platform_accounts.get(platform, "")
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = [global_config.bot.nickname] + list(global_config.bot.alias_names)
|
||||
"""检查消息是否提到了机器人(统一多平台实现)"""
|
||||
text = message.processed_plain_text or ""
|
||||
platform = getattr(message.message_info, "platform", "") or ""
|
||||
|
||||
# 获取各平台账号
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
|
||||
|
||||
# 获取当前平台对应的账号
|
||||
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
|
||||
|
||||
nickname = str(global_config.bot.nickname or "")
|
||||
alias_names = list(getattr(global_config.bot, "alias_names", []) or [])
|
||||
keywords = [nickname] + alias_names
|
||||
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
|
||||
# 这部分怎么处理啊啊啊啊
|
||||
# 我觉得可以给消息加一个 reply_probability_boost字段
|
||||
if (
|
||||
message.message_info.additional_config is not None
|
||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||
):
|
||||
# 1) 直接的 additional_config 标记
|
||||
add_cfg = getattr(message.message_info, "additional_config", None) or {}
|
||||
if isinstance(add_cfg, dict):
|
||||
if add_cfg.get("at_bot") or add_cfg.get("is_mentioned"):
|
||||
is_mentioned = True
|
||||
# 当提供数值型 is_mentioned 时,当作概率提升
|
||||
try:
|
||||
if add_cfg.get("is_mentioned") not in (None, ""):
|
||||
reply_probability = float(add_cfg.get("is_mentioned")) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2) 已经在上游设置过的 message.is_mentioned
|
||||
if getattr(message, "is_mentioned", False):
|
||||
is_mentioned = True
|
||||
|
||||
# 3) 扫描分段:是否包含 mention_bot(适配器插入)
|
||||
def _has_mention_bot(seg) -> bool:
|
||||
try:
|
||||
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
|
||||
is_mentioned = True
|
||||
return is_mentioned, is_at, reply_probability
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
logger.warning(
|
||||
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
|
||||
)
|
||||
if seg is None:
|
||||
return False
|
||||
if getattr(seg, "type", None) == "mention_bot":
|
||||
return True
|
||||
if getattr(seg, "type", None) == "seglist":
|
||||
for s in getattr(seg, "data", []) or []:
|
||||
if _has_mention_bot(s):
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
for keyword in keywords:
|
||||
if keyword in message.processed_plain_text:
|
||||
is_mentioned = True
|
||||
|
||||
# 判断是否被@
|
||||
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", message.processed_plain_text):
|
||||
if _has_mention_bot(getattr(message, "message_segment", None)):
|
||||
is_at = True
|
||||
is_mentioned = True
|
||||
|
||||
if is_at and global_config.chat.at_bot_inevitable_reply:
|
||||
# 4) 统一的 @ 检测逻辑
|
||||
if current_account and not is_at and not is_mentioned:
|
||||
if platform == "qq":
|
||||
# QQ 格式: @<name:qq_id>
|
||||
if re.search(rf"@<(.+?):{re.escape(current_account)}>", text):
|
||||
is_at = True
|
||||
is_mentioned = True
|
||||
else:
|
||||
# 其他平台格式: @username 或 @account
|
||||
if re.search(rf"@{re.escape(current_account)}(\b|$)", text, flags=re.IGNORECASE):
|
||||
is_at = True
|
||||
is_mentioned = True
|
||||
|
||||
# 5) 统一的回复检测逻辑
|
||||
if not is_mentioned:
|
||||
# 通用回复格式:包含 "(你)" 或 "(你)"
|
||||
if re.search(r"\[回复 .*?\(你\):", text) or re.search(r"\[回复 .*?(你):", text):
|
||||
is_mentioned = True
|
||||
# ID 形式的回复检测
|
||||
elif current_account:
|
||||
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\):(.+?)\],说:", text):
|
||||
is_mentioned = True
|
||||
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text):
|
||||
is_mentioned = True
|
||||
|
||||
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
|
||||
if not is_mentioned and keywords:
|
||||
msg_content = text
|
||||
# 去除各种 @ 与 回复标记,避免误判
|
||||
msg_content = re.sub(r"@(.+?)((\d+))", "", msg_content)
|
||||
msg_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", msg_content)
|
||||
msg_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id|你)\):(.+?)\],说:", "", msg_content)
|
||||
msg_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>:(.+?)\],说:", "", msg_content)
|
||||
for kw in keywords:
|
||||
if kw and kw in msg_content:
|
||||
is_mentioned = True
|
||||
break
|
||||
|
||||
# 7) 概率设置
|
||||
if is_at and getattr(global_config.chat, "at_bot_inevitable_reply", 1):
|
||||
reply_probability = 1.0
|
||||
logger.debug("被@,回复概率设置为100%")
|
||||
else:
|
||||
if not is_mentioned:
|
||||
# 判断是否被回复
|
||||
if re.match(
|
||||
rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\):(.+?)\],说:", message.processed_plain_text
|
||||
) or re.match(
|
||||
rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>:(.+?)\],说:",
|
||||
message.processed_plain_text,
|
||||
):
|
||||
is_mentioned = True
|
||||
else:
|
||||
# 判断内容中是否被提及
|
||||
message_content = re.sub(r"@(.+?)((\d+))", "", message.processed_plain_text)
|
||||
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
|
||||
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\):(.+?)\],说:", "", message_content)
|
||||
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>:(.+?)\],说:", "", message_content)
|
||||
for keyword in keywords:
|
||||
if keyword in message_content:
|
||||
is_mentioned = True
|
||||
if is_mentioned and global_config.chat.mentioned_bot_reply:
|
||||
reply_probability = 1.0
|
||||
logger.debug("被提及,回复概率设置为100%")
|
||||
elif is_mentioned and getattr(global_config.chat, "mentioned_bot_reply", 1):
|
||||
reply_probability = max(reply_probability, 1.0)
|
||||
logger.debug("被提及,回复概率设置为100%")
|
||||
|
||||
return is_mentioned, is_at, reply_probability
|
||||
|
||||
|
||||
|
|
@ -115,45 +185,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
|
|||
return embedding
|
||||
|
||||
|
||||
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
||||
# 获取当前群聊记录内发言的人
|
||||
filter_query = {"chat_id": chat_stream_id}
|
||||
sort_order = [("time", -1)]
|
||||
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
who_chat_in_group = []
|
||||
for db_msg in recent_messages:
|
||||
# user_info = UserInfo.from_dict(
|
||||
# {
|
||||
# "platform": msg_db_data["user_platform"],
|
||||
# "user_id": msg_db_data["user_id"],
|
||||
# "user_nickname": msg_db_data["user_nickname"],
|
||||
# "user_cardname": msg_db_data.get("user_cardname", ""),
|
||||
# }
|
||||
# )
|
||||
# if (
|
||||
# (user_info.platform, user_info.user_id) != sender
|
||||
# and user_info.user_id != global_config.bot.qq_account
|
||||
# and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||
# and len(who_chat_in_group) < 5
|
||||
# ): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
# who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
|
||||
if (
|
||||
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
|
||||
and db_msg.user_info.user_id != global_config.bot.qq_account
|
||||
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
|
||||
not in who_chat_in_group
|
||||
and len(who_chat_in_group) < 5
|
||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
who_chat_in_group.append(
|
||||
(db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
|
||||
)
|
||||
|
||||
return who_chat_in_group
|
||||
|
||||
|
||||
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
"""将文本分割成句子,并根据概率合并
|
||||
|
|
@ -410,42 +441,6 @@ def calculate_typing_time(
|
|||
return total_time # 加上回车时间
|
||||
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
"""计算余弦相似度"""
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def text_to_vector(text):
|
||||
"""将文本转换为词频向量"""
|
||||
# 分词
|
||||
words = jieba.lcut(text)
|
||||
return Counter(words)
|
||||
|
||||
|
||||
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||
"""使用简单的余弦相似度计算文本相似度"""
|
||||
# 将输入文本转换为词频向量
|
||||
text_vector = text_to_vector(text)
|
||||
|
||||
# 计算每个主题的相似度
|
||||
similarities = []
|
||||
for topic in topics:
|
||||
topic_vector = text_to_vector(topic)
|
||||
# 获取所有唯一词
|
||||
all_words = set(text_vector.keys()) | set(topic_vector.keys())
|
||||
# 构建向量
|
||||
v1 = [text_vector.get(word, 0) for word in all_words]
|
||||
v2 = [topic_vector.get(word, 0) for word in all_words]
|
||||
# 计算相似度
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
similarities.append((topic, similarity))
|
||||
|
||||
# 按相似度降序排序并返回前k个
|
||||
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
|
||||
|
||||
|
||||
def truncate_message(message: str, max_length=20) -> str:
|
||||
"""截断消息,使其不超过指定长度"""
|
||||
|
|
@ -523,47 +518,6 @@ def get_western_ratio(paragraph):
|
|||
return western_count / len(alnum_chars)
|
||||
|
||||
|
||||
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
|
||||
"""计算两个时间点之间的消息数量和文本总长度
|
||||
|
||||
Args:
|
||||
start_time (float): 起始时间戳 (不包含)
|
||||
end_time (float): 结束时间戳 (包含)
|
||||
stream_id (str): 聊天流ID
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (消息数量, 文本总长度)
|
||||
"""
|
||||
count = 0
|
||||
total_length = 0
|
||||
|
||||
# 参数校验 (可选但推荐)
|
||||
if start_time >= end_time:
|
||||
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
|
||||
return 0, 0
|
||||
if not stream_id:
|
||||
logger.error("stream_id 不能为空")
|
||||
return 0, 0
|
||||
|
||||
# 使用message_repository中的count_messages和find_messages函数
|
||||
|
||||
# 构建查询条件
|
||||
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
|
||||
|
||||
try:
|
||||
# 先获取消息数量
|
||||
count = count_messages(filter_query)
|
||||
|
||||
# 获取消息内容计算总长度
|
||||
messages = find_messages(message_filter=filter_query)
|
||||
total_length = sum(len(msg.processed_plain_text or "") for msg in messages)
|
||||
|
||||
return count, total_length
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息数量时发生意外错误: {e}")
|
||||
return 0, 0
|
||||
|
||||
|
||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
|
||||
|
|
@ -698,65 +652,6 @@ def assign_message_ids(messages: List[DatabaseMessages]) -> List[Tuple[str, Data
|
|||
return result
|
||||
|
||||
|
||||
# def assign_message_ids_flexible(
|
||||
# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
|
||||
# ) -> list:
|
||||
# """
|
||||
# 为消息列表中的每个消息分配唯一的简短随机ID(增强版)
|
||||
|
||||
# Args:
|
||||
# messages: 消息列表
|
||||
# prefix: ID前缀,默认为"msg"
|
||||
# id_length: ID的总长度(不包括前缀),默认为6
|
||||
# use_timestamp: 是否在ID中包含时间戳,默认为False
|
||||
|
||||
# Returns:
|
||||
# 包含 {'id': str, 'message': any} 格式的字典列表
|
||||
# """
|
||||
# result = []
|
||||
# used_ids = set()
|
||||
|
||||
# for i, message in enumerate(messages):
|
||||
# # 生成唯一的ID
|
||||
# while True:
|
||||
# if use_timestamp:
|
||||
# # 使用时间戳的后几位 + 随机字符
|
||||
# timestamp_suffix = str(int(time.time() * 1000))[-3:]
|
||||
# remaining_length = id_length - 3
|
||||
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
# message_id = f"{prefix}{timestamp_suffix}{random_chars}"
|
||||
# else:
|
||||
# # 使用索引 + 随机字符
|
||||
# index_str = str(i + 1)
|
||||
# remaining_length = max(1, id_length - len(index_str))
|
||||
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
# message_id = f"{prefix}{index_str}{random_chars}"
|
||||
|
||||
# if message_id not in used_ids:
|
||||
# used_ids.add(message_id)
|
||||
# break
|
||||
|
||||
# result.append({"id": message_id, "message": message})
|
||||
|
||||
# return result
|
||||
|
||||
|
||||
# 使用示例:
|
||||
# messages = ["Hello", "World", "Test message"]
|
||||
#
|
||||
# # 基础版本
|
||||
# result1 = assign_message_ids(messages)
|
||||
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
|
||||
#
|
||||
# # 增强版本 - 自定义前缀和长度
|
||||
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
|
||||
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
|
||||
#
|
||||
# # 增强版本 - 使用时间戳
|
||||
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
|
||||
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
||||
|
||||
|
||||
def parse_keywords_string(keywords_input) -> list[str]:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ class BotConfig(ConfigBase):
|
|||
|
||||
nickname: str
|
||||
"""昵称"""
|
||||
|
||||
platforms: list[str] = field(default_factory=lambda: [])
|
||||
"""其他平台列表"""
|
||||
|
||||
alias_names: list[str] = field(default_factory=lambda: [])
|
||||
"""别名列表"""
|
||||
|
|
|
|||
|
|
@ -50,15 +50,14 @@ class QuestionMaker:
|
|||
"""按权重随机选取一个未回答的冲突并自增 raise_time。
|
||||
|
||||
选择规则:
|
||||
- 若存在 `raise_time == 0` 的项:按权重抽样(0 次权重 1.0,≥1 次权重 0.05)。
|
||||
- 若不存在 `raise_time == 0` 的项:仅 5% 概率返回其中任意一条,否则返回 None。
|
||||
- 若存在 `raise_time == 0` 的项:按权重抽样(0 次权重 1.0,≥1 次权重 0.01)。
|
||||
- 若不存在,返回 None。
|
||||
- 每次成功选中后,将该条目的 `raise_time` 自增 1 并保存。
|
||||
"""
|
||||
conflicts = await self.get_un_answered_conflict()
|
||||
if not conflicts:
|
||||
return None
|
||||
|
||||
# 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个
|
||||
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
|
||||
if conflicts_with_zero:
|
||||
# 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.01
|
||||
|
|
@ -71,12 +70,14 @@ class QuestionMaker:
|
|||
# 按权重随机选择
|
||||
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
|
||||
|
||||
# 选中后,自增 raise_time 并保存
|
||||
# 选中后,自增 raise_time 并保存
|
||||
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
|
||||
chosen_conflict.save()
|
||||
|
||||
|
||||
return chosen_conflict
|
||||
return chosen_conflict
|
||||
else:
|
||||
# 如果没有 raise_time == 0 的冲突,返回 None
|
||||
return None
|
||||
|
||||
async def make_question(self) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""生成一条用于询问用户的冲突问题与上下文。
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[inner]
|
||||
version = "6.18.3"
|
||||
version = "6.18.4"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
|
|
@ -14,6 +14,9 @@ version = "6.18.3"
|
|||
[bot]
|
||||
platform = "qq"
|
||||
qq_account = "1145141919810" # 麦麦的QQ账号
|
||||
|
||||
platforms = ["wx:114514","xx:1919810"] # 麦麦的其他平台账号
|
||||
|
||||
nickname = "麦麦" # 麦麦的昵称
|
||||
alias_names = ["麦叠", "牢麦"] # 麦麦的别名
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue