Files
DEMO-AGENT/apps/scenarios/services.py

80 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import yaml
from django.conf import settings
REQUIRED_FIELDS = [
("id",),
("name",),
("description",),
("agent", "role"),
("agent", "goal"),
("agent", "instructions"),
("rag", "enabled"),
("tools",),
("output", "type"),
("audit", "enabled"),
]
class ScenarioNotFound(KeyError):
pass
class ScenarioValidationError(ValueError):
pass
def _get_nested(config: dict, path: tuple[str, ...]):
value = config
for key in path:
if not isinstance(value, dict) or key not in value:
raise ScenarioValidationError("缺失必填字段: " + ".".join(path))
value = value[key]
return value
def validate_scenario(config: dict) -> dict:
# 只校验真正影响运行闭环的必填字段;
# 页面展示类字段如 applicable_questions 允许缺失,并在归一化阶段补默认值。
for field_path in REQUIRED_FIELDS:
_get_nested(config, field_path)
return normalize_scenario(config)
def normalize_scenario(config: dict) -> dict:
"""补齐页面和其他模块常用的派生字段,减少模板中的条件判断。"""
normalized = dict(config)
normalized["applicable_questions"] = list(config.get("applicable_questions") or [])
normalized["rag"] = dict(config.get("rag", {}))
normalized["rag"]["enabled"] = bool(normalized["rag"].get("enabled"))
normalized["tools"] = list(config.get("tools") or [])
normalized["tool_count"] = len(normalized["tools"])
normalized["is_enabled"] = True
return normalized
def _scenario_files() -> list[Path]:
config_dir = Path(settings.SCENARIO_CONFIG_DIR)
if not config_dir.exists():
return []
return sorted([*config_dir.glob("*.yaml"), *config_dir.glob("*.yml")])
def list_scenarios() -> list[dict]:
# 首页每次读取最新 YAML便于复试现场快速改题。
scenarios = []
for path in _scenario_files():
with path.open("r", encoding="utf-8") as file:
config = yaml.safe_load(file) or {}
scenarios.append(validate_scenario(config))
return scenarios
def get_scenario(scenario_id: str) -> dict:
for scenario in list_scenarios():
if scenario["id"] == scenario_id:
return scenario
raise ScenarioNotFound(f"场景不存在: {scenario_id}")