MaiBot/src/chat/maibot_llmreq/tests/test_config_load.py

85 lines
2.4 KiB
Python

import pytest
from packaging.version import InvalidVersion
from src import maibot_llmreq
from src.maibot_llmreq.config.parser import _get_config_version, load_config
class TestConfigLoad:
def test_loads_valid_version_from_toml(self):
maibot_llmreq.init_logger()
toml_data = {"inner": {"version": "1.2.3"}}
version = _get_config_version(toml_data)
assert str(version) == "1.2.3"
def test_handles_missing_version_key(self):
maibot_llmreq.init_logger()
toml_data = {}
version = _get_config_version(toml_data)
assert str(version) == "0.0.0"
def test_raises_error_for_invalid_version(self):
maibot_llmreq.init_logger()
toml_data = {"inner": {"version": "invalid_version"}}
with pytest.raises(InvalidVersion):
_get_config_version(toml_data)
def test_loads_complete_config_successfully(self, tmp_path):
maibot_llmreq.init_logger()
config_path = tmp_path / "config.toml"
config_path.write_text("""
[inner]
version = "0.1.0"
[request_conf]
max_retry = 5
timeout = 10
[[api_providers]]
name = "provider1"
base_url = "https://api.example.com"
api_key = "key123"
[[api_providers]]
name = "provider2"
base_url = "https://api.example2.com"
api_key = "key456"
[[models]]
model_identifier = "model1"
api_provider = "provider1"
[[models]]
model_identifier = "model2"
api_provider = "provider2"
[task_model_usage]
task1 = { model = "model1" }
task2 = "model1"
task3 = [
"model1",
{ model = "model2", temperature = 0.5 }
]
""")
config = load_config(str(config_path))
assert config.req_conf.max_retry == 5
assert config.req_conf.timeout == 10
assert "provider1" in config.api_providers
assert "model1" in config.models
assert "task1" in config.task_model_arg_map
def test_raises_error_for_missing_required_field(self, tmp_path):
maibot_llmreq.init_logger()
config_path = tmp_path / "config.toml"
config_path.write_text("""
[inner]
version = "1.0.0"
""")
with pytest.raises(KeyError):
load_config(str(config_path))