diff --git a/plugins/ChatFrequency/plugin.py b/plugins/ChatFrequency/plugin.py index 2c50cb5e..e30b43a7 100644 --- a/plugins/ChatFrequency/plugin.py +++ b/plugins/ChatFrequency/plugin.py @@ -112,26 +112,40 @@ class BetterFrequencyPlugin(BasePlugin): # 配置节描述 config_section_descriptions = { "plugin": "插件基本信息", - "frequency": "频率控制配置" + "frequency": "频率控制配置", + "features": "功能开关配置" } # 配置Schema定义 config_schema: dict = { "plugin": { "name": ConfigField(type=str, default="better_frequency_plugin", description="插件名称"), - "version": ConfigField(type=str, default="1.0.0", description="插件版本"), + "version": ConfigField(type=str, default="1.0.1", description="插件版本"), "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), }, "frequency": { "default_talk_adjust": ConfigField(type=float, default=1.0, description="默认talk_frequency调整值"), "max_adjust_value": ConfigField(type=float, default=1.0, description="最大调整值"), "min_adjust_value": ConfigField(type=float, default=0.0, description="最小调整值"), + }, + "features": { + "enable_frequency_adjust_action": ConfigField(type=bool, default=False, description="是否启用频率调节动作(FrequencyAdjustAction)"), + "enable_commands": ConfigField(type=bool, default=True, description="是否启用命令功能(/chat命令)"), } } def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - return [ - (SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand), - (ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand), - (FrequencyAdjustAction.get_action_info(), FrequencyAdjustAction), - ] + components = [] + + # 根据配置决定是否注册命令组件 + if self.config.get("features", {}).get("enable_commands", True): + components.extend([ + (SetTalkFrequencyCommand.get_command_info(), SetTalkFrequencyCommand), + (ShowFrequencyCommand.get_command_info(), ShowFrequencyCommand), + ]) + + # 根据配置决定是否注册频率调节动作组件 + if self.config.get("features", {}).get("enable_frequency_adjust_action", True): + components.append((FrequencyAdjustAction.get_action_info(), FrequencyAdjustAction)) + + return components diff --git a/src/express/express_utils.py b/src/express/express_utils.py new file mode 100644 index 00000000..bf065495 --- /dev/null +++ b/src/express/express_utils.py @@ -0,0 +1,93 @@ +import re +import difflib +import random +from datetime import datetime +from typing import Optional, List, Dict +from collections import defaultdict + + +def filter_message_content(content: Optional[str]) -> str: + """ + 过滤消息内容,移除回复、@、图片等格式 + + Args: + content: 原始消息内容 + + Returns: + str: 过滤后的内容 + """ + if not content: + return "" + + # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 + content = re.sub(r'\[回复.*?\],说:\s*', '', content) + # 移除@<...>格式的内容 + content = re.sub(r'@<[^>]*>', '', content) + # 移除[picid:...]格式的图片ID + content = re.sub(r'\[picid:[^\]]*\]', '', content) + # 移除[表情包:...]格式的内容 + content = re.sub(r'\[表情包:[^\]]*\]', '', content) + + return content.strip() + + +def calculate_similarity(text1: str, text2: str) -> float: + """ + 计算两个文本的相似度,返回0-1之间的值 + 使用SequenceMatcher计算相似度 + + Args: + text1: 第一个文本 + text2: 第二个文本 + + Returns: + float: 相似度值,范围0-1 + """ + return difflib.SequenceMatcher(None, text1, text2).ratio() + + +def format_create_date(timestamp: float) -> str: + """ + 将时间戳格式化为可读的日期字符串 + + Args: + timestamp: 时间戳 + + Returns: + str: 格式化后的日期字符串 + """ + try: + return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, OSError): + return "未知时间" + + +def weighted_sample(population: List[Dict], k: int) -> List[Dict]: + """ + 随机抽样函数 + + Args: + population: 总体数据列表 + k: 需要抽取的数量 + + Returns: + List[Dict]: 抽取的数据列表 + """ + if not population or k <= 0: + return [] + + if len(population) <= k: + return population.copy() + + # 使用随机抽样 + selected = [] + population_copy = population.copy() + + for _ in range(k): + if not population_copy: + break + # 随机选择一个元素 + idx = random.randint(0, len(population_copy) - 1) + selected.append(population_copy.pop(idx)) + + return selected diff --git a/src/express/expression_learner.py b/src/express/expression_learner.py index e0bc6d71..1fb37324 100644 --- a/src/express/expression_learner.py +++ b/src/express/expression_learner.py @@ -18,6 +18,7 @@ from src.chat.utils.chat_message_builder import ( from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager from src.express.style_learner import style_learner_manager +from src.express.express_utils import filter_message_content, calculate_similarity, format_create_date from json_repair import repair_json @@ -26,24 +27,6 @@ from json_repair import repair_json logger = get_logger("expressor") -def calculate_similarity(text1: str, text2: str) -> float: - """ - 计算两个文本的相似度,返回0-1之间的值 - 使用SequenceMatcher计算相似度 - """ - return difflib.SequenceMatcher(None, text1, text2).ratio() - - -def format_create_date(timestamp: float) -> str: - """ - 将时间戳格式化为可读的日期字符串 - """ - try: - return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") - except (ValueError, OSError): - return "未知时间" - - def init_prompt() -> None: learn_style_prompt = """ {chat_str} @@ -457,7 +440,7 @@ class ExpressionLearner: continue prev_original_idx = bare_lines[pos - 1][0] - up_content = self._filter_message_content(random_msg[prev_original_idx].processed_plain_text or "") + up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "") if not up_content: # 上一句为空,跳过该表达 continue @@ -499,30 +482,6 @@ class ExpressionLearner: expressions.append((situation, style)) return expressions - def _filter_message_content(self, content: str) -> str: - """ - 过滤消息内容,移除回复、@、图片等格式 - - Args: - content: 原始消息内容 - - Returns: - str: 过滤后的内容 - """ - if not content: - return "" - - # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 - content = re.sub(r'\[回复.*?\],说:\s*', '', content) - # 移除@<...>格式的内容 - content = re.sub(r'@<[^>]*>', '', content) - # 移除[picid:...]格式的图片ID - content = re.sub(r'\[picid:[^\]]*\]', '', content) - # 移除[表情包:...]格式的内容 - content = re.sub(r'\[表情包:[^\]]*\]', '', content) - - return content.strip() - def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]: """ 为每条消息构建精简文本列表,保留到原消息索引的映射 @@ -537,7 +496,7 @@ class ExpressionLearner: for idx, msg in enumerate(messages): content = msg.processed_plain_text or "" - content = self._filter_message_content(content) + content = filter_message_content(content) # 即使content为空也要记录,防止错位 bare_lines.append((idx, content)) diff --git a/src/express/expression_selector.py b/src/express/expression_selector.py index 9ebef43c..41f8c57e 100644 --- a/src/express/expression_selector.py +++ b/src/express/expression_selector.py @@ -2,6 +2,7 @@ import json import time import random import hashlib +import re from typing import List, Dict, Optional, Any, Tuple from json_repair import repair_json @@ -12,6 +13,7 @@ from src.common.logger import get_logger from src.common.database.database_model import Expression from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.express.style_learner import style_learner_manager +from src.express.express_utils import filter_message_content, weighted_sample logger = get_logger("expression_selector") @@ -44,29 +46,6 @@ def init_prompt(): Prompt(expression_evaluation_prompt, "expression_evaluation_prompt") -def weighted_sample(population: List[Dict], k: int) -> List[Dict]: - """随机抽样""" - if not population or k <= 0: - return [] - - if len(population) <= k: - return population.copy() - - # 使用随机抽样 - selected = [] - population_copy = population.copy() - - for _ in range(k): - if not population_copy: - break - - # 随机选择一个元素 - chosen_idx = random.randint(0, len(population_copy) - 1) - selected.append(population_copy.pop(chosen_idx)) - - return selected - - class ExpressionSelector: def __init__(self): self.llm_model = LLMRequest( @@ -149,6 +128,9 @@ class ExpressionSelector: List[Dict[str, Any]]: 预测的表达方式列表 """ try: + # 过滤目标消息内容,移除回复、表情包等特殊格式 + filtered_target_message = filter_message_content(target_message) + # 支持多chat_id合并预测 related_chat_ids = self.get_related_chat_ids(chat_id) @@ -160,7 +142,7 @@ class ExpressionSelector: try: # 使用 style_learner 预测最合适的风格 best_style, scores = style_learner_manager.predict_style( - related_chat_id, target_message, top_k=total_num + related_chat_id, filtered_target_message, top_k=total_num ) if best_style and scores: @@ -186,7 +168,7 @@ class ExpressionSelector: "source_id": expr.chat_id, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "prediction_score": scores.get(best_style, 0.0), - "prediction_input": target_message + "prediction_input": filtered_target_message }) else: logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式") diff --git a/src/express/expressor_model/tokenizer.py b/src/express/expressor_model/tokenizer.py index 709e6a54..5fd915ae 100644 --- a/src/express/expressor_model/tokenizer.py +++ b/src/express/expressor_model/tokenizer.py @@ -8,6 +8,8 @@ except Exception: _HAS_JIEBA = False _WORD_RE = re.compile(r"[A-Za-z0-9_]+") +# 匹配纯符号的正则表达式 +_SYMBOL_RE = re.compile(r'^[^\w\u4e00-\u9fff]+$') def simple_en_tokenize(text: str) -> List[str]: return _WORD_RE.findall(text.lower()) @@ -25,4 +27,5 @@ class Tokenizer: toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()] else: toks = simple_en_tokenize(text) - return [t for t in toks if t not in self.stopwords] \ No newline at end of file + # 过滤掉纯符号和停用词 + return [t for t in toks if t not in self.stopwords and not _SYMBOL_RE.match(t)] \ No newline at end of file