feat:优化表达方式效果,现在考虑count

pull/1364/head
SengokuCola 2025-11-18 20:55:33 +08:00
parent e90513bf23
commit 3ed2ddeb91
2 changed files with 49 additions and 8 deletions

View File

@ -61,6 +61,37 @@ def format_create_date(timestamp: float) -> str:
return "未知时间"
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~3之间
count越高权重越高但最多为基础权重的3倍
"""
if not population:
return []
counts = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
return [1.0 for _ in counts]
weights = []
for count_value in counts:
# 线性映射到[1,3]区间
normalized = (count_value - min_count) / (max_count - min_count)
weights.append(1.0 + normalized * 2.0) # 1~3
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
@ -78,15 +109,24 @@ def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
if len(population) <= k:
return population.copy()
# 使用随机抽样
selected = []
selected: List[Dict] = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
for _ in range(min(k, len(population_copy))):
weights = _compute_weights(population_copy)
total_weight = sum(weights)
if total_weight <= 0:
# 回退到均匀随机
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected

View File

@ -139,6 +139,7 @@ class ExpressionSelector:
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"count": expr.count if getattr(expr, "count", None) is not None else 1,
}
for expr in style_query
]