From dc365424036b3082e3f4fcc965ca61aa35ca0ca8 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 14 Feb 2026 21:17:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=96=87=E4=BB=B6=E7=9B=91?= =?UTF-8?q?=E8=A7=86=E5=99=A8=E5=9C=B0=E5=9F=BA=E6=A8=A1=E5=9D=97=EF=BC=8C?= =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=A8=A1=E5=9E=8B=E8=AF=B7=E6=B1=82=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=E4=BD=BF=E7=94=A8=E6=96=B0=E7=89=88=E6=9C=AC=E7=9A=84?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E7=83=AD=E9=87=8D=E8=BD=BD=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9Ewatchfiles=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 1 + requirements.txt | 3 +- src/config/config.py | 90 +++++++++++++++++++++++++++++++---- src/config/config_utils.py | 2 +- src/config/file_watcher.py | 68 ++++++++++++++++++++++++++ src/llm_models/utils_model.py | 53 ++++++++++++++++++--- src/main.py | 15 ++++-- 7 files changed, 210 insertions(+), 22 deletions(-) create mode 100644 src/config/file_watcher.py diff --git a/pyproject.toml b/pyproject.toml index dcee3892..b84324fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "urllib3>=2.5.0", "uvicorn>=0.35.0", "msgpack>=1.1.2", + "watchfiles>=1.1.1", ] diff --git a/requirements.txt b/requirements.txt index 65105819..2ddfe645 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,5 @@ toml>=0.10.2 tomlkit>=0.13.3 urllib3>=2.5.0 uvicorn>=0.35.0 -msgpack>=1.1.2 \ No newline at end of file +msgpack>=1.1.2 +watchfiles>=1.1.1 diff --git a/src/config/config.py b/src/config/config.py index 5495d424..c1d23018 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,7 +1,7 @@ from pathlib import Path -from typing import TypeVar +from typing import Any, Callable, Mapping, Sequence, TypeVar from datetime import datetime -from typing import Any +import asyncio import copy import tomlkit @@ -38,6 +38,7 @@ from .config_base import ConfigBase, Field, AttributeData from .config_utils import recursive_parse_item_to_table, output_config_changes, compare_versions from src.common.logger import get_logger +from src.config.file_watcher import FileChange, FileWatcher """ 如果你想要修改配置文件,请递增version的值 @@ -126,7 +127,7 @@ class Config(ConfigBase): webui: WebUIConfig = Field(default_factory=WebUIConfig) """WebUI配置类""" - + database: DatabaseConfig = Field(default_factory=DatabaseConfig) """数据库配置类""" @@ -176,12 +177,17 @@ class ConfigManager: self.bot_config_path: Path = BOT_CONFIG_PATH self.model_config_path: Path = MODEL_CONFIG_PATH CONFIG_DIR.mkdir(parents=True, exist_ok=True) + self.global_config: Config | None = None + self.model_config: ModelConfig | None = None + self._reload_lock: asyncio.Lock = asyncio.Lock() + self._reload_callbacks: list[Callable[[], object]] = [] + self._file_watcher: FileWatcher | None = None def initialize(self): logger.info(f"MaiCore当前版本: {MMC_VERSION}") logger.info("正在品鉴配置文件...") - self.global_config: Config = self.load_global_config() - self.model_config: ModelConfig = self.load_model_config() + self.global_config = self.load_global_config() + self.model_config = self.load_model_config() logger.info("非常的新鲜,非常的美味!") def load_global_config(self) -> Config: @@ -197,11 +203,74 @@ class ConfigManager: return config def get_global_config(self) -> Config: + if self.global_config is None: + raise RuntimeError("global_config 未初始化") return self.global_config def get_model_config(self) -> ModelConfig: + if self.model_config is None: + raise RuntimeError("model_config 未初始化") return self.model_config + def register_reload_callback(self, callback: Callable[[], object]) -> None: + self._reload_callbacks.append(callback) + + async def reload_config(self) -> bool: + async with self._reload_lock: + try: + global_config_new, global_updated = load_config_from_file( + Config, + self.bot_config_path, + CONFIG_VERSION, + ) + model_config_new, model_updated = load_config_from_file( + ModelConfig, + self.model_config_path, + MODEL_CONFIG_VERSION, + True, + ) + except Exception as exc: + logger.error(f"配置重载失败: {exc}") + return False + + if global_updated or model_updated: + logger.warning("检测到配置版本更新,热重载仅更新内存数据") + + self.global_config = global_config_new + self.model_config = model_config_new + global global_config, model_config + global_config = global_config_new + model_config = model_config_new + logger.info("配置热重载完成") + + for callback in list(self._reload_callbacks): + try: + result = callback() + if asyncio.iscoroutine(result): + await result + except Exception as exc: + logger.warning(f"配置重载回调执行失败: {exc}") + return True + + async def start_file_watcher(self) -> None: + if self._file_watcher is not None and self._file_watcher.running: + return + self._file_watcher = FileWatcher(paths=[self.bot_config_path, self.model_config_path]) + await self._file_watcher.start(self._handle_file_changes) + logger.info("配置文件监视器已启动") + + async def stop_file_watcher(self) -> None: + if self._file_watcher is None: + return + await self._file_watcher.stop() + self._file_watcher = None + + async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None: + if not changes: + return + logger.info("检测到配置文件变更,触发热重载") + await self.reload_config() + def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None: """生成新的配置文件 @@ -220,7 +289,13 @@ def load_config_from_file( attribute_data = AttributeData() with open(config_path, "r", encoding="utf-8") as f: config_data = tomlkit.load(f) - old_ver: str = config_data["inner"]["version"] # type: ignore + inner_table = config_data.get("inner") + if not isinstance(inner_table, Mapping): + raise TypeError("配置文件缺少 inner 版本信息") + inner_version = inner_table.get("version") + if not isinstance(inner_version, str): + raise TypeError("配置文件 inner.version 类型错误") + old_ver: str = inner_version config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理 config_data = config_data.unwrap() # 转换为普通字典,方便后续处理 # 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改 @@ -236,8 +311,7 @@ def load_config_from_file( mig = try_migrate_legacy_bot_config_dict(original_data) if mig.migrated: logger.warning( - f"检测到旧版配置结构,已尝试自动修复: {mig.reason}。" - f"建议稍后检查并保存生成的新配置文件。" + f"检测到旧版配置结构,已尝试自动修复: {mig.reason}。建议稍后检查并保存生成的新配置文件。" ) migrated_data = mig.data target_config = config_class.from_dict(attribute_data, migrated_data) diff --git a/src/config/config_utils.py b/src/config/config_utils.py index c0edfc3f..90c5741f 100644 --- a/src/config/config_utils.py +++ b/src/config/config_utils.py @@ -7,7 +7,7 @@ import tomlkit from .config_base import ConfigBase if TYPE_CHECKING: - from .config import AttributeData + from .config_base import AttributeData def recursive_parse_item_to_table( diff --git a/src/config/file_watcher.py b/src/config/file_watcher.py new file mode 100644 index 00000000..fd0dc0f1 --- /dev/null +++ b/src/config/file_watcher.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Awaitable, Callable, Iterable, Sequence + +from watchfiles import Change, awatch + +import asyncio + +from src.common.logger import get_logger + + +logger = get_logger("file_watcher") + + +@dataclass(frozen=True) +class FileChange: + change_type: Change + path: Path + + +ChangeCallback = Callable[[Sequence[FileChange]], Awaitable[None]] + + +class FileWatcher: + def __init__(self, paths: Iterable[Path], debounce_ms: int = 600) -> None: + self._paths = [path.resolve() for path in paths] + self._debounce_ms = debounce_ms + self._running = False + self._task: asyncio.Task[None] | None = None + + @property + def running(self) -> bool: + return self._running + + async def start(self, callback: ChangeCallback) -> None: + if self._running: + return + self._running = True + self._task = asyncio.create_task(self._run(callback)) + + async def stop(self) -> None: + if not self._running: + return + self._running = False + if self._task is None: + return + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + return + + async def _run(self, callback: ChangeCallback) -> None: + try: + async for changes in awatch(*self._paths, debounce=self._debounce_ms): + if not self._running: + break + try: + await callback(self._normalize_changes(changes)) + except Exception as exc: + logger.warning(f"文件变更回调执行失败: {exc}") + except asyncio.CancelledError: + return + except Exception as exc: + logger.error(f"文件监视器运行异常: {exc}") + + def _normalize_changes(self, changes: set[tuple[Change, str]]) -> list[FileChange]: + return [FileChange(change_type=change, path=Path(path)) for change, path in changes] diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 295ec271..ad0460f3 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Optional, Callable, Any, Set import traceback from src.common.logger import get_logger -from src.config.config import model_config +from src.config.config import config_manager from src.config.model_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message from .payload_content.resp_format import RespFormat @@ -43,11 +43,44 @@ class LLMRequest: self.task_name = request_type self.model_for_task = model_set self.request_type = request_type + self._task_config_name = self._resolve_task_config_name(model_set) self.model_usage: Dict[str, Tuple[int, int, int]] = { model: (0, 0, 0) for model in self.model_for_task.model_list } """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" + def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]: + try: + model_task_config = config_manager.get_model_config().model_task_config + except Exception: + return None + for attr in dir(model_task_config): + if attr.startswith("__"): + continue + value = getattr(model_task_config, attr, None) + if isinstance(value, TaskConfig) and value is model_set: + return attr + return None + + def _get_latest_task_config(self) -> TaskConfig: + if self._task_config_name: + try: + model_task_config = config_manager.get_model_config().model_task_config + value = getattr(model_task_config, self._task_config_name, None) + if isinstance(value, TaskConfig): + return value + except Exception: + return self.model_for_task + return self.model_for_task + + def _refresh_task_config(self) -> TaskConfig: + latest = self._get_latest_task_config() + if latest is not self.model_for_task: + self.model_for_task = latest + if list(self.model_usage.keys()) != latest.model_list: + self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list} + return self.model_for_task + def _check_slow_request(self, time_cost: float, model_name: str) -> None: """检查请求是否过慢并输出警告日志 @@ -80,6 +113,7 @@ class LLMRequest: Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ + self._refresh_task_config() start_time = time.time() def message_factory(client: BaseClient) -> List[Message]: @@ -123,6 +157,7 @@ class LLMRequest: Returns: (Optional[str]): 生成的文本描述或None """ + self._refresh_task_config() response, _ = await self._execute_request( request_type=RequestType.AUDIO, audio_base64=voice_base64, @@ -148,6 +183,7 @@ class LLMRequest: Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ + self._refresh_task_config() start_time = time.time() def message_factory(client: BaseClient) -> List[Message]: @@ -204,6 +240,7 @@ class LLMRequest: Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ + self._refresh_task_config() start_time = time.time() tool_built = self._build_tool_options(tools) @@ -246,6 +283,7 @@ class LLMRequest: Returns: (Tuple[List[float], str]): (嵌入向量,使用的模型名称) """ + self._refresh_task_config() start_time = time.time() response, model_info = await self._execute_request( request_type=RequestType.EMBEDDING, @@ -269,6 +307,7 @@ class LLMRequest: """ 根据配置的策略选择模型:balance(负载均衡)或 random(随机选择) """ + self._refresh_task_config() available_models = { model: scores for model, scores in self.model_usage.items() @@ -314,8 +353,8 @@ class LLMRequest: message_list: List[Message], tool_options: list[ToolOption] | None, response_format: RespFormat | None, - stream_response_handler: Optional[Callable], - async_response_parser: Optional[Callable], + stream_response_handler: Optional[Callable[..., Any]], + async_response_parser: Optional[Callable[..., Any]], temperature: Optional[float], max_tokens: Optional[int], embedding_input: str | None, @@ -466,8 +505,8 @@ class LLMRequest: message_factory: Optional[Callable[[BaseClient], List[Message]]] = None, tool_options: list[ToolOption] | None = None, response_format: RespFormat | None = None, - stream_response_handler: Optional[Callable] = None, - async_response_parser: Optional[Callable] = None, + stream_response_handler: Optional[Callable[..., Any]] = None, + async_response_parser: Optional[Callable[..., Any]] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, embedding_input: str | None = None, @@ -595,7 +634,7 @@ class TempMethodsLLMUtils: Raises: ValueError: 未找到指定模型 """ - for model in model_config.models: + for model in config_manager.get_model_config().models: if model.name == model_name: return model raise ValueError(f"未找到名为 '{model_name}' 的模型") @@ -614,7 +653,7 @@ class TempMethodsLLMUtils: Raises: ValueError: 未找到指定提供商 """ - for provider in model_config.api_providers: + for provider in config_manager.get_model_config().api_providers: if provider.name == provider_name: return provider raise ValueError(f"未找到名为 '{provider_name}' 的API提供商") diff --git a/src/main.py b/src/main.py index 97e58578..df2d069a 100644 --- a/src/main.py +++ b/src/main.py @@ -9,7 +9,7 @@ from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask # from src.chat.utils.token_statistics import TokenStatisticsTask from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.message_receive.chat_stream import get_chat_manager -from src.config.config import global_config +from src.config.config import config_manager, global_config from src.chat.message_receive.bot import chat_bot from src.common.logger import get_logger from src.common.message_server.server import get_global_server, Server @@ -84,6 +84,8 @@ class MainSystem: """初始化其他组件""" init_start_time = time.time() + await config_manager.start_file_watcher() + # 添加在线时间统计任务 await async_task_manager.add_task(OnlineTimeRecordTask()) @@ -168,10 +170,13 @@ class MainSystem: async def main(): """主函数""" system = MainSystem() - await asyncio.gather( - system.initialize(), - system.schedule_tasks(), - ) + try: + await asyncio.gather( + system.initialize(), + system.schedule_tasks(), + ) + finally: + await config_manager.stop_file_watcher() if __name__ == "__main__":