better:优化了表达方式采样

pull/1299/head
SengokuCola 2025-10-14 12:36:23 +08:00
parent d5f17b1f89
commit cb500e069a
5 changed files with 128 additions and 77 deletions

View File

@ -112,26 +112,40 @@ class BetterFrequencyPlugin(BasePlugin):
# 配置节描述 # 配置节描述
config_section_descriptions = { config_section_descriptions = {
"plugin": "插件基本信息", "plugin": "插件基本信息",
"frequency": "频率控制配置" "frequency": "频率控制配置",
"features": "功能开关配置"
} }
# 配置Schema定义 # 配置Schema定义
config_schema: dict = { config_schema: dict = {
"plugin": { "plugin": {
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"), "name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"), "version": ConfigField(type=str, default="1.0.1", description="插件版本"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"), "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
}, },
"frequency": { "frequency": {
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"), "default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"), "max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"), "min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
},
"features": {
"enable_frequency_adjust_action": ConfigField(type=bool, default=False, description="是否启用频率调节动作FrequencyAdjustAction"),
"enable_commands": ConfigField(type=bool, default=True, description="是否启用命令功能(/chat命令"),
} }
} }
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [ components = []
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand), # 根据配置决定是否注册命令组件
(FrequencyAdjustAction.get_action_info(), FrequencyAdjustAction), if self.config.get("features", {}).get("enable_commands", True):
] components.extend([
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
])
# 根据配置决定是否注册频率调节动作组件
if self.config.get("features", {}).get("enable_frequency_adjust_action", True):
components.append((FrequencyAdjustAction.get_action_info(), FrequencyAdjustAction))
return components

View File

@ -0,0 +1,93 @@
import re
import difflib
import random
from datetime import datetime
from typing import Optional, List, Dict
from collections import defaultdict
def filter_message_content(content: Optional[str]) -> str:
"""
过滤消息内容移除回复@图片等格式
Args:
content: 原始消息内容
Returns:
str: 过滤后的内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
# 移除@<...>格式的内容
content = re.sub(r'@<[^>]*>', '', content)
# 移除[picid:...]格式的图片ID
content = re.sub(r'\[picid:[^\]]*\]', '', content)
# 移除[表情包:...]格式的内容
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
return content.strip()
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度值范围0-1
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
Args:
timestamp: 时间戳
Returns:
str: 格式化后的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用随机抽样
selected = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
return selected

View File

@ -18,6 +18,7 @@ from src.chat.utils.chat_message_builder import (
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.style_learner import style_learner_manager from src.express.style_learner import style_learner_manager
from src.express.express_utils import filter_message_content, calculate_similarity, format_create_date
from json_repair import repair_json from json_repair import repair_json
@ -26,24 +27,6 @@ from json_repair import repair_json
logger = get_logger("expressor") logger = get_logger("expressor")
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def init_prompt() -> None: def init_prompt() -> None:
learn_style_prompt = """ learn_style_prompt = """
{chat_str} {chat_str}
@ -457,7 +440,7 @@ class ExpressionLearner:
continue continue
prev_original_idx = bare_lines[pos - 1][0] prev_original_idx = bare_lines[pos - 1][0]
up_content = self._filter_message_content(random_msg[prev_original_idx].processed_plain_text or "") up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
if not up_content: if not up_content:
# 上一句为空,跳过该表达 # 上一句为空,跳过该表达
continue continue
@ -499,30 +482,6 @@ class ExpressionLearner:
expressions.append((situation, style)) expressions.append((situation, style))
return expressions return expressions
def _filter_message_content(self, content: str) -> str:
"""
过滤消息内容移除回复@图片等格式
Args:
content: 原始消息内容
Returns:
str: 过滤后的内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
# 移除@<...>格式的内容
content = re.sub(r'@<[^>]*>', '', content)
# 移除[picid:...]格式的图片ID
content = re.sub(r'\[picid:[^\]]*\]', '', content)
# 移除[表情包:...]格式的内容
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
return content.strip()
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]: def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
""" """
为每条消息构建精简文本列表保留到原消息索引的映射 为每条消息构建精简文本列表保留到原消息索引的映射
@ -537,7 +496,7 @@ class ExpressionLearner:
for idx, msg in enumerate(messages): for idx, msg in enumerate(messages):
content = msg.processed_plain_text or "" content = msg.processed_plain_text or ""
content = self._filter_message_content(content) content = filter_message_content(content)
# 即使content为空也要记录防止错位 # 即使content为空也要记录防止错位
bare_lines.append((idx, content)) bare_lines.append((idx, content))

View File

@ -2,6 +2,7 @@ import json
import time import time
import random import random
import hashlib import hashlib
import re
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json from json_repair import repair_json
@ -12,6 +13,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.express.style_learner import style_learner_manager from src.express.style_learner import style_learner_manager
from src.express.express_utils import filter_message_content, weighted_sample
logger = get_logger("expression_selector") logger = get_logger("expression_selector")
@ -44,29 +46,6 @@ def init_prompt():
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt") Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""随机抽样"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用随机抽样
selected = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
chosen_idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(chosen_idx))
return selected
class ExpressionSelector: class ExpressionSelector:
def __init__(self): def __init__(self):
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
@ -149,6 +128,9 @@ class ExpressionSelector:
List[Dict[str, Any]]: 预测的表达方式列表 List[Dict[str, Any]]: 预测的表达方式列表
""" """
try: try:
# 过滤目标消息内容,移除回复、表情包等特殊格式
filtered_target_message = filter_message_content(target_message)
# 支持多chat_id合并预测 # 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
@ -160,7 +142,7 @@ class ExpressionSelector:
try: try:
# 使用 style_learner 预测最合适的风格 # 使用 style_learner 预测最合适的风格
best_style, scores = style_learner_manager.predict_style( best_style, scores = style_learner_manager.predict_style(
related_chat_id, target_message, top_k=total_num related_chat_id, filtered_target_message, top_k=total_num
) )
if best_style and scores: if best_style and scores:
@ -186,7 +168,7 @@ class ExpressionSelector:
"source_id": expr.chat_id, "source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0), "prediction_score": scores.get(best_style, 0.0),
"prediction_input": target_message "prediction_input": filtered_target_message
}) })
else: else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式") logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")

View File

@ -8,6 +8,8 @@ except Exception:
_HAS_JIEBA = False _HAS_JIEBA = False
_WORD_RE = re.compile(r"[A-Za-z0-9_]+") _WORD_RE = re.compile(r"[A-Za-z0-9_]+")
# 匹配纯符号的正则表达式
_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$')
def simple_en_tokenize(text: str) -> List[str]: def simple_en_tokenize(text: str) -> List[str]:
return _WORD_RE.findall(text.lower()) return _WORD_RE.findall(text.lower())
@ -25,4 +27,5 @@ class Tokenizer:
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()] toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
else: else:
toks = simple_en_tokenize(text) toks = simple_en_tokenize(text)
return [t for t in toks if t not in self.stopwords] # 过滤掉纯符号和停用词
return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)]