diff --git a/bot.py b/bot.py index 02dee81c..3f3a4e9c 100644 --- a/bot.py +++ b/bot.py @@ -44,7 +44,11 @@ logger = get_logger("main") # 定义重启退出码 RESTART_EXIT_CODE = 42 - +print("-----------------------------------------") +print("\n\n\n\n\n") +print("警告:Dev进入不稳定开发状态,任何插件与WebUI均可能无法正常工作!") +print("\n\n\n\n\n") +print("-----------------------------------------") def run_runner_process(): """ diff --git a/changelogs/mai_next_todo.md b/changelogs/mai_next_todo.md index 573056a5..32243caf 100644 --- a/changelogs/mai_next_todo.md +++ b/changelogs/mai_next_todo.md @@ -26,11 +26,11 @@ version 0.3.0 - 2026-01-11 - [x] 其他上面两个依赖的函数已经合并到这两个函数中 ### ExpressionConfig -- [x] 迁移了原来在`ExpressionConfig`中的方法到一个单独的临时类`TempMethodsGroupGenerator`中 +- [x] 迁移了原来在`ExpressionConfig`中的方法到一个单独的临时类`TempMethodsExpression`中 - [x] get_expression_config_for_chat - [x] 其他上面依赖的函数已经合并到这个函数中 ### ModelConfig -- [x] 迁移了原来在`ModelConfig`中的方法到一个单独的临时类`TempMethodsModelConfig`中 - - [ ] get_model_info - - [ ] get_provider \ No newline at end of file +- [x] 迁移了原来在`ModelConfig`中的方法到一个单独的临时类`TempMethodsLLMUtils`中 + - [x] get_model_info + - [x] get_provider \ No newline at end of file diff --git a/src/bw_learner/expression_selector.py b/src/bw_learner/expression_selector.py index 457e5610..4bcd0dcd 100644 --- a/src/bw_learner/expression_selector.py +++ b/src/bw_learner/expression_selector.py @@ -11,6 +11,7 @@ from src.common.database.database_model import Expression from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.bw_learner.learner_utils import weighted_sample from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.common_utils import TempMethodsExpression logger = get_logger("expression_selector") @@ -59,7 +60,7 @@ class ExpressionSelector: bool: 是否允许使用表达 """ try: - use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id) + use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(chat_id) return use_expression except Exception as e: logger.error(f"检查表达使用权限失败: {e}") diff --git a/src/bw_learner/message_recorder.py b/src/bw_learner/message_recorder.py index fc570909..8e15ab43 100644 --- a/src/bw_learner/message_recorder.py +++ b/src/bw_learner/message_recorder.py @@ -5,6 +5,7 @@ from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive +from src.chat.utils.common_utils import TempMethodsExpression from src.bw_learner.expression_learner import expression_learner_manager from src.bw_learner.jargon_miner import miner_manager @@ -38,7 +39,7 @@ class MessageRecorder: """初始化提取参数""" # 获取 expression 配置 _, self.enable_expression_learning, self.enable_jargon_learning = ( - global_config.expression.get_expression_config_for_chat(self.chat_id) + TempMethodsExpression.get_expression_config_for_chat(self.chat_id) ) self.min_messages_for_extraction = 30 self.min_extraction_interval = 60 diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 029f2be3..360ab088 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -37,6 +37,7 @@ from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt from src.bw_learner.jargon_explainer import explain_jargon_in_context, retrieve_concepts_with_jargon +from src.chat.utils.common_utils import TempMethodsExpression init_lpmm_prompt() init_replyer_prompt() @@ -358,7 +359,7 @@ class DefaultReplyer: str: 表达习惯信息字符串 """ # 检查是否允许在此聊天流中使用表达 - use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id) + use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id) if not use_expression: return "", [] style_habits = [] @@ -1254,66 +1255,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list: pool.pop(idx) break return selected - - -class TempMethodsGroupGenerator: - """用于临时存放一些方法的类""" - - @staticmethod - def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]: - """ - 根据聊天流ID获取表达配置 - - Args: - chat_stream_id: 聊天流ID,格式为哈希值 - - Returns: - tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习) - """ - if not global_config.expression.learning_list: - return True, True, True - - if chat_stream_id: - for config_item in global_config.expression.learning_list: - if not config_item.platform and not config_item.item_id: - continue # 这是全局的 - stream_id = TempMethodsGroupGenerator._get_stream_id( - config_item.platform, - str(config_item.item_id), - (config_item.rule_type == "group"), - ) - if stream_id is None: - continue - if stream_id == chat_stream_id: - continue - return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning - for config_item in global_config.expression.learning_list: - if not config_item.platform and not config_item.item_id: - return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning - - return True, True, True - - @staticmethod - def _get_stream_id( - platform: str, - id_str: str, - is_group: bool = False, - ) -> Optional[str]: - """ - 根据平台、ID字符串和是否为群聊生成聊天流ID - - Args: - platform: 平台名称 - id_str: 用户或群组的原始ID字符串 - is_group: 是否为群聊 - - Returns: - str: 生成的聊天流ID(哈希值) - """ - try: - from src.chat.message_receive.chat_stream import get_chat_manager - - return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) - except Exception as e: - logger.error(f"生成聊天流ID失败: {e}") - return None diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index d73883eb..52bf9c68 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -18,6 +18,7 @@ from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.utils.prompt_builder import global_prompt_manager +from src.chat.utils.common_utils import TempMethodsExpression from src.chat.utils.chat_message_builder import ( build_readable_messages, get_raw_msg_before_timestamp_with_chat, @@ -260,7 +261,7 @@ class PrivateReplyer: str: 表达习惯信息字符串 """ # 检查是否允许在此聊天流中使用表达 - use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id) + use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id) if not use_expression: return "", [] style_habits = [] diff --git a/src/chat/utils/common_utils.py b/src/chat/utils/common_utils.py new file mode 100644 index 00000000..e751381e --- /dev/null +++ b/src/chat/utils/common_utils.py @@ -0,0 +1,68 @@ +from typing import Optional + +from src.config.config import global_config +from src.common.logger import get_logger + +logger = get_logger("common_utils") + +class TempMethodsExpression: + """用于临时存放一些方法的类""" + + @staticmethod + def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]: + """ + 根据聊天流ID获取表达配置 + + Args: + chat_stream_id: 聊天流ID,格式为哈希值 + + Returns: + tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习) + """ + if not global_config.expression.learning_list: + return True, True, True + + if chat_stream_id: + for config_item in global_config.expression.learning_list: + if not config_item.platform and not config_item.item_id: + continue # 这是全局的 + stream_id = TempMethodsExpression._get_stream_id( + config_item.platform, + str(config_item.item_id), + (config_item.rule_type == "group"), + ) + if stream_id is None: + continue + if stream_id == chat_stream_id: + continue + return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning + for config_item in global_config.expression.learning_list: + if not config_item.platform and not config_item.item_id: + return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning + + return True, True, True + + @staticmethod + def _get_stream_id( + platform: str, + id_str: str, + is_group: bool = False, + ) -> Optional[str]: + """ + 根据平台、ID字符串和是否为群聊生成聊天流ID + + Args: + platform: 平台名称 + id_str: 用户或群组的原始ID字符串 + is_group: 是否为群聊 + + Returns: + str: 生成的聊天流ID(哈希值) + """ + try: + from src.chat.message_receive.chat_stream import get_chat_manager + + return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) + except Exception as e: + logger.error(f"生成聊天流ID失败: {e}") + return None diff --git a/src/config/config.py b/src/config/config.py index d6397a66..24f58bcb 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -162,44 +162,6 @@ class ModelConfig(ConfigBase): return super().model_post_init(context) -def get_model_info_by_name(model_config: ModelConfig, model_name: str) -> ModelInfo: - """根据模型名称获取模型信息 - - Args: - model_config: ModelConfig实例 - model_name: 模型名称 - - Returns: - ModelInfo: 模型信息 - - Raises: - ValueError: 未找到指定模型 - """ - for model in model_config.models: - if model.name == model_name: - return model - raise ValueError(f"未找到名为 '{model_name}' 的模型") - - -def get_provider_by_name(model_config: ModelConfig, provider_name: str) -> APIProvider: - """根据提供商名称获取提供商信息 - - Args: - model_config: ModelConfig实例 - provider_name: 提供商名称 - - Returns: - APIProvider: API提供商信息 - - Raises: - ValueError: 未找到指定提供商 - """ - for provider in model_config.api_providers: - if provider.name == provider_name: - return provider - raise ValueError(f"未找到名为 '{provider_name}' 的API提供商") - - class ConfigManager: """总配置管理类""" @@ -321,4 +283,4 @@ def write_config_to_file( config_manager = ConfigManager() config_manager.initialize() global_config = config_manager.get_global_config() -model_config = config_manager.get_model_config() \ No newline at end of file +model_config = config_manager.get_model_config() diff --git a/src/dream/dream_agent.py b/src/dream/dream_agent.py index a7f4df0d..8770b24d 100644 --- a/src/dream/dream_agent.py +++ b/src/dream/dream_agent.py @@ -558,7 +558,7 @@ async def start_dream_scheduler( start_ts = time.time() # 检查当前时间是否在允许做梦的时间段内 - if not global_config.dream.is_in_dream_time(): + if not TempMethodsDream.is_in_dream_time(): logger.debug("[dream] 当前时间不在允许做梦的时间段内,跳过本次执行") else: try: @@ -577,3 +577,34 @@ async def start_dream_scheduler( # 初始化提示词 init_dream_prompts() + + +class TempMethodsDream: + @staticmethod + def is_in_dream_time() -> bool: + if not global_config.dream.dream_time_ranges: + return True + now_min = time.localtime() + now_total_min = now_min.tm_hour * 60 + now_min.tm_min + for time_range in global_config.dream.dream_time_ranges: + if parsed := TempMethodsDream._parse_range(time_range): + start_min, end_min = parsed + if TempMethodsDream._in_range(now_total_min, start_min, end_min): + return True + return False + + @staticmethod + def _in_range(now_min, start_min, end_min) -> bool: + if start_min <= end_min: + return start_min <= now_min <= end_min + return now_min >= start_min or now_min <= end_min + + @staticmethod + def _parse_range(range_str: str) -> Optional[Tuple[int, int]]: + try: + start_str, end_str = [s.strip() for s in range_str.split("-")] + sh, sm = [int(x) for x in start_str.split(":")] + eh, em = [int(x) for x in end_str.split(":")] + return sh * 60 + sm, eh * 60 + em + except Exception: + return None diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 4cde5c1a..295ec271 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Optional, Callable, Any, Set import traceback from src.common.logger import get_logger -from src.config.config import model_config, get_model_info_by_name, get_provider_by_name +from src.config.config import model_config from src.config.model_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message from .payload_content.resp_format import RespFormat @@ -278,7 +278,7 @@ class LLMRequest: raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。") strategy = self.model_for_task.selection_strategy.lower() - + if strategy == "random": # 随机选择策略 selected_model_name = random.choice(list(available_models.keys())) @@ -295,9 +295,9 @@ class LLMRequest: available_models, key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000, ) - - model_info = get_model_info_by_name(model_config, selected_model_name) - api_provider = get_provider_by_name(model_config, model_info.api_provider) + + model_info = TempMethodsLLMUtils.get_model_info_by_name(selected_model_name) + api_provider = TempMethodsLLMUtils.get_provider_by_name(model_info.api_provider) force_new_client = self.request_type == "embedding" client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) logger.debug(f"选择请求模型: {model_info.name} (策略: {strategy})") @@ -456,7 +456,9 @@ class LLMRequest: ) raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e - raise ModelAttemptFailed(f"任务 '{self.request_type or '未知任务'}' 的模型 '{model_info.name}' 未被尝试,因为重试次数已配置为0或更少。") + raise ModelAttemptFailed( + f"任务 '{self.request_type or '未知任务'}' 的模型 '{model_info.name}' 未被尝试,因为重试次数已配置为0或更少。" + ) async def _execute_request( self, @@ -576,3 +578,43 @@ class LLMRequest: original_error_msg = str(e.__cause__) return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}" return "" + + +class TempMethodsLLMUtils: + @staticmethod + def get_model_info_by_name(model_name: str) -> ModelInfo: + """根据模型名称获取模型信息 + + Args: + model_config: ModelConfig实例 + model_name: 模型名称 + + Returns: + ModelInfo: 模型信息 + + Raises: + ValueError: 未找到指定模型 + """ + for model in model_config.models: + if model.name == model_name: + return model + raise ValueError(f"未找到名为 '{model_name}' 的模型") + + @staticmethod + def get_provider_by_name(provider_name: str) -> APIProvider: + """根据提供商名称获取提供商信息 + + Args: + model_config: ModelConfig实例 + provider_name: 提供商名称 + + Returns: + APIProvider: API提供商信息 + + Raises: + ValueError: 未找到指定提供商 + """ + for provider in model_config.api_providers: + if provider.name == provider_name: + return provider + raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")