111 lines
3.1 KiB
Python
111 lines
3.1 KiB
Python
from pathlib import Path
|
||
|
||
import yaml
|
||
from django.conf import settings
|
||
|
||
|
||
REQUIRED_FIELDS = [
|
||
("id",),
|
||
("name",),
|
||
("description",),
|
||
("agent", "role"),
|
||
("agent", "goal"),
|
||
("agent", "instructions"),
|
||
("rag", "enabled"),
|
||
("tools",),
|
||
("output", "type"),
|
||
("audit", "enabled"),
|
||
]
|
||
|
||
|
||
class ScenarioNotFound(KeyError):
|
||
pass
|
||
|
||
|
||
class ScenarioValidationError(ValueError):
|
||
pass
|
||
|
||
|
||
def _get_nested(config: dict, path: tuple[str, ...]):
|
||
value = config
|
||
for key in path:
|
||
if not isinstance(value, dict) or key not in value:
|
||
raise ScenarioValidationError("缺失必填字段: " + ".".join(path))
|
||
value = value[key]
|
||
return value
|
||
|
||
|
||
def validate_scenario(config: dict) -> dict:
|
||
# 仅校验真正影响运行闭环的必填字段;
|
||
# 页面展示字段允许缺失,并在归一化阶段补默认值。
|
||
for field_path in REQUIRED_FIELDS:
|
||
_get_nested(config, field_path)
|
||
return normalize_scenario(config)
|
||
|
||
|
||
def normalize_scenario(config: dict) -> dict:
|
||
"""补齐页面和其它模块常用的派生字段,避免模板层重复判断。"""
|
||
normalized = dict(config)
|
||
normalized["applicable_questions"] = list(config.get("applicable_questions") or [])
|
||
normalized["rag"] = dict(config.get("rag", {}))
|
||
normalized["rag"]["enabled"] = bool(normalized["rag"].get("enabled"))
|
||
normalized["tools"] = list(config.get("tools") or [])
|
||
normalized["tool_count"] = len(normalized["tools"])
|
||
normalized["is_enabled"] = True
|
||
return normalized
|
||
|
||
|
||
def _scenario_files() -> list[Path]:
|
||
config_dir = Path(settings.SCENARIO_CONFIG_DIR)
|
||
if not config_dir.exists():
|
||
return []
|
||
return sorted([*config_dir.glob("*.yaml"), *config_dir.glob("*.yml")])
|
||
|
||
|
||
def _read_yaml_file(path: Path) -> dict:
|
||
with path.open("r", encoding="utf-8") as file:
|
||
return yaml.safe_load(file) or {}
|
||
|
||
|
||
def _collect_scenario_load_result() -> tuple[list[dict], list[dict]]:
|
||
"""
|
||
统一读取配置目录中的所有场景文件。
|
||
|
||
返回值:
|
||
- scenarios: 校验通过的场景列表
|
||
- issues: 非法 YAML / 缺字段等错误摘要,供首页展示
|
||
"""
|
||
scenarios = []
|
||
issues = []
|
||
for path in _scenario_files():
|
||
try:
|
||
config = _read_yaml_file(path)
|
||
scenarios.append(validate_scenario(config))
|
||
except (yaml.YAMLError, ScenarioValidationError) as exc:
|
||
issues.append(
|
||
{
|
||
"file_name": path.name,
|
||
"message": str(exc),
|
||
}
|
||
)
|
||
return scenarios, issues
|
||
|
||
|
||
def list_scenarios() -> list[dict]:
|
||
# 首页每次读取最新 YAML,便于复试现场快速改题。
|
||
scenarios, _issues = _collect_scenario_load_result()
|
||
return scenarios
|
||
|
||
|
||
def list_scenario_issues() -> list[dict]:
|
||
"""返回配置异常摘要,便于页面明确提示而不是直接 500。"""
|
||
_scenarios, issues = _collect_scenario_load_result()
|
||
return issues
|
||
|
||
|
||
def get_scenario(scenario_id: str) -> dict:
|
||
for scenario in list_scenarios():
|
||
if scenario["id"] == scenario_id:
|
||
return scenario
|
||
raise ScenarioNotFound(f"场景不存在: {scenario_id}")
|