80 lines
2.3 KiB
Python
80 lines
2.3 KiB
Python
from pathlib import Path
|
||
|
||
import yaml
|
||
from django.conf import settings
|
||
|
||
|
||
REQUIRED_FIELDS = [
|
||
("id",),
|
||
("name",),
|
||
("description",),
|
||
("agent", "role"),
|
||
("agent", "goal"),
|
||
("agent", "instructions"),
|
||
("rag", "enabled"),
|
||
("tools",),
|
||
("output", "type"),
|
||
("audit", "enabled"),
|
||
]
|
||
|
||
|
||
class ScenarioNotFound(KeyError):
|
||
pass
|
||
|
||
|
||
class ScenarioValidationError(ValueError):
|
||
pass
|
||
|
||
|
||
def _get_nested(config: dict, path: tuple[str, ...]):
|
||
value = config
|
||
for key in path:
|
||
if not isinstance(value, dict) or key not in value:
|
||
raise ScenarioValidationError("缺失必填字段: " + ".".join(path))
|
||
value = value[key]
|
||
return value
|
||
|
||
|
||
def validate_scenario(config: dict) -> dict:
|
||
# 只校验真正影响运行闭环的必填字段;
|
||
# 页面展示类字段如 applicable_questions 允许缺失,并在归一化阶段补默认值。
|
||
for field_path in REQUIRED_FIELDS:
|
||
_get_nested(config, field_path)
|
||
return normalize_scenario(config)
|
||
|
||
|
||
def normalize_scenario(config: dict) -> dict:
|
||
"""补齐页面和其他模块常用的派生字段,减少模板中的条件判断。"""
|
||
normalized = dict(config)
|
||
normalized["applicable_questions"] = list(config.get("applicable_questions") or [])
|
||
normalized["rag"] = dict(config.get("rag", {}))
|
||
normalized["rag"]["enabled"] = bool(normalized["rag"].get("enabled"))
|
||
normalized["tools"] = list(config.get("tools") or [])
|
||
normalized["tool_count"] = len(normalized["tools"])
|
||
normalized["is_enabled"] = True
|
||
return normalized
|
||
|
||
|
||
def _scenario_files() -> list[Path]:
|
||
config_dir = Path(settings.SCENARIO_CONFIG_DIR)
|
||
if not config_dir.exists():
|
||
return []
|
||
return sorted([*config_dir.glob("*.yaml"), *config_dir.glob("*.yml")])
|
||
|
||
|
||
def list_scenarios() -> list[dict]:
|
||
# 首页每次读取最新 YAML,便于复试现场快速改题。
|
||
scenarios = []
|
||
for path in _scenario_files():
|
||
with path.open("r", encoding="utf-8") as file:
|
||
config = yaml.safe_load(file) or {}
|
||
scenarios.append(validate_scenario(config))
|
||
return scenarios
|
||
|
||
|
||
def get_scenario(scenario_id: str) -> dict:
|
||
for scenario in list_scenarios():
|
||
if scenario["id"] == scenario_id:
|
||
return scenario
|
||
raise ScenarioNotFound(f"场景不存在: {scenario_id}")
|