mirror of https://github.com/Mai-with-u/MaiBot.git
添加文件监视器地基模块,重构模型请求模块使用新版本的配置热重载模块,新增watchfiles依赖
parent
daad0ba2f0
commit
dc36542403
|
|
@ -36,6 +36,7 @@ dependencies = [
|
|||
"urllib3>=2.5.0",
|
||||
"uvicorn>=0.35.0",
|
||||
"msgpack>=1.1.2",
|
||||
"watchfiles>=1.1.1",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,4 +29,5 @@ toml>=0.10.2
|
|||
tomlkit>=0.13.3
|
||||
urllib3>=2.5.0
|
||||
uvicorn>=0.35.0
|
||||
msgpack>=1.1.2
|
||||
msgpack>=1.1.2
|
||||
watchfiles>=1.1.1
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
from typing import Any, Callable, Mapping, Sequence, TypeVar
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
import asyncio
|
||||
import copy
|
||||
|
||||
import tomlkit
|
||||
|
|
@ -38,6 +38,7 @@ from .config_base import ConfigBase, Field, AttributeData
|
|||
from .config_utils import recursive_parse_item_to_table, output_config_changes, compare_versions
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.file_watcher import FileChange, FileWatcher
|
||||
|
||||
"""
|
||||
如果你想要修改配置文件,请递增version的值
|
||||
|
|
@ -126,7 +127,7 @@ class Config(ConfigBase):
|
|||
|
||||
webui: WebUIConfig = Field(default_factory=WebUIConfig)
|
||||
"""WebUI配置类"""
|
||||
|
||||
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||||
"""数据库配置类"""
|
||||
|
||||
|
|
@ -176,12 +177,17 @@ class ConfigManager:
|
|||
self.bot_config_path: Path = BOT_CONFIG_PATH
|
||||
self.model_config_path: Path = MODEL_CONFIG_PATH
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
self.global_config: Config | None = None
|
||||
self.model_config: ModelConfig | None = None
|
||||
self._reload_lock: asyncio.Lock = asyncio.Lock()
|
||||
self._reload_callbacks: list[Callable[[], object]] = []
|
||||
self._file_watcher: FileWatcher | None = None
|
||||
|
||||
def initialize(self):
|
||||
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
||||
logger.info("正在品鉴配置文件...")
|
||||
self.global_config: Config = self.load_global_config()
|
||||
self.model_config: ModelConfig = self.load_model_config()
|
||||
self.global_config = self.load_global_config()
|
||||
self.model_config = self.load_model_config()
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
|
||||
def load_global_config(self) -> Config:
|
||||
|
|
@ -197,11 +203,74 @@ class ConfigManager:
|
|||
return config
|
||||
|
||||
def get_global_config(self) -> Config:
|
||||
if self.global_config is None:
|
||||
raise RuntimeError("global_config 未初始化")
|
||||
return self.global_config
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
if self.model_config is None:
|
||||
raise RuntimeError("model_config 未初始化")
|
||||
return self.model_config
|
||||
|
||||
def register_reload_callback(self, callback: Callable[[], object]) -> None:
|
||||
self._reload_callbacks.append(callback)
|
||||
|
||||
async def reload_config(self) -> bool:
|
||||
async with self._reload_lock:
|
||||
try:
|
||||
global_config_new, global_updated = load_config_from_file(
|
||||
Config,
|
||||
self.bot_config_path,
|
||||
CONFIG_VERSION,
|
||||
)
|
||||
model_config_new, model_updated = load_config_from_file(
|
||||
ModelConfig,
|
||||
self.model_config_path,
|
||||
MODEL_CONFIG_VERSION,
|
||||
True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"配置重载失败: {exc}")
|
||||
return False
|
||||
|
||||
if global_updated or model_updated:
|
||||
logger.warning("检测到配置版本更新,热重载仅更新内存数据")
|
||||
|
||||
self.global_config = global_config_new
|
||||
self.model_config = model_config_new
|
||||
global global_config, model_config
|
||||
global_config = global_config_new
|
||||
model_config = model_config_new
|
||||
logger.info("配置热重载完成")
|
||||
|
||||
for callback in list(self._reload_callbacks):
|
||||
try:
|
||||
result = callback()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
logger.warning(f"配置重载回调执行失败: {exc}")
|
||||
return True
|
||||
|
||||
async def start_file_watcher(self) -> None:
|
||||
if self._file_watcher is not None and self._file_watcher.running:
|
||||
return
|
||||
self._file_watcher = FileWatcher(paths=[self.bot_config_path, self.model_config_path])
|
||||
await self._file_watcher.start(self._handle_file_changes)
|
||||
logger.info("配置文件监视器已启动")
|
||||
|
||||
async def stop_file_watcher(self) -> None:
|
||||
if self._file_watcher is None:
|
||||
return
|
||||
await self._file_watcher.stop()
|
||||
self._file_watcher = None
|
||||
|
||||
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
|
||||
if not changes:
|
||||
return
|
||||
logger.info("检测到配置文件变更,触发热重载")
|
||||
await self.reload_config()
|
||||
|
||||
|
||||
def generate_new_config_file(config_class: type[T], config_path: Path, inner_config_version: str) -> None:
|
||||
"""生成新的配置文件
|
||||
|
|
@ -220,7 +289,13 @@ def load_config_from_file(
|
|||
attribute_data = AttributeData()
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
old_ver: str = config_data["inner"]["version"] # type: ignore
|
||||
inner_table = config_data.get("inner")
|
||||
if not isinstance(inner_table, Mapping):
|
||||
raise TypeError("配置文件缺少 inner 版本信息")
|
||||
inner_version = inner_table.get("version")
|
||||
if not isinstance(inner_version, str):
|
||||
raise TypeError("配置文件 inner.version 类型错误")
|
||||
old_ver: str = inner_version
|
||||
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
|
||||
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
|
||||
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
|
||||
|
|
@ -236,8 +311,7 @@ def load_config_from_file(
|
|||
mig = try_migrate_legacy_bot_config_dict(original_data)
|
||||
if mig.migrated:
|
||||
logger.warning(
|
||||
f"检测到旧版配置结构,已尝试自动修复: {mig.reason}。"
|
||||
f"建议稍后检查并保存生成的新配置文件。"
|
||||
f"检测到旧版配置结构,已尝试自动修复: {mig.reason}。建议稍后检查并保存生成的新配置文件。"
|
||||
)
|
||||
migrated_data = mig.data
|
||||
target_config = config_class.from_dict(attribute_data, migrated_data)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import tomlkit
|
|||
from .config_base import ConfigBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import AttributeData
|
||||
from .config_base import AttributeData
|
||||
|
||||
|
||||
def recursive_parse_item_to_table(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable, Iterable, Sequence
|
||||
|
||||
from watchfiles import Change, awatch
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("file_watcher")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FileChange:
|
||||
change_type: Change
|
||||
path: Path
|
||||
|
||||
|
||||
ChangeCallback = Callable[[Sequence[FileChange]], Awaitable[None]]
|
||||
|
||||
|
||||
class FileWatcher:
|
||||
def __init__(self, paths: Iterable[Path], debounce_ms: int = 600) -> None:
|
||||
self._paths = [path.resolve() for path in paths]
|
||||
self._debounce_ms = debounce_ms
|
||||
self._running = False
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self._running
|
||||
|
||||
async def start(self, callback: ChangeCallback) -> None:
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._run(callback))
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
if self._task is None:
|
||||
return
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
async def _run(self, callback: ChangeCallback) -> None:
|
||||
try:
|
||||
async for changes in awatch(*self._paths, debounce=self._debounce_ms):
|
||||
if not self._running:
|
||||
break
|
||||
try:
|
||||
await callback(self._normalize_changes(changes))
|
||||
except Exception as exc:
|
||||
logger.warning(f"文件变更回调执行失败: {exc}")
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.error(f"文件监视器运行异常: {exc}")
|
||||
|
||||
def _normalize_changes(self, changes: set[tuple[Change, str]]) -> list[FileChange]:
|
||||
return [FileChange(change_type=change, path=Path(path)) for change, path in changes]
|
||||
|
|
@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Optional, Callable, Any, Set
|
|||
import traceback
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
|
|
@ -43,11 +43,44 @@ class LLMRequest:
|
|||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.request_type = request_type
|
||||
self._task_config_name = self._resolve_task_config_name(model_set)
|
||||
self.model_usage: Dict[str, Tuple[int, int, int]] = {
|
||||
model: (0, 0, 0) for model in self.model_for_task.model_list
|
||||
}
|
||||
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
|
||||
|
||||
def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]:
|
||||
try:
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
except Exception:
|
||||
return None
|
||||
for attr in dir(model_task_config):
|
||||
if attr.startswith("__"):
|
||||
continue
|
||||
value = getattr(model_task_config, attr, None)
|
||||
if isinstance(value, TaskConfig) and value is model_set:
|
||||
return attr
|
||||
return None
|
||||
|
||||
def _get_latest_task_config(self) -> TaskConfig:
|
||||
if self._task_config_name:
|
||||
try:
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
value = getattr(model_task_config, self._task_config_name, None)
|
||||
if isinstance(value, TaskConfig):
|
||||
return value
|
||||
except Exception:
|
||||
return self.model_for_task
|
||||
return self.model_for_task
|
||||
|
||||
def _refresh_task_config(self) -> TaskConfig:
|
||||
latest = self._get_latest_task_config()
|
||||
if latest is not self.model_for_task:
|
||||
self.model_for_task = latest
|
||||
if list(self.model_usage.keys()) != latest.model_list:
|
||||
self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list}
|
||||
return self.model_for_task
|
||||
|
||||
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
|
||||
"""检查请求是否过慢并输出警告日志
|
||||
|
||||
|
|
@ -80,6 +113,7 @@ class LLMRequest:
|
|||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
start_time = time.time()
|
||||
|
||||
def message_factory(client: BaseClient) -> List[Message]:
|
||||
|
|
@ -123,6 +157,7 @@ class LLMRequest:
|
|||
Returns:
|
||||
(Optional[str]): 生成的文本描述或None
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
response, _ = await self._execute_request(
|
||||
request_type=RequestType.AUDIO,
|
||||
audio_base64=voice_base64,
|
||||
|
|
@ -148,6 +183,7 @@ class LLMRequest:
|
|||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
start_time = time.time()
|
||||
|
||||
def message_factory(client: BaseClient) -> List[Message]:
|
||||
|
|
@ -204,6 +240,7 @@ class LLMRequest:
|
|||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
start_time = time.time()
|
||||
|
||||
tool_built = self._build_tool_options(tools)
|
||||
|
|
@ -246,6 +283,7 @@ class LLMRequest:
|
|||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
start_time = time.time()
|
||||
response, model_info = await self._execute_request(
|
||||
request_type=RequestType.EMBEDDING,
|
||||
|
|
@ -269,6 +307,7 @@ class LLMRequest:
|
|||
"""
|
||||
根据配置的策略选择模型:balance(负载均衡)或 random(随机选择)
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
available_models = {
|
||||
model: scores
|
||||
for model, scores in self.model_usage.items()
|
||||
|
|
@ -314,8 +353,8 @@ class LLMRequest:
|
|||
message_list: List[Message],
|
||||
tool_options: list[ToolOption] | None,
|
||||
response_format: RespFormat | None,
|
||||
stream_response_handler: Optional[Callable],
|
||||
async_response_parser: Optional[Callable],
|
||||
stream_response_handler: Optional[Callable[..., Any]],
|
||||
async_response_parser: Optional[Callable[..., Any]],
|
||||
temperature: Optional[float],
|
||||
max_tokens: Optional[int],
|
||||
embedding_input: str | None,
|
||||
|
|
@ -466,8 +505,8 @@ class LLMRequest:
|
|||
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[Callable] = None,
|
||||
async_response_parser: Optional[Callable] = None,
|
||||
stream_response_handler: Optional[Callable[..., Any]] = None,
|
||||
async_response_parser: Optional[Callable[..., Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
embedding_input: str | None = None,
|
||||
|
|
@ -595,7 +634,7 @@ class TempMethodsLLMUtils:
|
|||
Raises:
|
||||
ValueError: 未找到指定模型
|
||||
"""
|
||||
for model in model_config.models:
|
||||
for model in config_manager.get_model_config().models:
|
||||
if model.name == model_name:
|
||||
return model
|
||||
raise ValueError(f"未找到名为 '{model_name}' 的模型")
|
||||
|
|
@ -614,7 +653,7 @@ class TempMethodsLLMUtils:
|
|||
Raises:
|
||||
ValueError: 未找到指定提供商
|
||||
"""
|
||||
for provider in model_config.api_providers:
|
||||
for provider in config_manager.get_model_config().api_providers:
|
||||
if provider.name == provider_name:
|
||||
return provider
|
||||
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
|
||||
|
|
|
|||
15
src/main.py
15
src/main.py
|
|
@ -9,7 +9,7 @@ from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
|||
# from src.chat.utils.token_statistics import TokenStatisticsTask
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.config.config import global_config
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_server.server import get_global_server, Server
|
||||
|
|
@ -84,6 +84,8 @@ class MainSystem:
|
|||
"""初始化其他组件"""
|
||||
init_start_time = time.time()
|
||||
|
||||
await config_manager.start_file_watcher()
|
||||
|
||||
# 添加在线时间统计任务
|
||||
await async_task_manager.add_task(OnlineTimeRecordTask())
|
||||
|
||||
|
|
@ -168,10 +170,13 @@ class MainSystem:
|
|||
async def main():
|
||||
"""主函数"""
|
||||
system = MainSystem()
|
||||
await asyncio.gather(
|
||||
system.initialize(),
|
||||
system.schedule_tasks(),
|
||||
)
|
||||
try:
|
||||
await asyncio.gather(
|
||||
system.initialize(),
|
||||
system.schedule_tasks(),
|
||||
)
|
||||
finally:
|
||||
await config_manager.stop_file_watcher()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue