Skip to content

Commit 3a439d3

Browse files
feat(wanda): Add params field for env var validation and discovery
Add `params` field to wanda spec for declaring allowed values of environment variables used in templated fields (name, froms). This enables: - Strict validation: reject builds where env var values don't match declared params - Dependency discovery: specs with params are indexed under all their expanded names (e.g., base$PY with params [3.10, 3.11] is indexed as both base3.10 and base3.11), enabling dependency resolution without requiring all env vars to be set Validation runs before variable expansion in buildDepGraph, so errors reference the original $VARNAME. Topic: wanda-params Relative: wanda-build-deps Signed-off-by: andrew <andrew@anyscale.com>
1 parent 1a7af24 commit 3a439d3

File tree

4 files changed

+628
-72
lines changed

4 files changed

+628
-72
lines changed

wanda/deps.go

Lines changed: 45 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,14 @@ func (g *depGraph) loadSpec(specPath string, isRoot bool) error {
7070
if err != nil {
7171
return fmt.Errorf("parse %s: %w", specPath, err)
7272
}
73+
74+
if err := spec.ValidateParams(g.lookup); err != nil {
75+
return fmt.Errorf("%s: %w", specPath, err)
76+
}
77+
7378
spec = spec.expandVar(g.lookup)
7479

75-
if err := checkUnexpandedVars(spec, specPath); err != nil {
80+
if err := checkUnexpandedVars(spec, specPath, spec.Params); err != nil {
7681
return err
7782
}
7883

@@ -239,69 +244,38 @@ func (g *depGraph) validateDeps() error {
239244
}
240245

241246
// checkUnexpandedVars checks if a spec has any unexpanded environment variables
242-
// and returns a helpful error message if so.
243-
func checkUnexpandedVars(spec *Spec, specPath string) error {
244-
var missing []string
245-
246-
if vars := findUnexpandedVars(spec.Name); len(vars) > 0 {
247-
missing = append(missing, vars...)
248-
}
249-
for _, s := range spec.Froms {
250-
if vars := findUnexpandedVars(s); len(vars) > 0 {
251-
missing = append(missing, vars...)
252-
}
253-
}
254-
255-
if len(missing) == 0 {
247+
// and returns a helpful error message. If params are declared for a missing var,
248+
// the valid values are included in the error message.
249+
func checkUnexpandedVars(spec *Spec, specPath string, params map[string][]string) error {
250+
vars := spec.UnexpandedVars()
251+
if len(vars) == 0 {
256252
return nil
257253
}
258254

255+
// Deduplicate
259256
seen := make(map[string]bool)
260257
var unique []string
261-
for _, v := range missing {
258+
for _, v := range vars {
262259
if !seen[v] {
263260
seen[v] = true
264261
unique = append(unique, v)
265262
}
266263
}
267264

268-
if len(unique) == 1 {
269-
return fmt.Errorf("%s: environment variable %s is not set", specPath, unique[0])
265+
// Build error message with param hints where available
266+
var parts []string
267+
for _, v := range unique {
268+
if allowed, ok := params[v]; ok && len(allowed) > 0 {
269+
parts = append(parts, fmt.Sprintf("$%s (valid values: %s)", v, strings.Join(allowed, ", ")))
270+
} else {
271+
parts = append(parts, "$"+v)
272+
}
270273
}
271-
return fmt.Errorf("%s: environment variables not set: %s", specPath, strings.Join(unique, ", "))
272-
}
273274

274-
// findUnexpandedVars finds $VAR patterns in a string that were not expanded.
275-
func findUnexpandedVars(s string) []string {
276-
var vars []string
277-
for i := 0; i < len(s); i++ {
278-
if s[i] == '$' && i+1 < len(s) {
279-
// Skip $$
280-
if s[i+1] == '$' {
281-
i++
282-
continue
283-
}
284-
// Find the variable name
285-
j := i + 1
286-
for j < len(s) {
287-
c := s[j]
288-
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '_' {
289-
j++
290-
continue
291-
}
292-
if c >= '0' && c <= '9' && j > i+1 {
293-
j++
294-
continue
295-
}
296-
break
297-
}
298-
if j > i+1 {
299-
vars = append(vars, s[i:j])
300-
}
301-
i = j - 1
302-
}
275+
if len(parts) == 1 {
276+
return fmt.Errorf("%s: environment variable %s is not set", specPath, parts[0])
303277
}
304-
return vars
278+
return fmt.Errorf("%s: environment variables not set: %s", specPath, strings.Join(parts, "; "))
305279
}
306280

307281
// findRepoRoot walks up from startDir looking for a .git directory.
@@ -324,7 +298,8 @@ func findRepoRoot(startDir string) string {
324298
type specIndex map[string]string
325299

326300
// discoverSpecs scans searchRoot for *.wanda.yaml files and builds a name index.
327-
// Names are expanded using the provided lookup function.
301+
// Names are expanded using declared params first, then the lookup function.
302+
// Specs with params will have all param combinations indexed.
328303
// Returns an error if two specs expand to the same name.
329304
func discoverSpecs(searchRoot string, lookup lookupFunc) (specIndex, error) {
330305
index := make(specIndex)
@@ -347,30 +322,28 @@ func discoverSpecs(searchRoot string, lookup lookupFunc) (specIndex, error) {
347322
return nil // skip unparseable files
348323
}
349324

350-
// Expand the name using env lookup and index it.
351-
expanded := spec.expandVar(lookup)
352-
name := expanded.Name
353-
354-
// Skip specs with unexpanded variables.
355-
if strings.Contains(name, "$") {
356-
return nil
357-
}
358-
359-
if existing, exists := index[name]; exists && existing != path {
360-
// Record conflict.
361-
m := conflicts[name]
362-
if m == nil {
363-
m = make(map[string]struct{}, 2)
364-
conflicts[name] = m
325+
// Index under all expanded names (using params, then env lookup).
326+
for _, name := range spec.ExpandedNames() {
327+
expanded, ok := tryFullyExpand(name, lookup)
328+
if !ok {
329+
continue
365330
}
366-
m[existing] = struct{}{}
367-
m[path] = struct{}{}
368-
if minConflictName == "" || name < minConflictName {
369-
minConflictName = name
331+
if existing, exists := index[expanded]; exists && existing != path {
332+
// Record conflict.
333+
m := conflicts[expanded]
334+
if m == nil {
335+
m = make(map[string]struct{}, 2)
336+
conflicts[expanded] = m
337+
}
338+
m[existing] = struct{}{}
339+
m[path] = struct{}{}
340+
if minConflictName == "" || expanded < minConflictName {
341+
minConflictName = expanded
342+
}
343+
continue
370344
}
371-
return nil
345+
index[expanded] = path
372346
}
373-
index[name] = path
374347
return nil
375348
})
376349
if err != nil {

wanda/deps_test.go

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,70 @@ func TestDiscoverSpecs_WithVariables(t *testing.T) {
515515
}
516516
}
517517

518+
func TestDiscoverSpecs_WithParams(t *testing.T) {
519+
tmpDir := t.TempDir()
520+
521+
// Spec with params - should be indexed under all expanded names
522+
writeSpec(t, tmpDir, "base.wanda.yaml", strings.Join([]string{
523+
"name: base$PY",
524+
"params:",
525+
" PY:",
526+
" - '3.10'",
527+
" - '3.11'",
528+
" - '3.12'",
529+
"dockerfile: Dockerfile",
530+
}, "\n"))
531+
532+
// No env vars needed - params provide the values
533+
index, err := discoverSpecs(tmpDir, noopLookup)
534+
if err != nil {
535+
t.Fatalf("discoverSpecs: %v", err)
536+
}
537+
538+
// All three expanded names should be indexed
539+
for _, name := range []string{"base3.10", "base3.11", "base3.12"} {
540+
if _, ok := index[name]; !ok {
541+
t.Errorf("index missing %q, got: %v", name, index)
542+
}
543+
}
544+
545+
// All should point to the same spec file
546+
path := index["base3.10"]
547+
if index["base3.11"] != path || index["base3.12"] != path {
548+
t.Errorf("all names should map to same path, got: %v", index)
549+
}
550+
}
551+
552+
func TestDiscoverSpecs_ParamsAndEnvFallback(t *testing.T) {
553+
tmpDir := t.TempDir()
554+
555+
// Spec with partial params - one var has params, one needs env
556+
writeSpec(t, tmpDir, "base.wanda.yaml", strings.Join([]string{
557+
"name: base$PY-$ARCH",
558+
"params:",
559+
" PY:",
560+
" - '3.10'",
561+
"dockerfile: Dockerfile",
562+
}, "\n"))
563+
564+
lookup := func(key string) (string, bool) {
565+
if key == "ARCH" {
566+
return "amd64", true
567+
}
568+
return "", false
569+
}
570+
571+
index, err := discoverSpecs(tmpDir, lookup)
572+
if err != nil {
573+
t.Fatalf("discoverSpecs: %v", err)
574+
}
575+
576+
// Should be indexed as base3.10-amd64
577+
if _, ok := index["base3.10-amd64"]; !ok {
578+
t.Errorf("index missing base3.10-amd64, got: %v", index)
579+
}
580+
}
581+
518582
func TestBuildDepGraph_Discovery(t *testing.T) {
519583
tmpDir := t.TempDir()
520584

@@ -652,3 +716,133 @@ func TestBuildDepGraph_TransitiveDeps(t *testing.T) {
652716
t.Error("expected c in graph (transitive dep)")
653717
}
654718
}
719+
720+
func TestBuildDepGraph_ParamsValidation(t *testing.T) {
721+
tmpDir := t.TempDir()
722+
723+
writeSpec(t, tmpDir, "spec.wanda.yaml", strings.Join([]string{
724+
"name: myimage$PY_VERSION",
725+
"params:",
726+
" PY_VERSION:",
727+
" - '3.10'",
728+
" - '3.11'",
729+
"dockerfile: Dockerfile",
730+
}, "\n"))
731+
732+
t.Run("valid param value", func(t *testing.T) {
733+
lookup := func(k string) (string, bool) {
734+
if k == "PY_VERSION" {
735+
return "3.10", true
736+
}
737+
return "", false
738+
}
739+
_, err := buildDepGraph(filepath.Join(tmpDir, "spec.wanda.yaml"), lookup, "", nil)
740+
if err != nil {
741+
t.Errorf("unexpected error with valid param: %v", err)
742+
}
743+
})
744+
745+
t.Run("invalid param value", func(t *testing.T) {
746+
lookup := func(k string) (string, bool) {
747+
if k == "PY_VERSION" {
748+
return "3.9", true
749+
}
750+
return "", false
751+
}
752+
_, err := buildDepGraph(filepath.Join(tmpDir, "spec.wanda.yaml"), lookup, "", nil)
753+
if err == nil {
754+
t.Error("expected error for invalid param value")
755+
}
756+
if !strings.Contains(err.Error(), "3.9") {
757+
t.Errorf("error should mention invalid value '3.9': %v", err)
758+
}
759+
})
760+
}
761+
762+
func TestBuildDepGraph_UnexpandedWithParamsHint(t *testing.T) {
763+
tmpDir := t.TempDir()
764+
765+
// Spec with params but env var not set
766+
writeSpec(t, tmpDir, "spec.wanda.yaml", strings.Join([]string{
767+
"name: myimage$PY_VERSION",
768+
"params:",
769+
" PY_VERSION:",
770+
" - '3.10'",
771+
" - '3.11'",
772+
"dockerfile: Dockerfile",
773+
}, "\n"))
774+
775+
// No env var set - should get helpful error with valid values
776+
_, err := buildDepGraph(filepath.Join(tmpDir, "spec.wanda.yaml"), noopLookup, "", nil)
777+
if err == nil {
778+
t.Fatal("expected error for unexpanded env var")
779+
}
780+
781+
// Error should mention valid values from params
782+
errStr := err.Error()
783+
if !strings.Contains(errStr, "PY_VERSION") {
784+
t.Errorf("error should mention PY_VERSION: %v", err)
785+
}
786+
if !strings.Contains(errStr, "valid values") {
787+
t.Errorf("error should mention valid values: %v", err)
788+
}
789+
if !strings.Contains(errStr, "3.10") || !strings.Contains(errStr, "3.11") {
790+
t.Errorf("error should list valid values 3.10, 3.11: %v", err)
791+
}
792+
}
793+
794+
func TestBuildDepGraph_DiscoveryWithParams(t *testing.T) {
795+
tmpDir := t.TempDir()
796+
797+
if err := os.Mkdir(filepath.Join(tmpDir, ".git"), 0755); err != nil {
798+
t.Fatal(err)
799+
}
800+
801+
// Base spec with params - discoverable via params, loadable with env var
802+
baseDir := filepath.Join(tmpDir, "base")
803+
if err := os.MkdirAll(baseDir, 0755); err != nil {
804+
t.Fatal(err)
805+
}
806+
writeSpec(t, baseDir, "base.wanda.yaml", strings.Join([]string{
807+
"name: base$PY",
808+
"params:",
809+
" PY:",
810+
" - '3.10'",
811+
" - '3.11'",
812+
"dockerfile: Dockerfile",
813+
}, "\n"))
814+
815+
// App spec depends on base3.10
816+
appDir := filepath.Join(tmpDir, "app")
817+
if err := os.MkdirAll(appDir, 0755); err != nil {
818+
t.Fatal(err)
819+
}
820+
writeSpec(t, appDir, "app.wanda.yaml", strings.Join([]string{
821+
"name: app",
822+
`froms: ["cr.ray.io/rayproject/base3.10"]`,
823+
"dockerfile: Dockerfile",
824+
}, "\n"))
825+
826+
// Discovery finds base3.10 via params (no env var needed for discovery).
827+
// Loading the spec requires env var to be set for expansion.
828+
lookup := func(key string) (string, bool) {
829+
if key == "PY" {
830+
return "3.10", true
831+
}
832+
return "", false
833+
}
834+
835+
graph, err := buildDepGraph(filepath.Join(appDir, "app.wanda.yaml"), lookup, testPrefix, []string{tmpDir})
836+
if err != nil {
837+
t.Fatalf("buildDepGraph: %v", err)
838+
}
839+
840+
// base3.10 was discovered via params and loaded with PY=3.10
841+
if graph.Specs["base3.10"] == nil {
842+
t.Error("expected base3.10 in graph")
843+
}
844+
845+
if len(graph.Order) != 2 {
846+
t.Errorf("Order has %d items, want 2", len(graph.Order))
847+
}
848+
}

0 commit comments

Comments
 (0)