添加文件监视器地基模块,重构模型请求模块使用新版本的配置热重载模块,新增watchfiles依赖

pull/1496/head
DrSmoothl 2026-02-14 21:17:24 +08:00
parent daad0ba2f0
commit dc36542403
No known key found for this signature in database
7 changed files with 210 additions and 22 deletions

View File

@ -36,6 +36,7 @@ dependencies = [
"urllib3>=2.5.0",
"uvicorn>=0.35.0",
"msgpack>=1.1.2",
"watchfiles>=1.1.1",
]

View File

@ -29,4 +29,5 @@ toml>=0.10.2
tomlkit>=0.13.3
urllib3>=2.5.0
uvicorn>=0.35.0
msgpack>=1.1.2
msgpack>=1.1.2
watchfiles>=1.1.1

View File

@ -1,7 +1,7 @@
from pathlib import Path
from typing import TypeVar
from typing import Any, Callable, Mapping, Sequence, TypeVar
from datetime import datetime
from typing import Any
import asyncio
import copy
import tomlkit
@ -38,6 +38,7 @@ from .config_base import ConfigBase, Field, AttributeData
from .config_utils import recursive_parse_item_to_table, output_config_changes, compare_versions
from src.common.logger import get_logger
from src.config.file_watcher import FileChange, FileWatcher
"""
如果你想要修改配置文件请递增version的值
@ -126,7 +127,7 @@ class Config(ConfigBase):
webui: WebUIConfig = Field(default_factory=WebUIConfig)
"""WebUI配置类"""
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
"""数据库配置类"""
@ -176,12 +177,17 @@ class ConfigManager:
self.bot_config_path: Path = BOT_CONFIG_PATH
self.model_config_path: Path = MODEL_CONFIG_PATH
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
self.global_config: Config | None = None
self.model_config: ModelConfig | None = None
self._reload_lock: asyncio.Lock = asyncio.Lock()
self._reload_callbacks: list[Callable[[], object]] = []
self._file_watcher: FileWatcher | None = None
def initialize(self):
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
logger.info("正在品鉴配置文件...")
self.global_config: Config = self.load_global_config()
self.model_config: ModelConfig = self.load_model_config()
self.global_config = self.load_global_config()
self.model_config = self.load_model_config()
logger.info("非常的新鲜,非常的美味!")
def load_global_config(self) -> Config:
@ -197,11 +203,74 @@ class ConfigManager:
return config
def get_global_config(self) -> Config:
if self.global_config is None:
raise RuntimeError("global_config 未初始化")
return self.global_config
def get_model_config(self) -> ModelConfig:
if self.model_config is None:
raise RuntimeError("model_config 未初始化")
return self.model_config
def register_reload_callback(self, callback: Callable[[], object]) -> None:
self._reload_callbacks.append(callback)
async def reload_config(self) -> bool:
async with self._reload_lock:
try:
global_config_new, global_updated = load_config_from_file(
Config,
self.bot_config_path,
CONFIG_VERSION,
)
model_config_new, model_updated = load_config_from_file(
ModelConfig,
self.model_config_path,
MODEL_CONFIG_VERSION,
True,
)
except Exception as exc:
logger.error(f"配置重载失败: {exc}")
return False
if global_updated or model_updated:
logger.warning("检测到配置版本更新,热重载仅更新内存数据")
self.global_config = global_config_new
self.model_config = model_config_new
global global_config, model_config
global_config = global_config_new
model_config = model_config_new
logger.info("配置热重载完成")
for callback in list(self._reload_callbacks):
try:
result = callback()
if asyncio.iscoroutine(result):
await result
except Exception as exc:
logger.warning(f"配置重载回调执行失败: {exc}")
return True
async def start_file_watcher(self) -> None:
if self._file_watcher is not None and self._file_watcher.running:
return
self._file_watcher = FileWatcher(paths=[self.bot_config_path, self.model_config_path])
await self._file_watcher.start(self._handle_file_changes)
logger.info("配置文件监视器已启动")
async def stop_file_watcher(self) -> None:
if self._file_watcher is None:
return
await self._file_watcher.stop()
self._file_watcher = None
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
if not changes:
return
logger.info("检测到配置文件变更,触发热重载")
await self.reload_config()
def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None:
"""生成新的配置文件
@ -220,7 +289,13 @@ def load_config_from_file(
attribute_data = AttributeData()
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
old_ver: str = config_data["inner"]["version"] # type: ignore
inner_table = config_data.get("inner")
if not isinstance(inner_table, Mapping):
raise TypeError("配置文件缺少 inner 版本信息")
inner_version = inner_table.get("version")
if not isinstance(inner_version, str):
raise TypeError("配置文件 inner.version 类型错误")
old_ver: str = inner_version
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
@ -236,8 +311,7 @@ def load_config_from_file(
mig = try_migrate_legacy_bot_config_dict(original_data)
if mig.migrated:
logger.warning(
f"检测到旧版配置结构,已尝试自动修复: {mig.reason}"
f"建议稍后检查并保存生成的新配置文件。"
f"检测到旧版配置结构,已尝试自动修复: {mig.reason}。建议稍后检查并保存生成的新配置文件。"
)
migrated_data = mig.data
target_config = config_class.from_dict(attribute_data, migrated_data)

View File

@ -7,7 +7,7 @@ import tomlkit
from .config_base import ConfigBase
if TYPE_CHECKING:
from .config import AttributeData
from .config_base import AttributeData
def recursive_parse_item_to_table(

View File

@ -0,0 +1,68 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Awaitable, Callable, Iterable, Sequence
from watchfiles import Change, awatch
import asyncio
from src.common.logger import get_logger
logger = get_logger("file_watcher")
@dataclass(frozen=True)
class FileChange:
change_type: Change
path: Path
ChangeCallback = Callable[[Sequence[FileChange]], Awaitable[None]]
class FileWatcher:
def __init__(self, paths: Iterable[Path], debounce_ms: int = 600) -> None:
self._paths = [path.resolve() for path in paths]
self._debounce_ms = debounce_ms
self._running = False
self._task: asyncio.Task[None] | None = None
@property
def running(self) -> bool:
return self._running
async def start(self, callback: ChangeCallback) -> None:
if self._running:
return
self._running = True
self._task = asyncio.create_task(self._run(callback))
async def stop(self) -> None:
if not self._running:
return
self._running = False
if self._task is None:
return
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
return
async def _run(self, callback: ChangeCallback) -> None:
try:
async for changes in awatch(*self._paths, debounce=self._debounce_ms):
if not self._running:
break
try:
await callback(self._normalize_changes(changes))
except Exception as exc:
logger.warning(f"文件变更回调执行失败: {exc}")
except asyncio.CancelledError:
return
except Exception as exc:
logger.error(f"文件监视器运行异常: {exc}")
def _normalize_changes(self, changes: set[tuple[Change, str]]) -> list[FileChange]:
return [FileChange(change_type=change, path=Path(path)) for change, path in changes]

View File

@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Optional, Callable, Any, Set
import traceback
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.config import config_manager
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
@ -43,11 +43,44 @@ class LLMRequest:
self.task_name = request_type
self.model_for_task = model_set
self.request_type = request_type
self._task_config_name = self._resolve_task_config_name(model_set)
self.model_usage: Dict[str, Tuple[int, int, int]] = {
model: (0, 0, 0) for model in self.model_for_task.model_list
}
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]:
try:
model_task_config = config_manager.get_model_config().model_task_config
except Exception:
return None
for attr in dir(model_task_config):
if attr.startswith("__"):
continue
value = getattr(model_task_config, attr, None)
if isinstance(value, TaskConfig) and value is model_set:
return attr
return None
def _get_latest_task_config(self) -> TaskConfig:
if self._task_config_name:
try:
model_task_config = config_manager.get_model_config().model_task_config
value = getattr(model_task_config, self._task_config_name, None)
if isinstance(value, TaskConfig):
return value
except Exception:
return self.model_for_task
return self.model_for_task
def _refresh_task_config(self) -> TaskConfig:
latest = self._get_latest_task_config()
if latest is not self.model_for_task:
self.model_for_task = latest
if list(self.model_usage.keys()) != latest.model_list:
self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list}
return self.model_for_task
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
"""检查请求是否过慢并输出警告日志
@ -80,6 +113,7 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容推理内容模型名称工具调用列表
"""
self._refresh_task_config()
start_time = time.time()
def message_factory(client: BaseClient) -> List[Message]:
@ -123,6 +157,7 @@ class LLMRequest:
Returns:
(Optional[str]): 生成的文本描述或None
"""
self._refresh_task_config()
response, _ = await self._execute_request(
request_type=RequestType.AUDIO,
audio_base64=voice_base64,
@ -148,6 +183,7 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容推理内容模型名称工具调用列表
"""
self._refresh_task_config()
start_time = time.time()
def message_factory(client: BaseClient) -> List[Message]:
@ -204,6 +240,7 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容推理内容模型名称工具调用列表
"""
self._refresh_task_config()
start_time = time.time()
tool_built = self._build_tool_options(tools)
@ -246,6 +283,7 @@ class LLMRequest:
Returns:
(Tuple[List[float], str]): (嵌入向量使用的模型名称)
"""
self._refresh_task_config()
start_time = time.time()
response, model_info = await self._execute_request(
request_type=RequestType.EMBEDDING,
@ -269,6 +307,7 @@ class LLMRequest:
"""
根据配置的策略选择模型balance负载均衡 random随机选择
"""
self._refresh_task_config()
available_models = {
model: scores
for model, scores in self.model_usage.items()
@ -314,8 +353,8 @@ class LLMRequest:
message_list: List[Message],
tool_options: list[ToolOption] | None,
response_format: RespFormat | None,
stream_response_handler: Optional[Callable],
async_response_parser: Optional[Callable],
stream_response_handler: Optional[Callable[..., Any]],
async_response_parser: Optional[Callable[..., Any]],
temperature: Optional[float],
max_tokens: Optional[int],
embedding_input: str | None,
@ -466,8 +505,8 @@ class LLMRequest:
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[Callable] = None,
async_response_parser: Optional[Callable] = None,
stream_response_handler: Optional[Callable[..., Any]] = None,
async_response_parser: Optional[Callable[..., Any]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
embedding_input: str | None = None,
@ -595,7 +634,7 @@ class TempMethodsLLMUtils:
Raises:
ValueError: 未找到指定模型
"""
for model in model_config.models:
for model in config_manager.get_model_config().models:
if model.name == model_name:
return model
raise ValueError(f"未找到名为 '{model_name}' 的模型")
@ -614,7 +653,7 @@ class TempMethodsLLMUtils:
Raises:
ValueError: 未找到指定提供商
"""
for provider in model_config.api_providers:
for provider in config_manager.get_model_config().api_providers:
if provider.name == provider_name:
return provider
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")

View File

@ -9,7 +9,7 @@ from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
# from src.chat.utils.token_statistics import TokenStatisticsTask
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.config.config import global_config
from src.config.config import config_manager, global_config
from src.chat.message_receive.bot import chat_bot
from src.common.logger import get_logger
from src.common.message_server.server import get_global_server, Server
@ -84,6 +84,8 @@ class MainSystem:
"""初始化其他组件"""
init_start_time = time.time()
await config_manager.start_file_watcher()
# 添加在线时间统计任务
await async_task_manager.add_task(OnlineTimeRecordTask())
@ -168,10 +170,13 @@ class MainSystem:
async def main():
"""主函数"""
system = MainSystem()
await asyncio.gather(
system.initialize(),
system.schedule_tasks(),
)
try:
await asyncio.gather(
system.initialize(),
system.schedule_tasks(),
)
finally:
await config_manager.stop_file_watcher()
if __name__ == "__main__":