mirror of https://github.com/Mai-with-u/MaiBot.git
better:优化了表达方式采样
parent
d5f17b1f89
commit
cb500e069a
|
|
@ -112,26 +112,40 @@ class BetterFrequencyPlugin(BasePlugin):
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
config_section_descriptions = {
|
config_section_descriptions = {
|
||||||
"plugin": "插件基本信息",
|
"plugin": "插件基本信息",
|
||||||
"frequency": "频率控制配置"
|
"frequency": "频率控制配置",
|
||||||
|
"features": "功能开关配置"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 配置Schema定义
|
# 配置Schema定义
|
||||||
config_schema: dict = {
|
config_schema: dict = {
|
||||||
"plugin": {
|
"plugin": {
|
||||||
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
|
"name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"),
|
||||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
"version": ConfigField(type=str, default="1.0.1", description="插件版本"),
|
||||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||||
},
|
},
|
||||||
"frequency": {
|
"frequency": {
|
||||||
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
|
"default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"),
|
||||||
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
|
"max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"),
|
||||||
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
|
"min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"),
|
||||||
|
},
|
||||||
|
"features": {
|
||||||
|
"enable_frequency_adjust_action": ConfigField(type=bool, default=False, description="是否启用频率调节动作(FrequencyAdjustAction)"),
|
||||||
|
"enable_commands": ConfigField(type=bool, default=True, description="是否启用命令功能(/chat命令)"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||||
return [
|
components = []
|
||||||
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
|
|
||||||
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
|
# 根据配置决定是否注册命令组件
|
||||||
(FrequencyAdjustAction.get_action_info(), FrequencyAdjustAction),
|
if self.config.get("features", {}).get("enable_commands", True):
|
||||||
]
|
components.extend([
|
||||||
|
(SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand),
|
||||||
|
(ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 根据配置决定是否注册频率调节动作组件
|
||||||
|
if self.config.get("features", {}).get("enable_frequency_adjust_action", True):
|
||||||
|
components.append((FrequencyAdjustAction.get_action_info(), FrequencyAdjustAction))
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
import re
|
||||||
|
import difflib
|
||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Dict
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
def filter_message_content(content: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
过滤消息内容,移除回复、@、图片等格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 原始消息内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 过滤后的内容
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||||
|
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
|
||||||
|
# 移除@<...>格式的内容
|
||||||
|
content = re.sub(r'@<[^>]*>', '', content)
|
||||||
|
# 移除[picid:...]格式的图片ID
|
||||||
|
content = re.sub(r'\[picid:[^\]]*\]', '', content)
|
||||||
|
# 移除[表情包:...]格式的内容
|
||||||
|
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
|
||||||
|
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_similarity(text1: str, text2: str) -> float:
|
||||||
|
"""
|
||||||
|
计算两个文本的相似度,返回0-1之间的值
|
||||||
|
使用SequenceMatcher计算相似度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text1: 第一个文本
|
||||||
|
text2: 第二个文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 相似度值,范围0-1
|
||||||
|
"""
|
||||||
|
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||||
|
|
||||||
|
|
||||||
|
def format_create_date(timestamp: float) -> str:
|
||||||
|
"""
|
||||||
|
将时间戳格式化为可读的日期字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: 时间戳
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的日期字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
return "未知时间"
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
随机抽样函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
population: 总体数据列表
|
||||||
|
k: 需要抽取的数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 抽取的数据列表
|
||||||
|
"""
|
||||||
|
if not population or k <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(population) <= k:
|
||||||
|
return population.copy()
|
||||||
|
|
||||||
|
# 使用随机抽样
|
||||||
|
selected = []
|
||||||
|
population_copy = population.copy()
|
||||||
|
|
||||||
|
for _ in range(k):
|
||||||
|
if not population_copy:
|
||||||
|
break
|
||||||
|
# 随机选择一个元素
|
||||||
|
idx = random.randint(0, len(population_copy) - 1)
|
||||||
|
selected.append(population_copy.pop(idx))
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
@ -18,6 +18,7 @@ from src.chat.utils.chat_message_builder import (
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.express.style_learner import style_learner_manager
|
from src.express.style_learner import style_learner_manager
|
||||||
|
from src.express.express_utils import filter_message_content, calculate_similarity, format_create_date
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -26,24 +27,6 @@ from json_repair import repair_json
|
||||||
logger = get_logger("expressor")
|
logger = get_logger("expressor")
|
||||||
|
|
||||||
|
|
||||||
def calculate_similarity(text1: str, text2: str) -> float:
|
|
||||||
"""
|
|
||||||
计算两个文本的相似度,返回0-1之间的值
|
|
||||||
使用SequenceMatcher计算相似度
|
|
||||||
"""
|
|
||||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
|
||||||
|
|
||||||
|
|
||||||
def format_create_date(timestamp: float) -> str:
|
|
||||||
"""
|
|
||||||
将时间戳格式化为可读的日期字符串
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
except (ValueError, OSError):
|
|
||||||
return "未知时间"
|
|
||||||
|
|
||||||
|
|
||||||
def init_prompt() -> None:
|
def init_prompt() -> None:
|
||||||
learn_style_prompt = """
|
learn_style_prompt = """
|
||||||
{chat_str}
|
{chat_str}
|
||||||
|
|
@ -457,7 +440,7 @@ class ExpressionLearner:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prev_original_idx = bare_lines[pos - 1][0]
|
prev_original_idx = bare_lines[pos - 1][0]
|
||||||
up_content = self._filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
|
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
|
||||||
if not up_content:
|
if not up_content:
|
||||||
# 上一句为空,跳过该表达
|
# 上一句为空,跳过该表达
|
||||||
continue
|
continue
|
||||||
|
|
@ -499,30 +482,6 @@ class ExpressionLearner:
|
||||||
expressions.append((situation, style))
|
expressions.append((situation, style))
|
||||||
return expressions
|
return expressions
|
||||||
|
|
||||||
def _filter_message_content(self, content: str) -> str:
|
|
||||||
"""
|
|
||||||
过滤消息内容,移除回复、@、图片等格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 原始消息内容
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 过滤后的内容
|
|
||||||
"""
|
|
||||||
if not content:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
|
||||||
content = re.sub(r'\[回复.*?\],说:\s*', '', content)
|
|
||||||
# 移除@<...>格式的内容
|
|
||||||
content = re.sub(r'@<[^>]*>', '', content)
|
|
||||||
# 移除[picid:...]格式的图片ID
|
|
||||||
content = re.sub(r'\[picid:[^\]]*\]', '', content)
|
|
||||||
# 移除[表情包:...]格式的内容
|
|
||||||
content = re.sub(r'\[表情包:[^\]]*\]', '', content)
|
|
||||||
|
|
||||||
return content.strip()
|
|
||||||
|
|
||||||
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
|
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
|
||||||
"""
|
"""
|
||||||
为每条消息构建精简文本列表,保留到原消息索引的映射
|
为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||||
|
|
@ -537,7 +496,7 @@ class ExpressionLearner:
|
||||||
|
|
||||||
for idx, msg in enumerate(messages):
|
for idx, msg in enumerate(messages):
|
||||||
content = msg.processed_plain_text or ""
|
content = msg.processed_plain_text or ""
|
||||||
content = self._filter_message_content(content)
|
content = filter_message_content(content)
|
||||||
# 即使content为空也要记录,防止错位
|
# 即使content为空也要记录,防止错位
|
||||||
bare_lines.append((idx, content))
|
bare_lines.append((idx, content))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import json
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import re
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
@ -12,6 +13,7 @@ from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.express.style_learner import style_learner_manager
|
from src.express.style_learner import style_learner_manager
|
||||||
|
from src.express.express_utils import filter_message_content, weighted_sample
|
||||||
|
|
||||||
logger = get_logger("expression_selector")
|
logger = get_logger("expression_selector")
|
||||||
|
|
||||||
|
|
@ -44,29 +46,6 @@ def init_prompt():
|
||||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
|
||||||
"""随机抽样"""
|
|
||||||
if not population or k <= 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if len(population) <= k:
|
|
||||||
return population.copy()
|
|
||||||
|
|
||||||
# 使用随机抽样
|
|
||||||
selected = []
|
|
||||||
population_copy = population.copy()
|
|
||||||
|
|
||||||
for _ in range(k):
|
|
||||||
if not population_copy:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 随机选择一个元素
|
|
||||||
chosen_idx = random.randint(0, len(population_copy) - 1)
|
|
||||||
selected.append(population_copy.pop(chosen_idx))
|
|
||||||
|
|
||||||
return selected
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionSelector:
|
class ExpressionSelector:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm_model = LLMRequest(
|
self.llm_model = LLMRequest(
|
||||||
|
|
@ -149,6 +128,9 @@ class ExpressionSelector:
|
||||||
List[Dict[str, Any]]: 预测的表达方式列表
|
List[Dict[str, Any]]: 预测的表达方式列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 过滤目标消息内容,移除回复、表情包等特殊格式
|
||||||
|
filtered_target_message = filter_message_content(target_message)
|
||||||
|
|
||||||
# 支持多chat_id合并预测
|
# 支持多chat_id合并预测
|
||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
|
|
||||||
|
|
@ -160,7 +142,7 @@ class ExpressionSelector:
|
||||||
try:
|
try:
|
||||||
# 使用 style_learner 预测最合适的风格
|
# 使用 style_learner 预测最合适的风格
|
||||||
best_style, scores = style_learner_manager.predict_style(
|
best_style, scores = style_learner_manager.predict_style(
|
||||||
related_chat_id, target_message, top_k=total_num
|
related_chat_id, filtered_target_message, top_k=total_num
|
||||||
)
|
)
|
||||||
|
|
||||||
if best_style and scores:
|
if best_style and scores:
|
||||||
|
|
@ -186,7 +168,7 @@ class ExpressionSelector:
|
||||||
"source_id": expr.chat_id,
|
"source_id": expr.chat_id,
|
||||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||||
"prediction_score": scores.get(best_style, 0.0),
|
"prediction_score": scores.get(best_style, 0.0),
|
||||||
"prediction_input": target_message
|
"prediction_input": filtered_target_message
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
|
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ except Exception:
|
||||||
_HAS_JIEBA = False
|
_HAS_JIEBA = False
|
||||||
|
|
||||||
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||||
|
# 匹配纯符号的正则表达式
|
||||||
|
_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$')
|
||||||
|
|
||||||
def simple_en_tokenize(text: str) -> List[str]:
|
def simple_en_tokenize(text: str) -> List[str]:
|
||||||
return _WORD_RE.findall(text.lower())
|
return _WORD_RE.findall(text.lower())
|
||||||
|
|
@ -25,4 +27,5 @@ class Tokenizer:
|
||||||
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
|
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
|
||||||
else:
|
else:
|
||||||
toks = simple_en_tokenize(text)
|
toks = simple_en_tokenize(text)
|
||||||
return [t for t in toks if t not in self.stopwords]
|
# 过滤掉纯符号和停用词
|
||||||
|
return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)]
|
||||||
Loading…
Reference in New Issue