From 3ed2ddeb916e7e74a73396ccb458cd1aa81719d9 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 18 Nov 2025 20:55:33 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E4=BC=98=E5=8C=96=E8=A1=A8?= =?UTF-8?q?=E8=BE=BE=E6=96=B9=E5=BC=8F=E6=95=88=E6=9E=9C=EF=BC=8C=E7=8E=B0?= =?UTF-8?q?=E5=9C=A8=E8=80=83=E8=99=91count?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/express/express_utils.py | 56 +++++++++++++++++++++++++----- src/express/expression_selector.py | 1 + 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/src/express/express_utils.py b/src/express/express_utils.py index c27306d1..20095561 100644 --- a/src/express/express_utils.py +++ b/src/express/express_utils.py @@ -61,6 +61,37 @@ def format_create_date(timestamp: float) -> str: return "未知时间" +def _compute_weights(population: List[Dict]) -> List[float]: + """ + 根据表达的count计算权重,范围限定在1~3之间。 + count越高,权重越高,但最多为基础权重的3倍。 + """ + if not population: + return [] + + counts = [] + for item in population: + count = item.get("count", 1) + try: + count_value = float(count) + except (TypeError, ValueError): + count_value = 1.0 + counts.append(max(count_value, 0.0)) + + min_count = min(counts) + max_count = max(counts) + + if max_count == min_count: + return [1.0 for _ in counts] + + weights = [] + for count_value in counts: + # 线性映射到[1,3]区间 + normalized = (count_value - min_count) / (max_count - min_count) + weights.append(1.0 + normalized * 2.0) # 1~3 + return weights + + def weighted_sample(population: List[Dict], k: int) -> List[Dict]: """ 随机抽样函数 @@ -78,15 +109,24 @@ def weighted_sample(population: List[Dict], k: int) -> List[Dict]: if len(population) <= k: return population.copy() - # 使用随机抽样 - selected = [] + selected: List[Dict] = [] population_copy = population.copy() - for _ in range(k): - if not population_copy: - break - # 随机选择一个元素 - idx = random.randint(0, len(population_copy) - 1) - selected.append(population_copy.pop(idx)) + for _ in range(min(k, len(population_copy))): + weights = _compute_weights(population_copy) + total_weight = sum(weights) + if total_weight <= 0: + # 回退到均匀随机 + idx = random.randint(0, len(population_copy) - 1) + selected.append(population_copy.pop(idx)) + continue + + threshold = random.uniform(0, total_weight) + cumulative = 0.0 + for idx, weight in enumerate(weights): + cumulative += weight + if threshold <= cumulative: + selected.append(population_copy.pop(idx)) + break return selected diff --git a/src/express/expression_selector.py b/src/express/expression_selector.py index 031cc714..66c49def 100644 --- a/src/express/expression_selector.py +++ b/src/express/expression_selector.py @@ -139,6 +139,7 @@ class ExpressionSelector: "last_active_time": expr.last_active_time, "source_id": expr.chat_id, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, + "count": expr.count if getattr(expr, "count", None) is not None else 1, } for expr in style_query ]