65 lines
1.5 KiB
Python
65 lines
1.5 KiB
Python
from pathlib import Path
|
|
|
|
import yaml
|
|
from django.conf import settings
|
|
|
|
|
|
REQUIRED_FIELDS = [
|
|
("id",),
|
|
("name",),
|
|
("description",),
|
|
("agent", "role"),
|
|
("agent", "goal"),
|
|
("agent", "instructions"),
|
|
("rag", "enabled"),
|
|
("tools",),
|
|
("output", "type"),
|
|
("audit", "enabled"),
|
|
]
|
|
|
|
|
|
class ScenarioNotFound(KeyError):
|
|
pass
|
|
|
|
|
|
class ScenarioValidationError(ValueError):
|
|
pass
|
|
|
|
|
|
def _get_nested(config: dict, path: tuple[str, ...]):
|
|
value = config
|
|
for key in path:
|
|
if not isinstance(value, dict) or key not in value:
|
|
raise ScenarioValidationError("缺失必填字段: " + ".".join(path))
|
|
value = value[key]
|
|
return value
|
|
|
|
|
|
def validate_scenario(config: dict) -> dict:
|
|
for field_path in REQUIRED_FIELDS:
|
|
_get_nested(config, field_path)
|
|
return config
|
|
|
|
|
|
def _scenario_files() -> list[Path]:
|
|
config_dir = Path(settings.SCENARIO_CONFIG_DIR)
|
|
if not config_dir.exists():
|
|
return []
|
|
return sorted([*config_dir.glob("*.yaml"), *config_dir.glob("*.yml")])
|
|
|
|
|
|
def list_scenarios() -> list[dict]:
|
|
scenarios = []
|
|
for path in _scenario_files():
|
|
with path.open("r", encoding="utf-8") as file:
|
|
config = yaml.safe_load(file) or {}
|
|
scenarios.append(validate_scenario(config))
|
|
return scenarios
|
|
|
|
|
|
def get_scenario(scenario_id: str) -> dict:
|
|
for scenario in list_scenarios():
|
|
if scenario["id"] == scenario_id:
|
|
return scenario
|
|
raise ScenarioNotFound(f"场景不存在: {scenario_id}")
|