Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 45 additions & 72 deletions wanda/deps.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,14 @@ func (g *depGraph) loadSpec(specPath string, isRoot bool) error {
if err != nil {
return fmt.Errorf("parse %s: %w", specPath, err)
}

if err := spec.ValidateParams(g.lookup); err != nil {
return fmt.Errorf("%s: %w", specPath, err)
}

spec = spec.expandVar(g.lookup)

if err := checkUnexpandedVars(spec, specPath); err != nil {
if err := checkUnexpandedVars(spec, specPath, spec.Params); err != nil {
return err
}

Expand Down Expand Up @@ -252,69 +257,38 @@ func (g *depGraph) validateDeps() error {
}

// checkUnexpandedVars checks if a spec has any unexpanded environment variables
// and returns a helpful error message if so.
func checkUnexpandedVars(spec *Spec, specPath string) error {
var missing []string

if vars := findUnexpandedVars(spec.Name); len(vars) > 0 {
missing = append(missing, vars...)
}
for _, s := range spec.Froms {
if vars := findUnexpandedVars(s); len(vars) > 0 {
missing = append(missing, vars...)
}
}

if len(missing) == 0 {
// and returns a helpful error message. If params are declared for a missing var,
// the valid values are included in the error message.
func checkUnexpandedVars(spec *Spec, specPath string, params map[string][]string) error {
vars := spec.UnexpandedVars()
if len(vars) == 0 {
return nil
}

// Deduplicate
seen := make(map[string]bool)
var unique []string
for _, v := range missing {
for _, v := range vars {
if !seen[v] {
seen[v] = true
unique = append(unique, v)
}
}

if len(unique) == 1 {
return fmt.Errorf("%s: environment variable %s is not set", specPath, unique[0])
// Build error message with param hints where available
var parts []string
for _, v := range unique {
if allowed, ok := params[v]; ok && len(allowed) > 0 {
parts = append(parts, fmt.Sprintf("$%s (valid values: %s)", v, strings.Join(allowed, ", ")))
} else {
parts = append(parts, "$"+v)
}
}
return fmt.Errorf("%s: environment variables not set: %s", specPath, strings.Join(unique, ", "))
}

// findUnexpandedVars finds $VAR patterns in a string that were not expanded.
func findUnexpandedVars(s string) []string {
var vars []string
for i := 0; i < len(s); i++ {
if s[i] == '$' && i+1 < len(s) {
// Skip $$
if s[i+1] == '$' {
i++
continue
}
// Find the variable name
j := i + 1
for j < len(s) {
c := s[j]
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '_' {
j++
continue
}
if c >= '0' && c <= '9' && j > i+1 {
j++
continue
}
break
}
if j > i+1 {
vars = append(vars, s[i:j])
}
i = j - 1
}
if len(parts) == 1 {
return fmt.Errorf("%s: environment variable %s is not set", specPath, parts[0])
}
return vars
return fmt.Errorf("%s: environment variables not set: %s", specPath, strings.Join(parts, "; "))
}

// readWandaSpecs reads the wandaSpecsFile.
Expand Down Expand Up @@ -345,7 +319,8 @@ func readWandaSpecs(wandaSpecsFile string) ([]string, error) {
type specIndex map[string]string

// discoverSpecs scans searchRoot for *.wanda.yaml files and builds a name index.
// Names are expanded using the provided lookup function.
// Names are expanded using declared params first, then the lookup function.
// Specs with params will have all param combinations indexed.
// Returns an error if two specs expand to the same name.
func discoverSpecs(searchRoot string, lookup lookupFunc) (specIndex, error) {
index := make(specIndex)
Expand All @@ -368,30 +343,28 @@ func discoverSpecs(searchRoot string, lookup lookupFunc) (specIndex, error) {
return nil // skip unparseable files
}

// Expand the name using env lookup and index it.
expanded := spec.expandVar(lookup)
name := expanded.Name

// Skip specs with unexpanded variables.
if strings.Contains(name, "$") {
return nil
}

if existing, exists := index[name]; exists && existing != path {
// Record conflict.
m := conflicts[name]
if m == nil {
m = make(map[string]struct{}, 2)
conflicts[name] = m
// Index under all expanded names (using params, then env lookup).
for _, name := range spec.ExpandedNames() {
expanded, ok := tryFullyExpand(name, lookup)
if !ok {
continue
}
m[existing] = struct{}{}
m[path] = struct{}{}
if minConflictName == "" || name < minConflictName {
minConflictName = name
if existing, exists := index[expanded]; exists && existing != path {
// Record conflict.
m := conflicts[expanded]
if m == nil {
m = make(map[string]struct{}, 2)
conflicts[expanded] = m
}
m[existing] = struct{}{}
m[path] = struct{}{}
if minConflictName == "" || expanded < minConflictName {
minConflictName = expanded
}
continue
}
return nil
index[expanded] = path
}
index[name] = path
return nil
})
if err != nil {
Expand Down
195 changes: 195 additions & 0 deletions wanda/deps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,70 @@ func TestDiscoverSpecs_WithVariables(t *testing.T) {
}
}

func TestDiscoverSpecs_WithParams(t *testing.T) {
tmpDir := t.TempDir()

// Spec with params - should be indexed under all expanded names
writeSpec(t, tmpDir, "base.wanda.yaml", strings.Join([]string{
"name: base$PY",
"params:",
" PY:",
" - '3.10'",
" - '3.11'",
" - '3.12'",
"dockerfile: Dockerfile",
}, "\n"))

// No env vars needed - params provide the values
index, err := discoverSpecs(tmpDir, noopLookup)
if err != nil {
t.Fatalf("discoverSpecs: %v", err)
}

// All three expanded names should be indexed
for _, name := range []string{"base3.10", "base3.11", "base3.12"} {
if _, ok := index[name]; !ok {
t.Errorf("index missing %q, got: %v", name, index)
}
}

// All should point to the same spec file
path := index["base3.10"]
if index["base3.11"] != path || index["base3.12"] != path {
t.Errorf("all names should map to same path, got: %v", index)
}
}

func TestDiscoverSpecs_ParamsAndEnvFallback(t *testing.T) {
tmpDir := t.TempDir()

// Spec with partial params - one var has params, one needs env
writeSpec(t, tmpDir, "base.wanda.yaml", strings.Join([]string{
"name: base$PY-$ARCH",
"params:",
" PY:",
" - '3.10'",
"dockerfile: Dockerfile",
}, "\n"))

lookup := func(key string) (string, bool) {
if key == "ARCH" {
return "amd64", true
}
return "", false
}

index, err := discoverSpecs(tmpDir, lookup)
if err != nil {
t.Fatalf("discoverSpecs: %v", err)
}

// Should be indexed as base3.10-amd64
if _, ok := index["base3.10-amd64"]; !ok {
t.Errorf("index missing base3.10-amd64, got: %v", index)
}
}

func TestBuildDepGraph_Discovery(t *testing.T) {
tmpDir := t.TempDir()

Expand Down Expand Up @@ -585,3 +649,134 @@ func TestBuildDepGraph_TransitiveDeps(t *testing.T) {
t.Error("expected c in graph (transitive dep)")
}
}

func TestBuildDepGraph_ParamsValidation(t *testing.T) {
tmpDir := t.TempDir()

writeSpec(t, tmpDir, "spec.wanda.yaml", strings.Join([]string{
"name: myimage$PY_VERSION",
"params:",
" PY_VERSION:",
" - '3.10'",
" - '3.11'",
"dockerfile: Dockerfile",
}, "\n"))

specsFile := filepath.Join(tmpDir, testWandaSpecsFile)

t.Run("valid param value", func(t *testing.T) {
lookup := func(k string) (string, bool) {
if k == "PY_VERSION" {
return "3.10", true
}
return "", false
}
_, err := buildDepGraph(filepath.Join(tmpDir, "spec.wanda.yaml"), lookup, "", specsFile)
if err != nil {
t.Errorf("unexpected error with valid param: %v", err)
}
})

t.Run("invalid param value", func(t *testing.T) {
lookup := func(k string) (string, bool) {
if k == "PY_VERSION" {
return "3.9", true
}
return "", false
}
_, err := buildDepGraph(filepath.Join(tmpDir, "spec.wanda.yaml"), lookup, "", specsFile)
if err == nil {
t.Error("expected error for invalid param value")
}
if !strings.Contains(err.Error(), "3.9") {
t.Errorf("error should mention invalid value '3.9': %v", err)
}
})
}

func TestBuildDepGraph_UnexpandedWithParamsHint(t *testing.T) {
tmpDir := t.TempDir()

// Spec with params but env var not set
writeSpec(t, tmpDir, "spec.wanda.yaml", strings.Join([]string{
"name: myimage$PY_VERSION",
"params:",
" PY_VERSION:",
" - '3.10'",
" - '3.11'",
"dockerfile: Dockerfile",
}, "\n"))

specsFile := filepath.Join(tmpDir, testWandaSpecsFile)
// No env var set - should get helpful error with valid values
_, err := buildDepGraph(filepath.Join(tmpDir, "spec.wanda.yaml"), noopLookup, "", specsFile)
if err == nil {
t.Fatal("expected error for unexpanded env var")
}

// Error should mention valid values from params
errStr := err.Error()
if !strings.Contains(errStr, "PY_VERSION") {
t.Errorf("error should mention PY_VERSION: %v", err)
}
if !strings.Contains(errStr, "valid values") {
t.Errorf("error should mention valid values: %v", err)
}
if !strings.Contains(errStr, "3.10") || !strings.Contains(errStr, "3.11") {
t.Errorf("error should list valid values 3.10, 3.11: %v", err)
}
}

func TestBuildDepGraph_DiscoveryWithParams(t *testing.T) {
tmpDir := t.TempDir()

specsFile := writeWandaSpecs(t, tmpDir, []string{"."})

// Base spec with params - discoverable via params, loadable with env var
baseDir := filepath.Join(tmpDir, "base")
if err := os.MkdirAll(baseDir, 0755); err != nil {
t.Fatal(err)
}
writeSpec(t, baseDir, "base.wanda.yaml", strings.Join([]string{
"name: base$PY",
"params:",
" PY:",
" - '3.10'",
" - '3.11'",
"dockerfile: Dockerfile",
}, "\n"))

// App spec depends on base3.10
appDir := filepath.Join(tmpDir, "app")
if err := os.MkdirAll(appDir, 0755); err != nil {
t.Fatal(err)
}
writeSpec(t, appDir, "app.wanda.yaml", strings.Join([]string{
"name: app",
`froms: ["cr.ray.io/rayproject/base3.10"]`,
"dockerfile: Dockerfile",
}, "\n"))

// Discovery finds base3.10 via params (no env var needed for discovery).
// Loading the spec requires env var to be set for expansion.
lookup := func(key string) (string, bool) {
if key == "PY" {
return "3.10", true
}
return "", false
}

graph, err := buildDepGraph(filepath.Join(appDir, "app.wanda.yaml"), lookup, testPrefix, specsFile)
if err != nil {
t.Fatalf("buildDepGraph: %v", err)
}

// base3.10 was discovered via params and loaded with PY=3.10
if graph.Specs["base3.10"] == nil {
t.Error("expected base3.10 in graph")
}

if len(graph.Order) != 2 {
t.Errorf("Order has %d items, want 2", len(graph.Order))
}
}
Loading