提取重复代码到工具里

pull/914/head
Bakadax 2025-05-02 00:43:52 +08:00
parent 635ead2b6a
commit 7c95166e0a
4 changed files with 173 additions and 209 deletions

View File

@ -1,8 +1,14 @@
# GroupNickname/nickname_utils.py
import random
from typing import List, Dict, Tuple
import time
from typing import List, Dict, Tuple, Optional
from src.common.logger_manager import get_logger
from src.config.config import global_config
from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.chat.chat_stream import ChatStream
from src.plugins.chat.message import MessageRecv
from src.plugins.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from .nickname_processor import add_to_nickname_queue
logger = get_logger("nickname_utils")
@ -10,14 +16,6 @@ logger = get_logger("nickname_utils")
def select_nicknames_for_prompt(all_nicknames_info: Dict[str, List[Dict[str, int]]]) -> List[Tuple[str, str, int]]:
"""
从给定的绰号信息中根据映射次数加权随机选择最多 N 个绰号
Args:
all_nicknames_info: 包含用户及其绰号信息的字典格式为
{ "用户名1": [{"绰号A": 次数}, {"绰号B": 次数}], ... }
Returns:
List[Tuple[str, str, int]]: 选中的绰号列表每个元素为 (用户名, 绰号, 次数)
按次数降序排序
"""
if not all_nicknames_info:
return []
@ -26,12 +24,9 @@ def select_nicknames_for_prompt(all_nicknames_info: Dict[str, List[Dict[str, int
for user_name, nicknames in all_nicknames_info.items():
if nicknames:
for nickname_entry in nicknames:
# nickname_entry 应该是 {"绰号": 次数} 格式
if isinstance(nickname_entry, dict) and len(nickname_entry) == 1:
nickname, count = list(nickname_entry.items())[0]
# 确保次数是正整数
if isinstance(count, int) and count > 0:
# 添加平滑因子避免概率为0并让低频词也有机会
weight = count + global_config.NICKNAME_PROBABILITY_SMOOTHING
candidates.append((user_name, nickname, count, weight))
else:
@ -44,55 +39,39 @@ def select_nicknames_for_prompt(all_nicknames_info: Dict[str, List[Dict[str, int
if not candidates:
return []
# 计算总权重
total_weight = sum(c[3] for c in candidates)
if total_weight <= 0:
# 如果所有权重都无效或为0则随机选择或按次数选择
candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
candidates.sort(key=lambda x: x[2], reverse=True)
selected = candidates[: global_config.MAX_NICKNAMES_IN_PROMPT]
else:
# 计算归一化概率
probabilities = [c[3] / total_weight for c in candidates]
# 使用概率分布进行加权随机选择(不重复)
num_to_select = min(global_config.MAX_NICKNAMES_IN_PROMPT, len(candidates))
try:
# random.choices 允许重复,我们需要不重复的选择
# 可以使用 numpy.random.choice 或手动实现不重复加权抽样
# 这里用一个简化的方法:多次 choices 然后去重,直到达到数量或无法再选
selected_indices = set()
selected = []
attempts = 0
max_attempts = num_to_select * 5 # 防止无限循环
max_attempts = num_to_select * 5
while len(selected) < num_to_select and attempts < max_attempts:
# 每次只选一个,避免一次选多个时概率分布变化导致的问题
chosen_index = random.choices(range(len(candidates)), weights=probabilities, k=1)[0]
if chosen_index not in selected_indices:
selected_indices.add(chosen_index)
selected.append(candidates[chosen_index])
attempts += 1
# 如果尝试多次后仍未选够,补充出现次数最多的
if len(selected) < num_to_select:
remaining_candidates = [c for i, c in enumerate(candidates) if i not in selected_indices]
remaining_candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
remaining_candidates.sort(key=lambda x: x[2], reverse=True)
needed = num_to_select - len(selected)
selected.extend(remaining_candidates[:needed])
except Exception as e:
logger.error(
f"Error during weighted random choice for nicknames: {e}. Falling back to top N.", exc_info=True
)
# 出错时回退到选择次数最多的 N 个
candidates.sort(key=lambda x: x[2], reverse=True)
selected = candidates[: global_config.MAX_NICKNAMES_IN_PROMPT]
# 格式化输出并按次数排序
result = [(user, nick, count) for user, nick, count, _weight in selected]
result.sort(key=lambda x: x[2], reverse=True) # 按次数降序
result.sort(key=lambda x: x[2], reverse=True)
logger.debug(f"Selected nicknames for prompt: {result}")
return result
@ -100,27 +79,154 @@ def select_nicknames_for_prompt(all_nicknames_info: Dict[str, List[Dict[str, int
def format_nickname_prompt_injection(selected_nicknames: List[Tuple[str, str, int]]) -> str:
"""
将选中的绰号信息格式化为注入 Prompt 的字符串
Args:
selected_nicknames: 选中的绰号列表 (用户名, 绰号, 次数)
Returns:
str: 格式化后的字符串如果列表为空则返回空字符串
(代码保持不变)
"""
if not selected_nicknames:
return ""
prompt_lines = ["以下是聊天记录中一些成员在本群的绰号信息(按常用度排序):"]
prompt_lines = ["【群成员绰号信息】"]
grouped_by_user: Dict[str, List[str]] = {}
for user_name, nickname, _count in selected_nicknames:
if user_name not in grouped_by_user:
grouped_by_user[user_name] = []
# 添加引号以区分绰号
grouped_by_user[user_name].append(f"{nickname}")
for user_name, nicknames in grouped_by_user.items():
nicknames_str = "".join(nicknames)
prompt_lines.append(f"{user_name}在本群有时被称为:{nicknames_str}")
prompt_lines.append(f"- {user_name}有时被称为:{nicknames_str}")
return "\n".join(prompt_lines) + "\n" # 末尾加换行符
if len(prompt_lines) > 1:
return "\n".join(prompt_lines) + "\n"
else:
return ""
async def get_nickname_injection_for_prompt(chat_stream: ChatStream, message_list_before_now: List[Dict]) -> str:
"""
获取并格式化用于 Prompt 注入的绰号信息字符串
"""
nickname_injection_str = ""
if global_config.ENABLE_NICKNAME_MAPPING and chat_stream and chat_stream.group_info:
try:
group_id = str(chat_stream.group_info.group_id)
user_ids_in_context = set()
if message_list_before_now:
for msg in message_list_before_now:
sender_id = msg["user_info"].get("user_id")
if sender_id:
user_ids_in_context.add(str(sender_id))
else:
recent_speakers = chat_stream.get_recent_speakers(limit=5)
for speaker in recent_speakers:
user_ids_in_context.add(str(speaker['user_id']))
if not user_ids_in_context:
logger.warning(f"[{chat_stream.stream_id}] No messages or recent speakers found for nickname injection.")
if user_ids_in_context:
platform = chat_stream.platform
all_nicknames_data = await relationship_manager.get_users_group_nicknames(
platform, list(user_ids_in_context), group_id
)
if all_nicknames_data:
selected_nicknames = select_nicknames_for_prompt(all_nicknames_data)
nickname_injection_str = format_nickname_prompt_injection(selected_nicknames)
if nickname_injection_str:
logger.debug(f"[{chat_stream.stream_id}] Generated nickname info for prompt:\n{nickname_injection_str}")
except Exception as e:
logger.error(f"[{chat_stream.stream_id}] Error getting or formatting nickname info for prompt: {e}", exc_info=True)
nickname_injection_str = ""
return nickname_injection_str
# --- 新增:触发绰号分析的工具函数 ---
async def trigger_nickname_analysis_if_needed(
anchor_message: MessageRecv,
bot_reply: List[str],
chat_stream: Optional[ChatStream] = None # 允许传入 chat_stream 或从 anchor_message 获取
):
"""
如果满足条件群聊功能开启则准备数据并触发绰号分析任务
Args:
anchor_message: 触发回复的原始消息对象
bot_reply: Bot 生成的回复内容列表
chat_stream: 可选的 ChatStream 对象
"""
# 检查功能是否开启
if not global_config.ENABLE_NICKNAME_MAPPING:
return
# 确定使用的 chat_stream
current_chat_stream = chat_stream or anchor_message.chat_stream
# 检查是否是群聊且 chat_stream 有效
if not current_chat_stream or not current_chat_stream.group_info:
logger.debug(f"[{current_chat_stream.stream_id if current_chat_stream else 'Unknown'}] Skipping nickname analysis: Not a group chat or invalid chat stream.")
return
log_prefix = f"[{current_chat_stream.stream_id}]" # 日志前缀
try:
# 1. 获取历史记录
history_limit = 30 # 可配置的历史记录条数
history_messages = get_raw_msg_before_timestamp_with_chat(
chat_id=current_chat_stream.stream_id,
timestamp=time.time(),
limit=history_limit,
)
# 格式化历史记录
chat_history_str = await build_readable_messages(
messages=history_messages,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
truncate=False,
)
# 2. 获取 Bot 回复字符串
bot_reply_str = " ".join(bot_reply) if bot_reply else "" # 处理空回复列表
# 3. 获取群号和平台
group_id = str(current_chat_stream.group_info.group_id)
platform = current_chat_stream.platform
# 4. 构建用户 ID 到名称的映射
user_ids_in_history = set()
for msg in history_messages:
sender_id = msg["user_info"].get("user_id")
if sender_id:
user_ids_in_history.add(str(sender_id))
user_name_map = {}
if user_ids_in_history:
try:
# 批量获取 person_name
names_data = await relationship_manager.get_person_names_batch(platform, list(user_ids_in_history))
except Exception as e:
logger.error(f"{log_prefix} Error getting person names batch: {e}", exc_info=True)
names_data = {}
for user_id in user_ids_in_history:
if user_id in names_data:
user_name_map[user_id] = names_data[user_id]
else:
# 回退查找 nickname (从后往前找最新的)
latest_nickname = next(
(
m["user_info"].get("user_nickname") # 从 user_info 获取
for m in reversed(history_messages)
if str(m["user_info"].get("user_id")) == user_id and m["user_info"].get("user_nickname") # 确保 nickname 存在
),
None,
)
user_name_map[user_id] = latest_nickname or f"未知({user_id})" # 提供回退
# 5. 添加到处理队列
await add_to_nickname_queue(chat_history_str, bot_reply_str, platform, group_id, user_name_map)
logger.debug(f"{log_prefix} Triggered nickname analysis for group {group_id}.")
except Exception as e:
logger.error(f"{log_prefix} Error triggering nickname analysis: {e}", exc_info=True)

View File

@ -1,38 +1,33 @@
import asyncio
import time
import traceback
import random # <--- 添加导入
import json # <--- 确保导入 json
import random
import json
from typing import List, Optional, Dict, Any, Deque, Callable, Coroutine
from collections import deque
from src.plugins.chat.message import MessageRecv, BaseMessageInfo, MessageThinking, MessageSending
from src.plugins.chat.message import Seg # Local import needed after move
from src.plugins.chat.message import Seg
from src.plugins.chat.chat_stream import ChatStream
from src.plugins.chat.message import UserInfo
from src.plugins.chat.chat_stream import chat_manager
from src.common.logger_manager import get_logger
from src.plugins.models.utils_model import LLMRequest
from src.config.config import global_config
from src.plugins.chat.utils_image import image_path_to_base64 # Local import needed after move
from src.plugins.utils.timer_calculator import Timer # <--- Import Timer
from src.plugins.chat.utils_image import image_path_to_base64
from src.plugins.utils.timer_calculator import Timer
from src.plugins.emoji_system.emoji_manager import emoji_manager
from src.heart_flow.sub_mind import SubMind
from src.heart_flow.observation import Observation
from src.plugins.heartFC_chat.heartflow_prompt_builder import global_prompt_manager, prompt_builder
import contextlib
from src.plugins.utils.chat_message_builder import (
num_new_messages_since,
get_raw_msg_before_timestamp_with_chat,
build_readable_messages,
)
from src.plugins.utils.chat_message_builder import num_new_messages_since
from src.plugins.heartFC_chat.heartFC_Cycleinfo import CycleInfo
from .heartFC_sender import HeartFCSender
from src.plugins.chat.utils import process_llm_response
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
from src.plugins.moods.moods import MoodManager
from src.individuality.individuality import Individuality
from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.group_nickname.nickname_processor import add_to_nickname_queue # <--- 导入队列添加函数
from src.plugins.group_nickname.nickname_utils import trigger_nickname_analysis_if_needed
WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒
@ -586,10 +581,9 @@ class HeartFChatting:
send_emoji=emoji_query,
)
print("消息发送成功,准备进入绰号分析")
# --- [新增] 触发绰号分析 ---
# 在发送成功后(或至少尝试发送后)触发
await self._trigger_nickname_analysis(anchor_message, reply)
# --- 结束触发 ---
# 调用工具函数触发绰号分析
await trigger_nickname_analysis_if_needed(anchor_message, reply, self.chat_stream)
return True, thinking_id
@ -697,90 +691,6 @@ class HeartFChatting:
# 发生意外错误时,可以选择是否重置计数器,这里选择不重置
return False # 表示动作未成功
# 触发绰号分析的函数
async def _trigger_nickname_analysis(self, anchor_message: MessageRecv, reply: List[str]):
"""
触发绰号分析任务将相关数据放入处理队列
Args:
anchor_message: 锚点消息对象
reply: Bot 生成的回复内容列表
"""
if not global_config.ENABLE_NICKNAME_MAPPING:
return # 如果功能未开启,则直接返回
if not anchor_message or not anchor_message.chat_stream or not anchor_message.chat_stream.group_info:
logger.debug(f"{self.log_prefix} Skipping nickname analysis: Not a group chat or invalid anchor.")
return # 仅在群聊中进行分析
try:
# 1. 获取原始消息列表
history_limit = 30 # 例如,获取最近 30 条消息
history_messages = get_raw_msg_before_timestamp_with_chat(
chat_id=anchor_message.chat_stream.stream_id,
timestamp=time.time(), # 获取当前时间点的历史
limit=history_limit,
)
# 格式化历史记录
chat_history_str = await build_readable_messages(
messages=history_messages,
replace_bot_name=True, # 在分析时也替换机器人名字,使其与 LLM 交互一致
merge_messages=False, # 不合并,保留原始对话流
timestamp_mode="relative", # 使用相对时间戳
read_mark=0.0, # 不需要已读标记
truncate=False, # 获取完整内容进行分析
)
# 2. 获取 Bot 回复字符串
bot_reply_str = " ".join(reply)
# 3. 获取群号
group_id = str(anchor_message.chat_stream.group_info.group_id) # 确保是字符串
# 4. 获取当前上下文中涉及的用户 ID 及其已知名称
user_ids_in_history = set()
for msg in history_messages:
sender_id = msg["user_info"].get("user_id")
if sender_id:
user_ids_in_history.add(str(sender_id)) # 确保是字符串
user_name_map = {}
if user_ids_in_history:
platform = anchor_message.chat_stream.platform
try:
names_data = await relationship_manager.get_person_names_batch(platform, list(user_ids_in_history))
except Exception as e:
logger.error(f"Error getting person names: {e}", exc_info=True)
names_data = {} # 出错时置空
print(f"\n\nnames_data:\n{names_data}\n\n")
for user_id in user_ids_in_history:
if user_id in names_data:
user_name_map[user_id] = names_data[user_id]
else:
# 回退查找 nickname
latest_nickname = next(
(
m.get("sender_nickname")
for m in reversed(history_messages)
if str(m.get("sender_id")) == user_id
),
None,
)
if latest_nickname:
user_name_map[user_id] = latest_nickname
else:
user_name_map[user_id] = f"未知({user_id})"
# 5. 添加到队列
await add_to_nickname_queue(chat_history_str, bot_reply_str, platform, group_id, user_name_map)
logger.debug(f"{self.log_prefix} Triggered nickname analysis for group {group_id}.")
except Exception as e:
logger.error(f"{self.log_prefix} Error triggering nickname analysis: {e}", exc_info=True)
async def _wait_for_new_message(self, observation, planner_start_db_time: float, log_prefix: str) -> bool:
"""
等待新消息 检测到关闭信号

View File

@ -1,4 +1,6 @@
import random
import time
from typing import Union, Optional, List, Dict, Any # 引入 List, Dict, Any
from ...config.config import global_config
from src.common.logger_manager import get_logger
from ...individuality.individuality import Individuality
@ -6,15 +8,13 @@ from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugins.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.chat.utils import get_embedding
import time
from typing import Union, Optional
from ...common.database import db
from ..chat.utils import get_recent_group_speaker
from ..moods.moods import MoodManager
from ..memory_system.Hippocampus import HippocampusManager
from ..schedule.schedule_generator import bot_schedule
from ..knowledge.knowledge_lib import qa_manager
from src.plugins.group_nickname.nickname_utils import select_nicknames_for_prompt, format_nickname_prompt_injection
from src.plugins.group_nickname.nickname_utils import get_nickname_injection_for_prompt
logger = get_logger("prompt")
@ -23,6 +23,7 @@ def init_prompt():
Prompt(
"""
{info_from_tools}
{nickname_info}
{chat_target}
{chat_talking_prompt}
现在你想要在群里发言或者回复\n
@ -131,6 +132,7 @@ JSON 结构如下,包含三个字段 "action", "reasoning", "emoji_query":
{relation_prompt}
{prompt_info}
{schedule_prompt}
{nickname_info}
{chat_target}
{chat_talking_prompt}
现在"{sender_name}"说的:{message_txt}引起了你的注意你想要在群里发言或者回复这条消息\n
@ -214,40 +216,13 @@ async def _build_prompt_focus(reason, current_mind_info, structured_info, chat_s
logger.debug("开始构建prompt")
# 注入绰号信息
nickname_injection_str = ""
if global_config.ENABLE_NICKNAME_MAPPING and chat_stream.group_info:
try:
group_id = str(chat_stream.group_info.group_id)
user_ids_in_context = set()
if message_list_before_now:
for msg in message_list_before_now:
sender_id = msg["user_info"].get("user_id")
if sender_id:
user_ids_in_context.add(str(sender_id))
else:
logger.warning("Variable 'message_list_before_now' not found for nickname injection in focus prompt.")
if user_ids_in_context:
platform = chat_stream.platform
# --- 调用批量获取群组绰号的方法 ---
all_nicknames_data = await relationship_manager.get_users_group_nicknames(
platform, list(user_ids_in_context), group_id
)
if all_nicknames_data:
selected_nicknames = select_nicknames_for_prompt(all_nicknames_data)
nickname_injection_str = format_nickname_prompt_injection(selected_nicknames)
if nickname_injection_str:
logger.debug(f"Injecting nickname info into focus prompt:\n{nickname_injection_str}")
except Exception as e:
logger.error(f"Error getting or formatting nickname info for focus prompt: {e}", exc_info=True)
logger.debug(f"-------------------nickname_injection_str_______________________\n{nickname_injection_str}\n\n")
# 调用新的工具函数获取绰号信息
nickname_injection_str = await get_nickname_injection_for_prompt(chat_stream, message_list_before_now)
prompt = await global_prompt_manager.format_prompt(
"heart_flow_prompt",
info_from_tools=structured_info_prompt,
nickname_info=nickname_injection_str,
chat_target=await global_prompt_manager.get_prompt_async("chat_target_group1")
if chat_in_group
else await global_prompt_manager.get_prompt_async("chat_target_private1"),
@ -299,7 +274,7 @@ class PromptBuilder:
)
return None
async def _build_prompt_normal(self, chat_stream, message_txt: str, sender_name: str = "某人") -> tuple[str, str]:
async def _build_prompt_normal(self, chat_stream, message_txt: str, sender_name: str = "某人") -> str: # 返回值改为 str
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=2)
@ -430,38 +405,8 @@ class PromptBuilder:
else:
schedule_prompt = ""
# 注入绰号信息
nickname_injection_str = ""
if global_config.ENABLE_NICKNAME_MAPPING and chat_stream.group_info:
try:
group_id = str(chat_stream.group_info.group_id)
user_ids_in_context = set()
if message_list_before_now:
for msg in message_list_before_now:
sender_id = msg["user_info"].get("user_id")
if sender_id:
user_ids_in_context.add(str(sender_id))
else:
logger.warning(
"Variable 'message_list_before_now' not found for nickname injection in focus prompt."
)
if user_ids_in_context:
platform = chat_stream.platform
# --- 调用批量获取群组绰号的方法 ---
all_nicknames_data = await relationship_manager.get_users_group_nicknames(
platform, list(user_ids_in_context), group_id
)
if all_nicknames_data:
selected_nicknames = select_nicknames_for_prompt(all_nicknames_data)
nickname_injection_str = format_nickname_prompt_injection(selected_nicknames)
if nickname_injection_str:
logger.debug(f"Injecting nickname info into focus prompt:\n{nickname_injection_str}")
except Exception as e:
logger.error(f"Error getting or formatting nickname info for focus prompt: {e}", exc_info=True)
logger.debug(f"-------------------nickname_injection_str_______________________\n{nickname_injection_str}\n\n")
# 调用新的工具函数获取绰号信息
nickname_injection_str = await get_nickname_injection_for_prompt(chat_stream, message_list_before_now)
prompt = await global_prompt_manager.format_prompt(
"reasoning_prompt_main",
@ -470,6 +415,7 @@ class PromptBuilder:
memory_prompt=memory_prompt,
prompt_info=prompt_info,
schedule_prompt=schedule_prompt,
nickname_info=nickname_injection_str, # <--- 注入绰号信息
chat_target=await global_prompt_manager.get_prompt_async("chat_target_group1")
if chat_in_group
else await global_prompt_manager.get_prompt_async("chat_target_private1"),

View File

@ -19,6 +19,7 @@ from src.plugins.chat.chat_stream import ChatStream, chat_manager
from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
from src.plugins.utils.timer_calculator import Timer
from src.plugins.group_nickname.nickname_utils import trigger_nickname_analysis_if_needed
logger = get_logger("chat")
@ -286,6 +287,7 @@ class NormalChat:
# 检查 first_bot_msg 是否为 None (例如思考消息已被移除的情况)
if first_bot_msg:
info_catcher.catch_after_response(timing_results["消息发送"], response_set, first_bot_msg)
await trigger_nickname_analysis_if_needed(message, response_set, self.chat_stream)
else:
logger.warning(f"[{self.stream_name}] 思考消息 {thinking_id} 在发送前丢失,无法记录 info_catcher")