更好的随机选择

pull/914/head
Bakadax 2025-05-02 06:34:33 +08:00
parent d62e9f0a87
commit 68cef9a725
1 changed files with 68 additions and 50 deletions

View File

@ -27,81 +27,56 @@ def select_nicknames_for_prompt(all_nicknames_info: Dict[str, List[Dict[str, int
按次数降序排序
"""
if not all_nicknames_info:
# 如果输入为空,直接返回空列表
return []
candidates = [] # 候选绰号列表,包含 (用户名, 绰号, 次数, 权重)
candidates = []
for user_name, nicknames in all_nicknames_info.items():
if nicknames:
for nickname_entry in nicknames:
# nickname_entry 应该是 {"绰号": 次数} 格式
if isinstance(nickname_entry, dict) and len(nickname_entry) == 1:
nickname, count = list(nickname_entry.items())[0]
# 确保次数是正整数
if isinstance(count, int) and count > 0:
# 添加平滑因子避免概率为0并让低频词也有机会
weight = count + global_config.NICKNAME_PROBABILITY_SMOOTHING
candidates.append((user_name, nickname, count, weight))
else:
# 日志:记录无效的绰号次数
logger.warning(f"用户 '{user_name}' 的绰号 '{nickname}' 次数无效: {count}。已跳过。")
else:
# 日志:记录无效的绰号条目格式
logger.warning(f"用户 '{user_name}' 的绰号条目格式无效: {nickname_entry}。已跳过。")
if not candidates:
# 如果没有有效的候选绰号,返回空列表
return []
# 计算总权重
total_weight = sum(c[3] for c in candidates)
# 确定需要选择的数量
num_to_select = min(global_config.MAX_NICKNAMES_IN_PROMPT, len(candidates))
if total_weight <= 0:
# 如果所有权重都无效或为0则按原始次数排序选择前 N 个
logger.warning("所有候选绰号的总权重为0或负数将按原始次数选择 Top N。")
candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
selected = candidates[: global_config.MAX_NICKNAMES_IN_PROMPT]
else:
# 计算归一化概率
probabilities = [c[3] / total_weight for c in candidates]
try:
# 调用新的辅助函数进行不重复加权抽样
selected_candidates_with_weight = weighted_sample_without_replacement(candidates, num_to_select)
# 使用概率分布进行加权随机选择(不重复)
num_to_select = min(global_config.MAX_NICKNAMES_IN_PROMPT, len(candidates))
try:
# 实现不重复加权抽样
selected_indices = set()
selected = []
attempts = 0
max_attempts = num_to_select * 5 # 设置最大尝试次数,防止无限循环
# 如果抽样结果数量不足(例如权重问题导致提前退出),可以考虑是否需要补充
if len(selected_candidates_with_weight) < num_to_select:
logger.debug(f"加权随机选择后数量不足 ({len(selected_candidates_with_weight)}/{num_to_select}),补充选择次数最多的。")
# 筛选出未被选中的候选
selected_ids = set((c[0], c[1]) for c in selected_candidates_with_weight) # 使用 (用户名, 绰号) 作为唯一标识
remaining_candidates = [c for c in candidates if (c[0], c[1]) not in selected_ids]
remaining_candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
needed = num_to_select - len(selected_candidates_with_weight)
selected_candidates_with_weight.extend(remaining_candidates[:needed])
while len(selected) < num_to_select and attempts < max_attempts:
# 每次只选一个
chosen_index = random.choices(range(len(candidates)), weights=probabilities, k=1)[0]
if chosen_index not in selected_indices:
selected_indices.add(chosen_index)
selected.append(candidates[chosen_index])
attempts += 1
except Exception as e:
# 日志:记录加权随机选择时发生的错误,并回退到简单选择
logger.error(f"绰号加权随机选择时出错: {e}。将回退到选择次数最多的 Top N。", exc_info=True)
# 出错时回退到选择次数最多的 N 个
candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
# 注意:这里需要选择包含权重的元组,或者调整后续处理
selected_candidates_with_weight = candidates[:num_to_select]
# 如果尝试多次后仍未选够,补充出现次数最多的
if len(selected) < num_to_select:
logger.debug(f"加权随机选择后数量不足 ({len(selected)}/{num_to_select}),补充选择次数最多的。")
remaining_candidates = [c for i, c in enumerate(candidates) if i not in selected_indices]
remaining_candidates.sort(key=lambda x: x[2], reverse=True) # 按原始次数排序
needed = num_to_select - len(selected)
selected.extend(remaining_candidates[:needed])
except Exception as e:
# 日志:记录加权随机选择时发生的错误,并回退到简单选择
logger.error(f"绰号加权随机选择时出错: {e}。将回退到选择次数最多的 Top N。", exc_info=True)
# 出错时回退到选择次数最多的 N 个
candidates.sort(key=lambda x: x[2], reverse=True)
selected = candidates[: global_config.MAX_NICKNAMES_IN_PROMPT]
# 格式化输出结果为 (用户名, 绰号, 次数),移除权重
result = [(user, nick, count) for user, nick, count, _weight in selected_candidates_with_weight]
# 格式化输出结果为 (用户名, 绰号, 次数)
result = [(user, nick, count) for user, nick, count, _weight in selected]
result.sort(key=lambda x: x[2], reverse=True) # 按次数降序
result.sort(key=lambda x: x[2], reverse=True) # 按次数降序
# 日志:记录最终选中的用于 Prompt 的绰号
logger.debug(f"为 Prompt 选择的绰号: {result}")
return result
@ -309,3 +284,46 @@ async def trigger_nickname_analysis_if_needed(
except Exception as e:
# 日志:记录触发分析过程中发生的任何其他错误
logger.error(f"{log_prefix} 触发绰号分析时出错: {e}", exc_info=True)
def weighted_sample_without_replacement(candidates: List[Tuple[str, str, int, float]], k: int) -> List[Tuple[str, str, int, float]]:
"""
执行不重复的加权随机抽样
Args:
candidates: 候选列表每个元素为 (用户名, 绰号, 次数, 权重)
k: 需要选择的数量
Returns:
List[Tuple[str, str, int, float]]: 选中的元素列表
"""
if k <= 0:
return []
if k >= len(candidates):
# 如果需要选择的数量大于或等于候选数量,直接返回所有候选
return candidates[:] # 返回副本以避免修改原始列表
pool = candidates[:] # 创建候选列表的副本进行操作
selected = []
# 注意:原评论代码中计算 total_weight 但未使用,这里也省略。
# random.choices 内部会处理权重的归一化。
for _ in range(min(k, len(pool))): # 确保迭代次数不超过池中剩余元素
if not pool: # 如果池已空,提前结束
break
weights = [c[3] for c in pool] # 获取当前池中所有元素的权重
# 检查权重是否有效
if sum(weights) <= 0:
# 如果所有剩余权重无效,随机选择一个(或根据需要采取其他策略)
logger.warning("加权抽样池中剩余权重总和为0或负数随机选择一个。")
chosen_index = random.randrange(len(pool))
chosen = pool.pop(chosen_index)
else:
# 使用 random.choices 进行加权抽样,选择 1 个
# random.choices 返回一个列表,所以取第一个元素 [0]
chosen = random.choices(pool, weights=weights, k=1)[0]
pool.remove(chosen) # 从池中移除选中的元素,实现不重复抽样
selected.append(chosen)
return selected