mirror of https://github.com/Mai-with-u/MaiBot.git
更好的随机选择
parent
d62e9f0a87
commit
68cef9a725
|
|
@ -27,81 +27,56 @@ def select_nicknames_for_prompt(all_nicknames_info: Dict[str, List[Dict[str, int
|
|||
按次数降序排序。
|
||||
"""
|
||||
if not all_nicknames_info:
|
||||
# 如果输入为空,直接返回空列表
|
||||
return []
|
||||
|
||||
candidates = [] # 候选绰号列表,包含 (用户名, 绰号, 次数, 权重)
|
||||
candidates = []
|
||||
for user_name, nicknames in all_nicknames_info.items():
|
||||
if nicknames:
|
||||
for nickname_entry in nicknames:
|
||||
# nickname_entry 应该是 {"绰号": 次数} 格式
|
||||
if isinstance(nickname_entry, dict) and len(nickname_entry) == 1:
|
||||
nickname, count = list(nickname_entry.items())[0]
|
||||
# 确保次数是正整数
|
||||
if isinstance(count, int) and count > 0:
|
||||
# 添加平滑因子,避免概率为0,并让低频词也有机会
|
||||
weight = count + global_config.NICKNAME_PROBABILITY_SMOOTHING
|
||||
candidates.append((user_name, nickname, count, weight))
|
||||
else:
|
||||
# 日志:记录无效的绰号次数
|
||||
logger.warning(f"用户 '{user_name}' 的绰号 '{nickname}' 次数无效: {count}。已跳过。")
|
||||
else:
|
||||
# 日志:记录无效的绰号条目格式
|
||||
logger.warning(f"用户 '{user_name}' 的绰号条目格式无效: {nickname_entry}。已跳过。")
|
||||
|
||||
if not candidates:
|
||||
# 如果没有有效的候选绰号,返回空列表
|
||||
return []
|
||||
|
||||
# 计算总权重
|
||||
total_weight = sum(c[3] for c in candidates)
|
||||
# 确定需要选择的数量
|
||||
num_to_select = min(global_config.MAX_NICKNAMES_IN_PROMPT, len(candidates))
|
||||
|
||||
if total_weight <= 0:
|
||||
# 如果所有权重都无效或为0,则按原始次数排序选择前 N 个
|
||||
logger.warning("所有候选绰号的总权重为0或负数,将按原始次数选择 Top N。")
|
||||
candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
|
||||
selected = candidates[: global_config.MAX_NICKNAMES_IN_PROMPT]
|
||||
else:
|
||||
# 计算归一化概率
|
||||
probabilities = [c[3] / total_weight for c in candidates]
|
||||
try:
|
||||
# 调用新的辅助函数进行不重复加权抽样
|
||||
selected_candidates_with_weight = weighted_sample_without_replacement(candidates, num_to_select)
|
||||
|
||||
# 使用概率分布进行加权随机选择(不重复)
|
||||
num_to_select = min(global_config.MAX_NICKNAMES_IN_PROMPT, len(candidates))
|
||||
try:
|
||||
# 实现不重复加权抽样
|
||||
selected_indices = set()
|
||||
selected = []
|
||||
attempts = 0
|
||||
max_attempts = num_to_select * 5 # 设置最大尝试次数,防止无限循环
|
||||
# 如果抽样结果数量不足(例如权重问题导致提前退出),可以考虑是否需要补充
|
||||
if len(selected_candidates_with_weight) < num_to_select:
|
||||
logger.debug(f"加权随机选择后数量不足 ({len(selected_candidates_with_weight)}/{num_to_select}),补充选择次数最多的。")
|
||||
# 筛选出未被选中的候选
|
||||
selected_ids = set((c[0], c[1]) for c in selected_candidates_with_weight) # 使用 (用户名, 绰号) 作为唯一标识
|
||||
remaining_candidates = [c for c in candidates if (c[0], c[1]) not in selected_ids]
|
||||
remaining_candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
|
||||
needed = num_to_select - len(selected_candidates_with_weight)
|
||||
selected_candidates_with_weight.extend(remaining_candidates[:needed])
|
||||
|
||||
while len(selected) < num_to_select and attempts < max_attempts:
|
||||
# 每次只选一个
|
||||
chosen_index = random.choices(range(len(candidates)), weights=probabilities, k=1)[0]
|
||||
if chosen_index not in selected_indices:
|
||||
selected_indices.add(chosen_index)
|
||||
selected.append(candidates[chosen_index])
|
||||
attempts += 1
|
||||
except Exception as e:
|
||||
# 日志:记录加权随机选择时发生的错误,并回退到简单选择
|
||||
logger.error(f"绰号加权随机选择时出错: {e}。将回退到选择次数最多的 Top N。", exc_info=True)
|
||||
# 出错时回退到选择次数最多的 N 个
|
||||
candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
|
||||
# 注意:这里需要选择包含权重的元组,或者调整后续处理
|
||||
selected_candidates_with_weight = candidates[:num_to_select]
|
||||
|
||||
# 如果尝试多次后仍未选够,补充出现次数最多的
|
||||
if len(selected) < num_to_select:
|
||||
logger.debug(f"加权随机选择后数量不足 ({len(selected)}/{num_to_select}),补充选择次数最多的。")
|
||||
remaining_candidates = [c for i, c in enumerate(candidates) if i not in selected_indices]
|
||||
remaining_candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
|
||||
needed = num_to_select - len(selected)
|
||||
selected.extend(remaining_candidates[:needed])
|
||||
|
||||
except Exception as e:
|
||||
# 日志:记录加权随机选择时发生的错误,并回退到简单选择
|
||||
logger.error(f"绰号加权随机选择时出错: {e}。将回退到选择次数最多的 Top N。", exc_info=True)
|
||||
# 出错时回退到选择次数最多的 N 个
|
||||
candidates.sort(key=lambda x: x[2], reverse=True)
|
||||
selected = candidates[: global_config.MAX_NICKNAMES_IN_PROMPT]
|
||||
# 格式化输出结果为 (用户名, 绰号, 次数),移除权重
|
||||
result = [(user, nick, count) for user, nick, count, _weight in selected_candidates_with_weight]
|
||||
|
||||
# 格式化输出结果为 (用户名, 绰号, 次数)
|
||||
result = [(user, nick, count) for user, nick, count, _weight in selected]
|
||||
result.sort(key=lambda x: x[2], reverse=True) # 按次数降序
|
||||
result.sort(key=lambda x: x[2], reverse=True) # 按次数降序
|
||||
|
||||
# 日志:记录最终选中的用于 Prompt 的绰号
|
||||
logger.debug(f"为 Prompt 选择的绰号: {result}")
|
||||
return result
|
||||
|
||||
|
|
@ -309,3 +284,46 @@ async def trigger_nickname_analysis_if_needed(
|
|||
except Exception as e:
|
||||
# 日志:记录触发分析过程中发生的任何其他错误
|
||||
logger.error(f"{log_prefix} 触发绰号分析时出错: {e}", exc_info=True)
|
||||
|
||||
def weighted_sample_without_replacement(candidates: List[Tuple[str, str, int, float]], k: int) -> List[Tuple[str, str, int, float]]:
|
||||
"""
|
||||
执行不重复的加权随机抽样。
|
||||
|
||||
Args:
|
||||
candidates: 候选列表,每个元素为 (用户名, 绰号, 次数, 权重)。
|
||||
k: 需要选择的数量。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, int, float]]: 选中的元素列表。
|
||||
"""
|
||||
if k <= 0:
|
||||
return []
|
||||
if k >= len(candidates):
|
||||
# 如果需要选择的数量大于或等于候选数量,直接返回所有候选
|
||||
return candidates[:] # 返回副本以避免修改原始列表
|
||||
|
||||
pool = candidates[:] # 创建候选列表的副本进行操作
|
||||
selected = []
|
||||
# 注意:原评论代码中计算 total_weight 但未使用,这里也省略。
|
||||
# random.choices 内部会处理权重的归一化。
|
||||
|
||||
for _ in range(min(k, len(pool))): # 确保迭代次数不超过池中剩余元素
|
||||
if not pool: # 如果池已空,提前结束
|
||||
break
|
||||
|
||||
weights = [c[3] for c in pool] # 获取当前池中所有元素的权重
|
||||
# 检查权重是否有效
|
||||
if sum(weights) <= 0:
|
||||
# 如果所有剩余权重无效,随机选择一个(或根据需要采取其他策略)
|
||||
logger.warning("加权抽样池中剩余权重总和为0或负数,随机选择一个。")
|
||||
chosen_index = random.randrange(len(pool))
|
||||
chosen = pool.pop(chosen_index)
|
||||
else:
|
||||
# 使用 random.choices 进行加权抽样,选择 1 个
|
||||
# random.choices 返回一个列表,所以取第一个元素 [0]
|
||||
chosen = random.choices(pool, weights=weights, k=1)[0]
|
||||
pool.remove(chosen) # 从池中移除选中的元素,实现不重复抽样
|
||||
|
||||
selected.append(chosen)
|
||||
|
||||
return selected
|
||||
Loading…
Reference in New Issue