diff --git a/cli/options.go b/cli/options.go index 69ea5654..f90001da 100644 --- a/cli/options.go +++ b/cli/options.go @@ -517,8 +517,9 @@ func (o *ProjectOptions) prepare(ctx context.Context) (*types.ConfigDetails, err return configDetails, nil } -// ProjectFromOptions load a compose project based on command line options -// Deprecated: use ProjectOptions.LoadProject or ProjectOptions.LoadModel +// ProjectFromOptions load a compose project based on command line options. +// +// Deprecated: use ProjectOptions.LoadProject or ProjectOptions.LoadModel. func ProjectFromOptions(ctx context.Context, options *ProjectOptions) (*types.Project, error) { return options.LoadProject(ctx) } diff --git a/go.mod b/go.mod index 80abf2c8..86ee059f 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/distribution/reference v0.5.0 github.com/docker/go-connections v0.4.0 github.com/docker/go-units v0.5.0 - github.com/go-viper/mapstructure/v2 v2.4.0 github.com/google/go-cmp v0.5.9 github.com/mattn/go-shellwords v1.0.12 github.com/opencontainers/go-digest v1.0.0 diff --git a/go.sum b/go.sum index b921dc5f..40a2dddc 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,6 @@ github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKoh github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= -github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= diff --git a/interpolation/node.go b/interpolation/node.go new file mode 100644 index 00000000..0e887226 --- /dev/null +++ b/interpolation/node.go @@ -0,0 +1,112 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package interpolation + +import ( + "fmt" + "os" + "strings" + + "go.yaml.in/yaml/v4" + + "github.com/compose-spec/compose-go/v2/template" + "github.com/compose-spec/compose-go/v2/tree" +) + +// InterpolateNode replaces variables in yaml.Node scalar values +func InterpolateNode(node *yaml.Node, opts Options) error { + if opts.LookupValue == nil { + opts.LookupValue = os.LookupEnv + } + if opts.TypeCastMapping == nil { + opts.TypeCastMapping = make(map[tree.Path]Cast) + } + if opts.Substitute == nil { + opts.Substitute = template.Substitute + } + return recursiveInterpolateNode(node, tree.NewPath(), opts) +} + +func recursiveInterpolateNode(node *yaml.Node, path tree.Path, opts Options) error { + switch node.Kind { + case yaml.DocumentNode: + if len(node.Content) > 0 { + return recursiveInterpolateNode(node.Content[0], path, opts) + } + return nil + + case yaml.MappingNode: + for i := 0; i+1 < len(node.Content); i += 2 { + key := node.Content[i] + value := node.Content[i+1] + if err := recursiveInterpolateNode(value, path.Next(key.Value), opts); err != nil { + return err + } + } + return nil + + case yaml.SequenceNode: + for _, item := range node.Content { + if err := recursiveInterpolateNode(item, path.Next(tree.PathMatchList), opts); err != nil { + return err + } + } + return nil + + case yaml.ScalarNode: + if node.Tag != "!!str" && node.Tag != "" && !strings.Contains(node.Value, "$") { + return nil + } + newValue, err := opts.Substitute(node.Value, template.Mapping(opts.LookupValue)) + if err != nil { + return newPathError(path, err) + } + caster, ok := opts.getCasterForPath(path) + if !ok { + if newValue != node.Value { + node.Value = newValue + } + return nil + } + casted, err := caster(newValue) + if err != nil { + return newPathError(path, fmt.Errorf("failed to cast to expected type: %w", err)) + } + switch casted.(type) { + case bool: + node.Tag = "!!bool" + node.Value = fmt.Sprint(casted) + case int, int64: + node.Tag = "!!int" + node.Value = fmt.Sprint(casted) + case float64: + node.Tag = "!!float" + node.Value = fmt.Sprint(casted) + case nil: + node.Tag = "!!null" + node.Value = "null" + case string: + node.Value = fmt.Sprint(casted) + default: + node.Value = fmt.Sprint(casted) + } + return nil + + default: + return nil + } +} diff --git a/interpolation/node_test.go b/interpolation/node_test.go new file mode 100644 index 00000000..7f242b7a --- /dev/null +++ b/interpolation/node_test.go @@ -0,0 +1,209 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package interpolation + +import ( + "encoding/json" + "strconv" + "testing" + + "github.com/compose-spec/compose-go/v2/tree" + "go.yaml.in/yaml/v4" + "gotest.tools/v3/assert" +) + +func TestInterpolateNode_Simple(t *testing.T) { + input := ` +services: + web: + image: ${IMAGE} +` + lookup := func(key string) (string, bool) { + if key == "IMAGE" { + return "nginx", true + } + return "", false + } + + var node yaml.Node + assert.NilError(t, yaml.Unmarshal([]byte(input), &node)) + err := InterpolateNode(&node, Options{LookupValue: lookup}) + assert.NilError(t, err) + + var result map[string]interface{} + assert.NilError(t, node.Decode(&result)) + + services := result["services"].(map[string]interface{}) + web := services["web"].(map[string]interface{}) + assert.Equal(t, "nginx", web["image"]) +} + +func TestInterpolateNode_Default(t *testing.T) { + input := ` +services: + web: + image: ${IMAGE:-default} +` + lookup := func(_ string) (string, bool) { + return "", false + } + + var node yaml.Node + assert.NilError(t, yaml.Unmarshal([]byte(input), &node)) + err := InterpolateNode(&node, Options{LookupValue: lookup}) + assert.NilError(t, err) + + var result map[string]interface{} + assert.NilError(t, node.Decode(&result)) + + services := result["services"].(map[string]interface{}) + web := services["web"].(map[string]interface{}) + assert.Equal(t, "default", web["image"]) +} + +func TestInterpolateNode_NoSubstitution(t *testing.T) { + input := ` +services: + web: + image: nginx + ports: + - "8080" +` + lookup := func(_ string) (string, bool) { + return "", false + } + + var node yaml.Node + assert.NilError(t, yaml.Unmarshal([]byte(input), &node)) + + // Take a snapshot before interpolation + var before map[string]interface{} + assert.NilError(t, node.Decode(&before)) + + err := InterpolateNode(&node, Options{LookupValue: lookup}) + assert.NilError(t, err) + + var after map[string]interface{} + assert.NilError(t, node.Decode(&after)) + + beforeJSON, _ := json.Marshal(before) + afterJSON, _ := json.Marshal(after) + assert.Equal(t, string(beforeJSON), string(afterJSON)) +} + +func TestInterpolateNode_TypeCast(t *testing.T) { + input := ` +services: + web: + ports: + - ${PORT} +` + lookup := func(key string) (string, bool) { + if key == "PORT" { + return "8080", true + } + return "", false + } + + toInt := func(value string) (interface{}, error) { + return strconv.Atoi(value) + } + + var node yaml.Node + assert.NilError(t, yaml.Unmarshal([]byte(input), &node)) + err := InterpolateNode(&node, Options{ + LookupValue: lookup, + TypeCastMapping: map[tree.Path]Cast{ + tree.NewPath("services", tree.PathMatchAll, "ports", tree.PathMatchList): toInt, + }, + }) + assert.NilError(t, err) + + var result map[string]interface{} + assert.NilError(t, node.Decode(&result)) + + services := result["services"].(map[string]interface{}) + web := services["web"].(map[string]interface{}) + ports := web["ports"].([]interface{}) + assert.Equal(t, 8080, ports[0]) +} + +func TestInterpolateNode_Parity(t *testing.T) { + input := ` +services: + web: + image: ${IMAGE} + environment: + FOO: ${FOO_VAL} + BAR: ${BAR_VAL:-default_bar} + labels: + version: ${VERSION} +` + env := map[string]string{ + "IMAGE": "nginx", + "FOO_VAL": "hello", + "VERSION": "1.0", + } + testInterpolateParity(t, input, env) +} + +func testInterpolateParity(t *testing.T, input string, env map[string]string) { + t.Helper() + lookup := func(key string) (string, bool) { + v, ok := env[key] + return v, ok + } + opts := Options{ + LookupValue: lookup, + } + + // Map-based + var mapData map[string]interface{} + assert.NilError(t, yaml.Unmarshal([]byte(input), &mapData)) + mapResult, err := Interpolate(mapData, opts) + assert.NilError(t, err) + + // Node-based + var node yaml.Node + assert.NilError(t, yaml.Unmarshal([]byte(input), &node)) + err = InterpolateNode(&node, opts) + assert.NilError(t, err) + + var nodeMap map[string]interface{} + assert.NilError(t, node.Decode(&nodeMap)) + + // Compare via JSON + mapJSON, _ := json.Marshal(mapResult) + nodeJSON, _ := json.Marshal(nodeMap) + assert.Equal(t, string(mapJSON), string(nodeJSON)) +} + +func TestInterpolateNode_Error(t *testing.T) { + input := ` +services: + web: + image: ${IMAGE:?} +` + lookup := func(_ string) (string, bool) { + return "", false + } + + var node yaml.Node + assert.NilError(t, yaml.Unmarshal([]byte(input), &node)) + err := InterpolateNode(&node, Options{LookupValue: lookup}) + assert.Assert(t, err != nil, "expected an error for missing required variable") +} diff --git a/loader/compose_model.go b/loader/compose_model.go new file mode 100644 index 00000000..1e64b4dc --- /dev/null +++ b/loader/compose_model.go @@ -0,0 +1,1345 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package loader + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path" + "path/filepath" + "reflect" + "regexp" + "slices" + "strconv" + "strings" + + "github.com/compose-spec/compose-go/v2/dotenv" + "github.com/compose-spec/compose-go/v2/format" + interp "github.com/compose-spec/compose-go/v2/interpolation" + "github.com/compose-spec/compose-go/v2/override" + "github.com/compose-spec/compose-go/v2/schema" + "github.com/compose-spec/compose-go/v2/template" + "github.com/compose-spec/compose-go/v2/tree" + "github.com/compose-spec/compose-go/v2/types" + "github.com/compose-spec/compose-go/v2/validation" + "github.com/sirupsen/logrus" + "go.yaml.in/yaml/v4" +) + +// NodeContext holds the loading context for a set of yaml nodes. +// Each node parsed from a compose file is associated with a NodeContext +// that captures the environment variables, working directory, and source file +// active at parse time. This allows deferred interpolation and path resolution +// using the correct context for each node, even after nodes from different +// files have been merged. +type NodeContext struct { + Source string + WorkingDir string + Env types.Mapping +} + +// Layer represents a single compose file parsed as a yaml.Node tree +// together with its loading context. +type Layer struct { + Node *yaml.Node + Context *NodeContext +} + +// ComposeModel holds raw yaml layers and resolves them lazily into a types.Project. +// Yaml nodes are kept in their raw (uninterpolated) form as long as possible. +// Interpolation, type casting, and path resolution are deferred until Resolve() +// is called, at which point each node is processed using its own NodeContext. +type ComposeModel struct { + layers []*Layer + configDetails types.ConfigDetails + opts *Options + // nodeContexts maps each yaml.Node to the loading context it was parsed under. + // This survives merging: original nodes keep their pointer identity, so + // after MergeNodes the map still resolves each leaf to its source context. + nodeContexts map[*yaml.Node]*NodeContext + // serviceWorkDirs maps service names to the working directory of the file + // they were defined in. Used for includes where paths are relative to + // the included file's directory. + serviceWorkDirs map[string]string + // loadedFiles tracks files that have been loaded, for cycle detection. + loadedFiles []string +} + +func init() { + // Wire up the volume-parsing hook so that types.ServiceVolumeConfig.UnmarshalYAML + // can parse short syntax without importing the format package directly. + types.ParseVolumeFunc = format.ParseVolume +} + +// LoadLazyModel parses compose files into raw yaml.Node layers without +// performing interpolation or normalization. The resulting ComposeModel +// can later be materialized into a types.Project by calling Resolve(). +func LoadLazyModel(_ context.Context, configDetails types.ConfigDetails, options ...func(*Options)) (*ComposeModel, error) { + opts := ToOptions(&configDetails, options) + + if len(configDetails.ConfigFiles) < 1 { + return nil, errors.New("no compose file specified") + } + + if err := projectName(&configDetails, opts); err != nil { + return nil, err + } + + model := &ComposeModel{ + configDetails: configDetails, + opts: opts, + nodeContexts: make(map[*yaml.Node]*NodeContext), + serviceWorkDirs: make(map[string]string), + } + + for _, file := range configDetails.ConfigFiles { + node, err := loadYamlFileNode(file) + if err != nil { + return nil, err + } + if node == nil { + continue + } + nodeCtx := &NodeContext{ + Source: file.Filename, + WorkingDir: configDetails.WorkingDir, + Env: configDetails.Environment, + } + layer := &Layer{Node: node, Context: nodeCtx} + model.layers = append(model.layers, layer) + model.registerNodes(node, nodeCtx) + model.loadedFiles = append(model.loadedFiles, file.Filename) + } + + if len(model.layers) == 0 { + return nil, errors.New("empty compose file") + } + + return model, nil +} + +// registerNodes associates every node in a tree with the given context. +func (m *ComposeModel) registerNodes(node *yaml.Node, ctx *NodeContext) { + m.registerNodesVisited(node, ctx, make(map[*yaml.Node]bool)) +} + +func (m *ComposeModel) registerNodesVisited(node *yaml.Node, ctx *NodeContext, visited map[*yaml.Node]bool) { + if node == nil || visited[node] { + return + } + visited[node] = true + m.nodeContexts[node] = ctx + for _, child := range node.Content { + m.registerNodesVisited(child, ctx, visited) + } + if node.Alias != nil { + m.registerNodesVisited(node.Alias, ctx, visited) + } +} + +// loadYamlFileNode parses a ConfigFile into a *yaml.Node tree, +// processing !reset and !override tags via ResetProcessor. +func loadYamlFileNode(file types.ConfigFile) (*yaml.Node, error) { + content := file.Content + if content == nil && file.Config == nil { + var err error + content, err = os.ReadFile(file.Filename) + if err != nil { + return nil, err + } + } + + if file.Config != nil { + // Config is already a map[string]any — marshal back to yaml then parse as node. + // This path is rare (used in tests) and maintains compatibility. + b, err := yaml.Marshal(file.Config) + if err != nil { + return nil, err + } + content = b + } + + r := bytes.NewReader(content) + decoder := yaml.NewDecoder(r) + + var result *yaml.Node + for { + var doc yaml.Node + err := decoder.Decode(&doc) + if err != nil && errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("failed to parse %s: %w", file.Filename, err) + } + // Keep !reset and !override tags in the tree — they are handled during merge + result = &doc + } + if result != nil { + if err := checkDuplicateKeys(result); err != nil { + return nil, err + } + } + return result, nil +} + +// checkDuplicateKeys recursively walks a yaml.Node tree and returns an error +// if any MappingNode contains duplicate keys. +func checkDuplicateKeys(node *yaml.Node) error { + if node == nil { + return nil + } + switch node.Kind { + case yaml.DocumentNode: + for _, child := range node.Content { + if err := checkDuplicateKeys(child); err != nil { + return err + } + } + case yaml.MappingNode: + keys := map[string]int{} + for i := 0; i+1 < len(node.Content); i += 2 { + keyNode := node.Content[i] + key := keyNode.Value + if line, seen := keys[key]; seen { + return fmt.Errorf("line %d: mapping key %#v already defined at line %d", keyNode.Line, key, line) + } + keys[key] = keyNode.Line + if err := checkDuplicateKeys(node.Content[i+1]); err != nil { + return err + } + } + case yaml.SequenceNode: + for _, child := range node.Content { + if err := checkDuplicateKeys(child); err != nil { + return err + } + } + } + return nil +} + +// Resolve materializes the lazy model into a types.Project. +// This is the point where all deferred processing happens: +// extends resolution, includes loading, merging, interpolation, +// type casting, and decoding into Go structs. +func (m *ComposeModel) Resolve() (*types.Project, error) { //nolint:gocyclo + // 1. Process extends on raw nodes (before interpolation, same as existing pipeline) + if !m.opts.SkipExtends { + for _, layer := range m.layers { + if err := m.applyExtendsNode(layer); err != nil { + return nil, err + } + } + } + + // 2. Process includes — loads referenced files as additional raw layers + if !m.opts.SkipInclude { + if err := m.applyIncludeNodes(); err != nil { + return nil, err + } + } + + // 3. Merge all raw layers into a single tree + var merged *yaml.Node + mergedSource := "" + for _, layer := range m.layers { + node := resolveDocumentNode(layer.Node) + if merged == nil { + merged = node + mergedSource = layer.Context.Source + continue + } + var err error + merged, err = override.MergeNodes(merged, node, tree.NewPath()) + if err != nil { + return nil, fmt.Errorf("merging %s: %w", layer.Context.Source, err) + } + mergedSource = layer.Context.Source + } + + if merged == nil { + return nil, errors.New("empty compose model") + } + + // 3a. Check top-level node is a mapping + if merged.Kind != yaml.MappingNode { + return nil, fmt.Errorf("top-level object must be a mapping") + } + + // 3b. Check for non-string keys + if err := checkNonStringKeysNode(merged, ""); err != nil { + return nil, err + } + + // 3c. Validate project name + if !m.opts.SkipValidation && m.opts.projectName == "" { + return nil, errors.New("project name must not be empty") + } + + // 3d. Strip remaining !reset tags and enforce sequence unicity + override.StripResetTags(merged) + override.EnforceUnicityNode(merged, tree.NewPath()) + + // 4. Interpolate and type-cast the merged tree. + // Each node is interpolated using the environment from its original + // source file (via nodeContexts). Nodes created during merge inherit + // context from the nearest registered ancestor. + if !m.opts.SkipInterpolation { + defaultCtx := m.layers[0].Context + if err := m.interpolateTree(merged, tree.NewPath(), defaultCtx); err != nil { + return nil, err + } + } + + // 4b. Transform deprecated external.name syntax to external: true + name + if err := transformExternalNodes(merged); err != nil { + return nil, err + } + + // 5. Schema validation (on map[string]any representation, for backward compat) + if !m.opts.SkipValidation { + var dict map[string]any + if err := merged.Decode(&dict); err == nil { + if err := schema.Validate(dict); err != nil { + return nil, fmt.Errorf("validating %s: %w", mergedSource, err) + } + if _, ok := dict["version"]; ok { + m.opts.warnObsoleteVersion(mergedSource) + } + } + if err := validation.Validate(dict); err != nil { + return nil, err + } + } + + // 6. Decode into types.Project + project := &types.Project{ + Name: m.opts.projectName, + WorkingDir: m.configDetails.WorkingDir, + Environment: m.configDetails.Environment, + } + + override.DeleteKey(merged, "name") + override.DeleteKey(merged, "include") + override.DeleteKey(merged, "version") + + if err := merged.Decode(project); err != nil { + err = m.enrichError(err, mergedSource) + return nil, fmt.Errorf("decoding compose model: %w", err) + } + + // 6a. Process known extensions — convert raw extension values to registered Go types + if len(m.opts.KnownExtensions) > 0 { + if err := processProjectExtensions(project, m.opts.KnownExtensions); err != nil { + return nil, err + } + } + + // 6b. Set service names from map keys (always, even when SkipNormalization) + for name, svc := range project.Services { + svc.Name = name + project.Services[name] = svc + } + + // 6c. Validate environment variable whitespace (with path context) + for name, svc := range project.Services { + for k := range svc.Environment { + if k != "" && k[len(k)-1] == ' ' { + return nil, fmt.Errorf("'services[%s].environment' environment variable %s is declared with a trailing space", name, k) + } + } + } + + // 7. Normalization (default network, resource names, build defaults, etc.) + if !m.opts.SkipNormalization { + normalizeProject(project) + } + + // 7a. Always resolve environment references in secrets/configs + resolveSecretConfigEnvironment(project) + + // 8. Path resolution + if m.opts.ResolvePaths { + if err := resolveProjectPaths(project, m.opts); err != nil { + return nil, err + } + } + + // 9. Windows path conversion + if m.opts.ConvertWindowsPaths { + for name, svc := range project.Services { + for j, vol := range svc.Volumes { + svc.Volumes[j] = convertVolumePath(vol) + } + project.Services[name] = svc + } + } + + // 10. Apply profiles filter + var err error + if project, err = project.WithProfiles(m.opts.Profiles); err != nil { + return nil, err + } + + // 11. Consistency check + if !m.opts.SkipConsistencyCheck { + if err := checkConsistency(project); err != nil { + return nil, err + } + } + + // 12. Resolve environment + if !m.opts.SkipResolveEnvironment { + project, err = project.WithServicesEnvironmentResolved(m.opts.discardEnvFiles) + if err != nil { + return nil, err + } + } + + project, err = project.WithServicesLabelsResolved(m.opts.discardEnvFiles) + if err != nil { + return nil, err + } + + return project, nil +} + +// interpolateTree walks the yaml.Node tree and interpolates scalar values +// using per-node context from nodeContexts. Nodes not found in the map +// inherit context from the nearest registered ancestor during the walk. +func (m *ComposeModel) interpolateTree(node *yaml.Node, p tree.Path, inherited *NodeContext) error { + if node == nil { + return nil + } + + // Use this node's own context if registered, otherwise inherit + ctx := inherited + if c, ok := m.nodeContexts[node]; ok { + ctx = c + } + + switch node.Kind { + case yaml.DocumentNode: + for _, child := range node.Content { + if err := m.interpolateTree(child, p, ctx); err != nil { + return err + } + } + + case yaml.MappingNode: + for i := 0; i+1 < len(node.Content); i += 2 { + keyNode := node.Content[i] + valNode := node.Content[i+1] + + // Determine context for this key-value pair. + // Prefer the value's own context (it came from a specific layer), + // then the key's context, then the inherited context. + pairCtx := ctx + if c, ok := m.nodeContexts[valNode]; ok { + pairCtx = c + } else if c, ok := m.nodeContexts[keyNode]; ok { + pairCtx = c + } + + next := p.Next(keyNode.Value) + if err := m.interpolateTree(valNode, next, pairCtx); err != nil { + return err + } + } + + case yaml.SequenceNode: + for _, child := range node.Content { + if err := m.interpolateTree(child, p.Next(tree.PathMatchList), ctx); err != nil { + return err + } + } + + case yaml.ScalarNode: + return m.interpolateScalar(node, p, ctx) + } + + return nil +} + +// interpolateScalar substitutes variables and applies type casting on a single scalar node. +func (m *ComposeModel) interpolateScalar(node *yaml.Node, p tree.Path, ctx *NodeContext) error { + if node.Tag != "!!str" && node.Tag != "" && !strings.Contains(node.Value, "$") { + return nil + } + + lookup := func(key string) (string, bool) { + if ctx == nil { + return "", false + } + v, ok := ctx.Env[key] + return v, ok + } + + newValue, err := template.Substitute(node.Value, template.Mapping(lookup)) + if err != nil { + return err + } + + // Type casting based on tree path + var caster interp.Cast + for pattern, c := range interpolateTypeCastMapping { + if p.Matches(pattern) { + caster = c + break + } + } + + if caster == nil { + node.Value = newValue + return nil + } + + casted, err := caster(newValue) + if err != nil { + return types.NodeErrorf(node, "failed to cast to expected type: %v", err) + } + + switch casted.(type) { + case bool: + node.Tag = "!!bool" + case int, int64: + node.Tag = "!!int" + case float64: + node.Tag = "!!float" + case nil: + node.Tag = "!!null" + node.Value = "null" + return nil + } + node.Value = fmt.Sprint(casted) + return nil +} + +// enrichError adds source file information to error messages. +// It tries to find the correct source file by matching line numbers +// from the error against nodes tracked in nodeContexts. +func (m *ComposeModel) enrichError(err error, fallbackSource string) error { + if err == nil { + return nil + } + source := m.findSourceForError(err) + if source == "" { + source = fallbackSource + if source == "" && len(m.layers) > 0 { + source = m.layers[0].Context.Source + } + } + return types.WithSource(err, source) +} + +// findSourceForError extracts line/column numbers from yaml error messages and +// looks up which source file contains a node at that position. +func (m *ComposeModel) findSourceForError(err error) string { + msg := err.Error() + // yaml/v4 errors look like: "line N, column M: ..." + // Try matching with both line and column first for precision + re := regexp.MustCompile(`line (\d+), column (\d+)`) + matches := re.FindStringSubmatch(msg) + if len(matches) >= 3 { + lineNum, _ := strconv.Atoi(matches[1]) + colNum, _ := strconv.Atoi(matches[2]) + for node, ctx := range m.nodeContexts { + if node.Line == lineNum && node.Column == colNum { + return ctx.Source + } + } + } + // Fallback: match on line only + reLineOnly := regexp.MustCompile(`line (\d+)`) + lineMatches := reLineOnly.FindStringSubmatch(msg) + if len(lineMatches) < 2 { + return "" + } + lineNum, convErr := strconv.Atoi(lineMatches[1]) + if convErr != nil { + return "" + } + for node, ctx := range m.nodeContexts { + if node.Line == lineNum { + return ctx.Source + } + } + return "" +} + +// applyExtendsNode processes "extends" directives within a single layer's node tree. +func (m *ComposeModel) applyExtendsNode(layer *Layer) error { + node := resolveDocumentNode(layer.Node) + _, services := override.FindKey(node, "services") + if services == nil || services.Kind != yaml.MappingNode { + return nil + } + + resolved := map[string]bool{} + for i := 0; i+1 < len(services.Content); i += 2 { + name := services.Content[i].Value + if err := m.resolveServiceExtends(layer, services, name, resolved, nil); err != nil { + return err + } + } + return nil +} + +func (m *ComposeModel) resolveServiceExtends(layer *Layer, services *yaml.Node, name string, resolved map[string]bool, chain []string) error { + if resolved[name] { + return nil + } + + // cycle detection using file:service identifiers + chainID := layer.Context.Source + ":" + name + if slices.Contains(chain, chainID) { + return fmt.Errorf("circular reference with extends") + } + chain = append(chain, chainID) + + _, svcNode := override.FindKey(services, name) + if svcNode == nil { + return nil + } + + _, extendsNode := override.FindKey(svcNode, "extends") + if extendsNode == nil { + resolved[name] = true + return nil + } + + var refService string + var refFile string + + switch extendsNode.Kind { + case yaml.ScalarNode: + refService = extendsNode.Value + m.opts.ProcessEvent("extends", map[string]any{"service": refService}) + case yaml.MappingNode: + _, sn := override.FindKey(extendsNode, "service") + if sn == nil { + return fmt.Errorf("extends.%s.service is required", name) + } + refService = sn.Value + _, fn := override.FindKey(extendsNode, "file") + if fn != nil { + refFile = fn.Value + } + metadata := map[string]any{"service": refService} + if refFile != "" { + metadata["file"] = refFile + } + m.opts.ProcessEvent("extends", metadata) + default: + return types.NodeErrorf(extendsNode, "extends must be a string or mapping") + } + + var baseService *yaml.Node + + if refFile != "" { + // Load from external file, checking remote resource loaders first + filePath := refFile + for _, loader := range m.opts.RemoteResourceLoaders() { + if loader.Accept(refFile) { + resolved, loadErr := loader.Load(context.TODO(), refFile) + if loadErr != nil { + return loadErr + } + filePath = resolved + break + } + } + if filePath == refFile && !filepath.IsAbs(filePath) { + filePath = filepath.Join(layer.Context.WorkingDir, filePath) + } + extNode, err := loadYamlFileNode(types.ConfigFile{Filename: filePath}) + if err != nil { + return types.WrapNodeError(extendsNode, fmt.Errorf("loading extends file %s: %w", refFile, err)) + } + if extNode == nil { + return types.NodeErrorf(extendsNode, "extends file %s is empty", refFile) + } + + extDir := filepath.Dir(filePath) + + // Register nodes from the external file with their own context + extCtx := &NodeContext{ + Source: filePath, + WorkingDir: extDir, + Env: layer.Context.Env, + } + m.registerNodes(extNode, extCtx) + + extRoot := resolveDocumentNode(extNode) + _, extServices := override.FindKey(extRoot, "services") + if extServices == nil { + return types.NodeErrorf(extendsNode, "extends file %s has no services", refFile) + } + _, baseService = override.FindKey(extServices, refService) + if baseService == nil { + return types.NodeErrorf(extendsNode, "service %q not found in %s", refService, refFile) + } + + // Recursively resolve extends in the base service's file + extResolved := map[string]bool{} + extLayer := &Layer{Node: extNode, Context: extCtx} + if err := m.resolveServiceExtends(extLayer, extServices, refService, extResolved, chain); err != nil { + return err + } + // Re-fetch after resolution + _, baseService = override.FindKey(extServices, refService) + + // Resolve relative paths in the base service using the extends file's + // relative directory, so paths are expressed relative to the main project dir. + relWorkDir, err := filepath.Rel(layer.Context.WorkingDir, extDir) + if err != nil { + relWorkDir = extDir + } + resolveServiceNodePaths(baseService, relWorkDir) + } else { + // Same file + if err := m.resolveServiceExtends(layer, services, refService, resolved, chain); err != nil { + return err + } + _, baseService = override.FindKey(services, refService) + if baseService == nil { + return types.NodeErrorf(extendsNode, "service %q not found", refService) + } + } + + // Deep clone base before merge to avoid mutating it + baseClone := m.deepCloneNode(baseService) + + // Merge: base extended by current service + merged, err := override.ExtendServiceNode(baseClone, svcNode) + if err != nil { + return types.WrapNodeError(extendsNode, fmt.Errorf("extending service %s: %w", name, err)) + } + + // Remove "extends" from merged result + override.DeleteKey(merged, "extends") + + // Replace service node in-place + override.SetKey(services, name, merged) + resolved[name] = true + return nil +} + +// deepCloneNode creates a deep copy of a yaml.Node tree, +// propagating node contexts from originals to clones. +func (m *ComposeModel) deepCloneNode(node *yaml.Node) *yaml.Node { + if node == nil { + return nil + } + clone := &yaml.Node{ + Kind: node.Kind, + Style: node.Style, + Tag: node.Tag, + Value: node.Value, + Anchor: node.Anchor, + HeadComment: node.HeadComment, + LineComment: node.LineComment, + FootComment: node.FootComment, + Line: node.Line, + Column: node.Column, + } + // Propagate context from original to clone + if ctx, ok := m.nodeContexts[node]; ok { + m.nodeContexts[clone] = ctx + } + if node.Alias != nil { + clone.Alias = m.deepCloneNode(node.Alias) + } + if len(node.Content) > 0 { + clone.Content = make([]*yaml.Node, len(node.Content)) + for i, c := range node.Content { + clone.Content[i] = m.deepCloneNode(c) + } + } + return clone +} + +// applyIncludeNodes processes "include" directives from the layers, +// loading referenced files as additional raw layers with their own context. +// Includes are inserted BEFORE their parent layer so the parent takes +// precedence during merge (matching the old pipeline's behavior). +func (m *ComposeModel) applyIncludeNodes() error { + var newLayers []*Layer + for _, layer := range m.layers { + node := resolveDocumentNode(layer.Node) + _, includeNode := override.FindKey(node, "include") + if includeNode == nil { + newLayers = append(newLayers, layer) + continue + } + if includeNode.Kind != yaml.SequenceNode { + return types.NodeErrorf(includeNode, "include must be a sequence") + } + + for _, entry := range includeNode.Content { + includeLayers, err := m.loadIncludeEntry(layer, entry) + if err != nil { + return err + } + // Prepend includes before parent so parent overrides them + newLayers = append(newLayers, includeLayers...) + } + newLayers = append(newLayers, layer) + } + m.layers = newLayers + return nil +} + +func (m *ComposeModel) loadIncludeEntry(parent *Layer, entry *yaml.Node) ([]*Layer, error) { //nolint:gocyclo + var paths []string + var projectDir string + var envFiles []string + + switch entry.Kind { + case yaml.ScalarNode: + paths = []string{entry.Value} + case yaml.MappingNode: + _, pathNode := override.FindKey(entry, "path") + if pathNode != nil { + switch pathNode.Kind { + case yaml.ScalarNode: + paths = []string{pathNode.Value} + case yaml.SequenceNode: + for _, p := range pathNode.Content { + paths = append(paths, p.Value) + } + } + } + _, pdNode := override.FindKey(entry, "project_directory") + if pdNode != nil { + projectDir = pdNode.Value + } + _, efNode := override.FindKey(entry, "env_file") + if efNode != nil { + switch efNode.Kind { + case yaml.ScalarNode: + envFiles = []string{efNode.Value} + case yaml.SequenceNode: + for _, e := range efNode.Content { + envFiles = append(envFiles, e.Value) + } + } + } + default: + return nil, types.NodeErrorf(entry, "include entry must be a string or mapping") + } + + if len(paths) == 0 { + return nil, types.NodeErrorf(entry, "include entry has no path") + } + + // Resolve paths: check remote resource loaders first, then resolve relative to parent + for i, p := range paths { + resolved := false + for _, loader := range m.opts.RemoteResourceLoaders() { + if !loader.Accept(p) { + continue + } + absPath, loadErr := loader.Load(context.TODO(), p) + if loadErr != nil { + return nil, types.WrapNodeError(entry, fmt.Errorf("loading include %s: %w", p, loadErr)) + } + paths[i] = absPath + resolved = true + break + } + if !resolved && !filepath.IsAbs(p) { + paths[i] = filepath.Join(parent.Context.WorkingDir, p) + } + } + + // Cycle detection + for _, p := range paths { + for _, loaded := range m.loadedFiles { + if loaded == p { + m.loadedFiles = append(m.loadedFiles, p) + return nil, fmt.Errorf("include cycle detected:\n%s\n include %s", + m.loadedFiles[0], strings.Join(m.loadedFiles[1:], "\n include ")) + } + } + } + + // Determine working directory for the included files (absolute) + workDir := projectDir + if workDir == "" { + workDir = filepath.Dir(paths[0]) + } else if !filepath.IsAbs(workDir) { + workDir = filepath.Join(parent.Context.WorkingDir, workDir) + } + + // Compute relative working dir from main project dir to include dir. + // This is used to resolve paths in the included file so they are + // expressed relative to the main project directory. + mainWorkDir := m.configDetails.WorkingDir + relWorkDir, err := filepath.Rel(mainWorkDir, workDir) + if err != nil { + relWorkDir = workDir + } + + // Resolve environment: parent env + env_file + env := parent.Context.Env.Clone() + if len(envFiles) > 0 { + for i, f := range envFiles { + if !filepath.IsAbs(f) { + envFiles[i] = filepath.Join(parent.Context.WorkingDir, f) + } + } + envFromFile, err := dotenv.GetEnvFromFile(env, envFiles) + if err != nil { + return nil, types.WrapNodeError(entry, err) + } + env = env.Merge(envFromFile) + } + + // Load each path as a raw layer with the include-specific context + var layers []*Layer + for _, p := range paths { + m.loadedFiles = append(m.loadedFiles, p) + node, err := loadYamlFileNode(types.ConfigFile{Filename: p}) + if err != nil { + return nil, types.WrapNodeError(entry, fmt.Errorf("loading include %s: %w", p, err)) + } + if node == nil { + continue + } + + nodeCtx := &NodeContext{ + Source: p, + WorkingDir: workDir, + Env: env, + } + m.registerNodes(node, nodeCtx) + + incLayer := &Layer{Node: node, Context: nodeCtx} + if !m.opts.SkipExtends { + if err := m.applyExtendsNode(incLayer); err != nil { + return nil, fmt.Errorf("%s: %w", p, err) + } + } + + // Resolve relative paths in the included layer using the relative + // working directory, so paths are expressed relative to the main + // project directory (matching the old pipeline's behavior). + resolveLayerNodePaths(node, relWorkDir) + + // Resolve bare environment variables (e.g., "VAR_NAME" without "=") + // using the include's environment, so they survive merge correctly. + resolveLayerEnvironment(node, env) + + layers = append(layers, incLayer) + } + return layers, nil +} + +// resolveDocumentNode unwraps a DocumentNode to get the actual mapping node. +func resolveDocumentNode(node *yaml.Node) *yaml.Node { + if node != nil && node.Kind == yaml.DocumentNode && len(node.Content) == 1 { + return node.Content[0] + } + return node +} + +// resolveLayerEnvironment resolves bare environment variable references +// (entries like "VAR_NAME" without "=") in all services within a layer, +// using the given environment mapping. This must be done before merge so that +// include-specific environment variables are correctly resolved. +func resolveLayerEnvironment(node *yaml.Node, env types.Mapping) { + root := resolveDocumentNode(node) + if root == nil || root.Kind != yaml.MappingNode { + return + } + _, services := override.FindKey(root, "services") + if services == nil || services.Kind != yaml.MappingNode { + return + } + for i := 0; i+1 < len(services.Content); i += 2 { + svc := services.Content[i+1] + if svc == nil || svc.Kind != yaml.MappingNode { + continue + } + _, envNode := override.FindKey(svc, "environment") + if envNode == nil || envNode.Kind != yaml.SequenceNode { + continue + } + for _, item := range envNode.Content { + if item.Kind != yaml.ScalarNode { + continue + } + // Only process bare variable names (no "=" sign) + if !strings.Contains(item.Value, "=") { + if val, ok := env[item.Value]; ok { + item.Value = fmt.Sprintf("%s=%s", item.Value, val) + } + } + } + } +} + +// addBuildContextDefault ensures a build mapping has a "context" key. +// If the build is a mapping without a context, "." is added as default. +// This must be called before path resolution for includes, so the default +// context gets resolved relative to the include's working directory. +func addBuildContextDefault(svc *yaml.Node) { + if svc == nil || svc.Kind != yaml.MappingNode { + return + } + _, build := override.FindKey(svc, "build") + if build == nil || build.Kind != yaml.MappingNode { + return + } + _, ctx := override.FindKey(build, "context") + if ctx == nil { + override.SetKey(build, "context", override.NewScalar(".")) + } +} + +// resolveServiceNodePaths adjusts relative paths in a service yaml.Node to be +// expressed relative to workDir. This is used for extends from external files +// to make paths relative to the main project directory. +func resolveServiceNodePaths(svc *yaml.Node, workDir string) { //nolint:gocyclo + if svc == nil || svc.Kind != yaml.MappingNode || workDir == "." { + return + } + + absNodePath := func(p string) string { + if filepath.IsAbs(p) || path.IsAbs(p) || p == "" { + return p + } + return filepath.Join(workDir, p) + } + + // build.context + _, build := override.FindKey(svc, "build") + if build != nil { + switch build.Kind { + case yaml.ScalarNode: + // short syntax: build: ./path + if !strings.Contains(build.Value, "://") { + build.Value = absNodePath(build.Value) + } + case yaml.MappingNode: + _, ctx := override.FindKey(build, "context") + if ctx != nil && ctx.Kind == yaml.ScalarNode && !strings.Contains(ctx.Value, "://") { + ctx.Value = absNodePath(ctx.Value) + } + _, addCtx := override.FindKey(build, "additional_contexts") + if addCtx != nil && addCtx.Kind == yaml.MappingNode { + for i := 0; i+1 < len(addCtx.Content); i += 2 { + v := addCtx.Content[i+1] + if v.Kind == yaml.ScalarNode && !strings.Contains(v.Value, "://") { + v.Value = absNodePath(v.Value) + } + } + } + } + } + + // env_file + _, envFile := override.FindKey(svc, "env_file") + if envFile != nil { + resolveEnvFileNodePaths(envFile, absNodePath) + } + + // label_file + _, labelFile := override.FindKey(svc, "label_file") + if labelFile != nil && labelFile.Kind == yaml.SequenceNode { + for _, item := range labelFile.Content { + if item.Kind == yaml.ScalarNode { + item.Value = absNodePath(item.Value) + } + } + } + + // volumes (only for bind mount sources that are relative paths) + _, volumes := override.FindKey(svc, "volumes") + if volumes != nil && volumes.Kind == yaml.SequenceNode { + for i, item := range volumes.Content { + switch item.Kind { + case yaml.MappingNode: + _, vtype := override.FindKey(item, "type") + if vtype != nil && vtype.Value == "bind" { + _, src := override.FindKey(item, "source") + if src != nil && src.Kind == yaml.ScalarNode { + src.Value = absNodePath(src.Value) + } + } + case yaml.ScalarNode: + // Short syntax: parse, resolve source, convert to long syntax mapping + vol, err := format.ParseVolume(item.Value) + if err == nil && vol.Type == types.VolumeTypeBind && vol.Source != "" && !filepath.IsAbs(vol.Source) && !path.IsAbs(vol.Source) { + vol.Source = absNodePath(vol.Source) + // Convert to long syntax mapping node to preserve bind type + trueNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"} + bindNode := override.NewMapping(override.KeyValue{ + Key: "create_host_path", Value: trueNode, + }) + pairs := []override.KeyValue{ + {Key: "type", Value: override.NewScalar("bind")}, + {Key: "source", Value: override.NewScalar(vol.Source)}, + {Key: "target", Value: override.NewScalar(vol.Target)}, + {Key: "bind", Value: bindNode}, + } + if vol.ReadOnly { + pairs = append(pairs, override.KeyValue{Key: "read_only", Value: override.NewScalar("true")}) + } + volumes.Content[i] = override.NewMapping(pairs...) + } + } + } + } + + // extends.file + _, extends := override.FindKey(svc, "extends") + if extends != nil && extends.Kind == yaml.MappingNode { + _, file := override.FindKey(extends, "file") + if file != nil && file.Kind == yaml.ScalarNode { + file.Value = absNodePath(file.Value) + } + } + + // develop.watch[].path + _, develop := override.FindKey(svc, "develop") + if develop != nil && develop.Kind == yaml.MappingNode { + _, watch := override.FindKey(develop, "watch") + if watch != nil && watch.Kind == yaml.SequenceNode { + for _, item := range watch.Content { + if item.Kind == yaml.MappingNode { + _, p := override.FindKey(item, "path") + if p != nil && p.Kind == yaml.ScalarNode { + p.Value = absNodePath(p.Value) + } + } + } + } + } +} + +// resolveEnvFileNodePaths adjusts env_file paths using the given abs function. +func resolveEnvFileNodePaths(node *yaml.Node, absPath func(string) string) { + switch node.Kind { + case yaml.ScalarNode: + node.Value = absPath(node.Value) + case yaml.SequenceNode: + for _, item := range node.Content { + switch item.Kind { + case yaml.ScalarNode: + item.Value = absPath(item.Value) + case yaml.MappingNode: + _, p := override.FindKey(item, "path") + if p != nil && p.Kind == yaml.ScalarNode { + p.Value = absPath(p.Value) + } + } + } + } +} + +// transformExternalNodes processes deprecated external.name syntax in volumes, +// networks, secrets, and configs. Converts `external: {name: foo}` to +// `external: true` and sets the name key on the parent resource. +func transformExternalNodes(node *yaml.Node) error { + for _, section := range []string{"volumes", "networks", "secrets", "configs"} { + _, sectionNode := override.FindKey(node, section) + if sectionNode == nil || sectionNode.Kind != yaml.MappingNode { + continue + } + for i := 0; i+1 < len(sectionNode.Content); i += 2 { + resourceKey := sectionNode.Content[i].Value + resource := sectionNode.Content[i+1] + if resource == nil || resource.Kind != yaml.MappingNode { + continue + } + _, extNode := override.FindKey(resource, "external") + if extNode == nil || extNode.Kind != yaml.MappingNode { + continue + } + // external is a mapping — deprecated syntax + _, extNameNode := override.FindKey(extNode, "name") + if extNameNode != nil && extNameNode.Kind == yaml.ScalarNode { + extName := extNameNode.Value + _, nameNode := override.FindKey(resource, "name") + p := tree.NewPath(section, resourceKey) + logrus.Warnf("%s: external.name is deprecated. Please set name and external: true", p) + if nameNode != nil && nameNode.Kind == yaml.ScalarNode && nameNode.Value != extName { + return fmt.Errorf("%s: name and external.name conflict; only use name", p) + } + if nameNode == nil { + override.SetKey(resource, "name", override.NewScalar(extName)) + } + } + // Replace external mapping with scalar true + extNode.Kind = yaml.ScalarNode + extNode.Tag = "!!bool" + extNode.Value = "true" + extNode.Content = nil + } + } + return nil +} + +// checkNonStringKeysNode walks a yaml.Node tree and returns an error if any +// mapping key is not a string (e.g., an integer key like `123: value`). +func checkNonStringKeysNode(node *yaml.Node, keyPrefix string) error { + if node == nil { + return nil + } + switch node.Kind { + case yaml.MappingNode: + for i := 0; i+1 < len(node.Content); i += 2 { + keyNode := node.Content[i] + valNode := node.Content[i+1] + if keyNode.Tag != "" && keyNode.Tag != "!!str" && keyNode.Tag != "!!merge" { + var location string + if keyPrefix == "" { + location = "at top level" + } else { + location = fmt.Sprintf("in %s", keyPrefix) + } + return fmt.Errorf("non-string key %s: %s", location, keyNode.Value) + } + var childPrefix string + if keyPrefix == "" { + childPrefix = keyNode.Value + } else { + childPrefix = fmt.Sprintf("%s.%s", keyPrefix, keyNode.Value) + } + if err := checkNonStringKeysNode(valNode, childPrefix); err != nil { + return err + } + } + case yaml.SequenceNode: + for idx, item := range node.Content { + childPrefix := fmt.Sprintf("%s[%d]", keyPrefix, idx) + if err := checkNonStringKeysNode(item, childPrefix); err != nil { + return err + } + } + } + return nil +} + +// resolveLayerNodePaths resolves all relative paths in a layer's node tree +// using the given working directory. This makes paths absolute so they survive +// merging into the main project with a different working directory. +func resolveLayerNodePaths(node *yaml.Node, workDir string) { + root := resolveDocumentNode(node) + if root == nil || root.Kind != yaml.MappingNode { + return + } + + absPath := func(p string) string { + if filepath.IsAbs(p) || path.IsAbs(p) || p == "" { + return p + } + return filepath.Join(workDir, p) + } + + _, services := override.FindKey(root, "services") + if services != nil && services.Kind == yaml.MappingNode { + for i := 0; i+1 < len(services.Content); i += 2 { + svc := services.Content[i+1] + // Add build.context default before path resolution so it gets + // resolved relative to the include's working directory. + addBuildContextDefault(svc) + resolveServiceNodePaths(svc, workDir) + } + } + + // configs.*.file + _, configs := override.FindKey(root, "configs") + if configs != nil && configs.Kind == yaml.MappingNode { + for i := 0; i+1 < len(configs.Content); i += 2 { + cfg := configs.Content[i+1] + if cfg.Kind == yaml.MappingNode { + _, file := override.FindKey(cfg, "file") + if file != nil && file.Kind == yaml.ScalarNode { + file.Value = absPath(file.Value) + } + } + } + } + + // secrets.*.file + _, secrets := override.FindKey(root, "secrets") + if secrets != nil && secrets.Kind == yaml.MappingNode { + for i := 0; i+1 < len(secrets.Content); i += 2 { + sec := secrets.Content[i+1] + if sec.Kind == yaml.MappingNode { + _, file := override.FindKey(sec, "file") + if file != nil && file.Kind == yaml.ScalarNode { + file.Value = absPath(file.Value) + } + } + } + } +} + +// processProjectExtensions converts raw extension values to registered Go types. +func processProjectExtensions(project *types.Project, known map[string]any) error { + convertExtensions := func(ext types.Extensions) error { + for name, val := range ext { + typ, ok := known[name] + if !ok { + continue + } + // Marshal the raw value to yaml, then unmarshal into the target type + b, err := yaml.Marshal(val) + if err != nil { + return fmt.Errorf("converting extension %s: %w", name, err) + } + target := reflect.New(reflect.TypeOf(typ)).Interface() + if err := yaml.Unmarshal(b, target); err != nil { + return fmt.Errorf("converting extension %s: %w", name, err) + } + ext[name] = reflect.ValueOf(target).Elem().Interface() + } + return nil + } + + if err := convertExtensions(project.Extensions); err != nil { + return err + } + for name, svc := range project.Services { + if err := convertExtensions(svc.Extensions); err != nil { + return err + } + project.Services[name] = svc + } + for name, net := range project.Networks { + if err := convertExtensions(net.Extensions); err != nil { + return err + } + project.Networks[name] = net + } + for name, vol := range project.Volumes { + if err := convertExtensions(vol.Extensions); err != nil { + return err + } + project.Volumes[name] = vol + } + return nil +} diff --git a/loader/compose_model_test.go b/loader/compose_model_test.go new file mode 100644 index 00000000..f706ccd0 --- /dev/null +++ b/loader/compose_model_test.go @@ -0,0 +1,586 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package loader + +import ( + "context" + "os" + "path/filepath" + "sort" + "strings" + "testing" + + "github.com/compose-spec/compose-go/v2/types" + "gotest.tools/v3/assert" +) + +func lazyLoad(t *testing.T, yaml string, env map[string]string, options ...func(*Options)) *types.Project { + t.Helper() + if env == nil { + env = map[string]string{} + } + workingDir, err := os.Getwd() + assert.NilError(t, err) + + configDetails := types.ConfigDetails{ + WorkingDir: workingDir, + ConfigFiles: []types.ConfigFile{ + {Filename: "compose.yml", Content: []byte(yaml)}, + }, + Environment: env, + } + model, err := LoadLazyModel(context.TODO(), configDetails, options...) + assert.NilError(t, err) + + project, err := model.Resolve() + assert.NilError(t, err) + return project +} + +func TestLoadLazyModel_Simple(t *testing.T) { + yaml := ` +name: testproject +services: + web: + image: nginx + ports: + - "8080:80" + db: + image: postgres + environment: + POSTGRES_DB: mydb +` + project := lazyLoad(t, yaml, nil, func(o *Options) { + o.SkipConsistencyCheck = true + }) + + assert.Equal(t, project.Name, "testproject") + assert.Equal(t, len(project.Services), 2) + + web := project.Services["web"] + assert.Equal(t, web.Image, "nginx") + + db := project.Services["db"] + assert.Equal(t, db.Image, "postgres") + + val, ok := db.Environment["POSTGRES_DB"] + assert.Assert(t, ok) + assert.Assert(t, val != nil) + assert.Equal(t, *val, "mydb") +} + +func TestLoadLazyModel_Interpolation(t *testing.T) { + yaml := ` +name: interpoltest +services: + app: + image: ${APP_IMAGE} +` + project := lazyLoad(t, yaml, map[string]string{ + "APP_IMAGE": "myapp:latest", + }, func(o *Options) { + o.SkipConsistencyCheck = true + }) + + app := project.Services["app"] + assert.Equal(t, app.Image, "myapp:latest") +} + +func TestLoadLazyModel_BuildShortSyntax(t *testing.T) { + yaml := ` +name: buildtest +services: + app: + build: ./app +` + project := lazyLoad(t, yaml, nil, func(o *Options) { + o.SkipConsistencyCheck = true + }) + + app := project.Services["app"] + assert.Assert(t, app.Build != nil) + assert.Assert(t, strings.Contains(app.Build.Context, "app"), + "expected Build.Context to contain 'app', got: %s", app.Build.Context) +} + +func TestLoadLazyModel_DependsOnShortSyntax(t *testing.T) { + yaml := ` +name: depstest +services: + web: + image: nginx + depends_on: + - db + db: + image: postgres +` + project := lazyLoad(t, yaml, nil, func(o *Options) { + o.SkipConsistencyCheck = true + }) + + web := project.Services["web"] + dep, ok := web.DependsOn["db"] + assert.Assert(t, ok, "expected depends_on to contain 'db'") + assert.Equal(t, dep.Condition, "service_started") + assert.Equal(t, dep.Required, true) +} + +func TestLoadLazyModel_Parity(t *testing.T) { + yaml := ` +name: paritytest +services: + web: + image: nginx:latest + ports: + - "8080:80" + environment: + - FOO=bar + labels: + app: web + depends_on: + - db + db: + image: postgres:15 + environment: + POSTGRES_DB: mydb + POSTGRES_USER: user +volumes: + data: {} +networks: + frontend: {} +` + sharedOpts := func(o *Options) { + o.SkipConsistencyCheck = true + o.SkipNormalization = true + o.ResolvePaths = true + } + + env := map[string]string{} + workingDir, err := os.Getwd() + assert.NilError(t, err) + + configDetails := types.ConfigDetails{ + WorkingDir: workingDir, + ConfigFiles: []types.ConfigFile{ + {Filename: "compose.yml", Content: []byte(yaml)}, + }, + Environment: env, + } + + // Load with existing pipeline + projectOld, err := LoadWithContext(context.TODO(), configDetails, sharedOpts) + assert.NilError(t, err) + + // Load with new lazy pipeline + model, err := LoadLazyModel(context.TODO(), configDetails, sharedOpts) + assert.NilError(t, err) + projectNew, err := model.Resolve() + assert.NilError(t, err) + + // Compare project names + assert.Equal(t, projectOld.Name, projectNew.Name) + + // Compare service names + oldNames := projectOld.ServiceNames() + newNames := projectNew.ServiceNames() + sort.Strings(oldNames) + sort.Strings(newNames) + assert.DeepEqual(t, oldNames, newNames) + + // Compare images + for _, name := range oldNames { + oldSvc := projectOld.Services[name] + newSvc := projectNew.Services[name] + assert.Equal(t, oldSvc.Image, newSvc.Image, "image mismatch for service %s", name) + } + + // Compare environment values + for _, name := range oldNames { + oldSvc := projectOld.Services[name] + newSvc := projectNew.Services[name] + for key, oldVal := range oldSvc.Environment { + newVal, ok := newSvc.Environment[key] + assert.Assert(t, ok, "env key %s missing in service %s", key, name) + if oldVal != nil && newVal != nil { + assert.Equal(t, *oldVal, *newVal, "env %s mismatch for service %s", key, name) + } + } + } + + // Compare labels + oldWeb := projectOld.Services["web"] + newWeb := projectNew.Services["web"] + assert.DeepEqual(t, map[string]string(oldWeb.Labels), map[string]string(newWeb.Labels)) + + // Compare depends_on conditions + for depName, oldDep := range oldWeb.DependsOn { + newDep, ok := newWeb.DependsOn[depName] + assert.Assert(t, ok, "depends_on %s missing", depName) + assert.Equal(t, oldDep.Condition, newDep.Condition) + assert.Equal(t, oldDep.Required, newDep.Required) + } + + // Compare volumes and networks exist + for name := range projectOld.Volumes { + _, ok := projectNew.Volumes[name] + assert.Assert(t, ok, "volume %s missing in new project", name) + } + for name := range projectOld.Networks { + _, ok := projectNew.Networks[name] + assert.Assert(t, ok, "network %s missing in new project", name) + } +} + +func TestLoadLazyModel_ErrorIncludesLineInfo(t *testing.T) { + // This YAML has an invalid value for "ports" — a boolean instead of string/number. + // The error should include line and column information. + yamlContent := `name: errtest +services: + web: + image: nginx + mem_limit: not_a_size +` + workingDir, err := os.Getwd() + assert.NilError(t, err) + + configDetails := types.ConfigDetails{ + WorkingDir: workingDir, + ConfigFiles: []types.ConfigFile{ + {Filename: "docker-compose.yml", Content: []byte(yamlContent)}, + }, + Environment: map[string]string{}, + } + model, err := LoadLazyModel(context.TODO(), configDetails, func(o *Options) { + o.SkipConsistencyCheck = true + }) + assert.NilError(t, err) + + _, err = model.Resolve() + assert.Assert(t, err != nil, "expected an error for invalid mem_limit") + + errMsg := err.Error() + t.Logf("error message: %s", errMsg) + + // The error message should contain the source filename and line/column info + assert.Assert(t, strings.Contains(errMsg, "docker-compose.yml"), + "expected error to contain source filename") + assert.Assert(t, strings.Contains(errMsg, "line 5"), + "expected error to contain line number") + assert.Assert(t, strings.Contains(errMsg, "column 16"), + "expected error to contain column number") +} + +func TestLoadLazyModel_ErrorIncludesLineInfo_InvalidDuration(t *testing.T) { + yamlContent := `name: errtest +services: + web: + image: nginx + healthcheck: + test: ["CMD", "true"] + interval: not_a_duration +` + workingDir, err := os.Getwd() + assert.NilError(t, err) + + configDetails := types.ConfigDetails{ + WorkingDir: workingDir, + ConfigFiles: []types.ConfigFile{ + {Filename: "myapp/compose.yaml", Content: []byte(yamlContent)}, + }, + Environment: map[string]string{}, + } + model, err := LoadLazyModel(context.TODO(), configDetails, func(o *Options) { + o.SkipConsistencyCheck = true + }) + assert.NilError(t, err) + + _, err = model.Resolve() + assert.Assert(t, err != nil, "expected an error for invalid duration") + + errMsg := err.Error() + t.Logf("error message: %s", errMsg) + + // The error message should contain the source filename and line/column info + assert.Assert(t, strings.Contains(errMsg, "myapp/compose.yaml"), + "expected error to contain source filename") + assert.Assert(t, strings.Contains(errMsg, "line 7"), + "expected error to contain line number") + assert.Assert(t, strings.Contains(errMsg, "column 17"), + "expected error to contain column number") +} + +func TestLoadLazyModel_ErrorMultipleFiles(t *testing.T) { + // First file is valid, second file has an error on a specific line. + // The error should reference the second file's name and the correct line. + base := `name: multitest +services: + web: + image: nginx +` + overrideContent := `services: + web: + mem_limit: not_a_size +` + workingDir, err := os.Getwd() + assert.NilError(t, err) + + configDetails := types.ConfigDetails{ + WorkingDir: workingDir, + ConfigFiles: []types.ConfigFile{ + {Filename: "docker-compose.yml", Content: []byte(base)}, + {Filename: "docker-compose.override.yml", Content: []byte(overrideContent)}, + }, + Environment: map[string]string{}, + } + model, err := LoadLazyModel(context.TODO(), configDetails, func(o *Options) { + o.SkipConsistencyCheck = true + }) + assert.NilError(t, err) + + _, err = model.Resolve() + assert.Assert(t, err != nil, "expected an error for invalid mem_limit in override") + + errMsg := err.Error() + t.Logf("error message: %s", errMsg) + + // The error should reference the override file (last file merged) + assert.Assert(t, strings.Contains(errMsg, "docker-compose.override.yml"), + "expected error to contain override filename, got: %s", errMsg) + // Should contain line info + assert.Assert(t, strings.Contains(errMsg, "line "), + "expected error to contain line number, got: %s", errMsg) + assert.Assert(t, strings.Contains(errMsg, "column "), + "expected error to contain column number, got: %s", errMsg) +} + +func TestLoadLazyModel_ErrorExtends(t *testing.T) { + // Create a temporary directory with a base file containing an error. + // The main file extends from it. The error should reference the base file. + tmpDir := t.TempDir() + + baseContent := `services: + base: + image: nginx + mem_limit: not_a_size +` + err := os.WriteFile(filepath.Join(tmpDir, "base.yml"), []byte(baseContent), 0o644) + assert.NilError(t, err) + + mainContent := `name: extendstest +services: + web: + extends: + file: base.yml + service: base +` + configDetails := types.ConfigDetails{ + WorkingDir: tmpDir, + ConfigFiles: []types.ConfigFile{ + {Filename: filepath.Join(tmpDir, "compose.yml"), Content: []byte(mainContent)}, + }, + Environment: map[string]string{}, + } + model, err := LoadLazyModel(context.TODO(), configDetails, func(o *Options) { + o.SkipConsistencyCheck = true + }) + assert.NilError(t, err) + + _, err = model.Resolve() + assert.Assert(t, err != nil, "expected an error for invalid mem_limit from extended service") + + errMsg := err.Error() + t.Logf("error message: %s", errMsg) + + // The error should contain line info (from the yaml node that has the bad value) + assert.Assert(t, strings.Contains(errMsg, "line "), + "expected error to contain line number, got: %s", errMsg) + assert.Assert(t, strings.Contains(errMsg, "column "), + "expected error to contain column number, got: %s", errMsg) +} + +func TestLoadLazyModel_IncludeEnvFile(t *testing.T) { + // An included file uses a variable ${WORKER_IMAGE} that is defined only + // in the env_file specified by the include directive. The main file does + // NOT have this variable in its environment. The included layer must be + // interpolated with the env_file values. + tmpDir := t.TempDir() + + // Create the env file with a variable unknown to the main file + err := os.WriteFile(filepath.Join(tmpDir, "worker.env"), []byte("WORKER_IMAGE=redis:7\n"), 0o644) + assert.NilError(t, err) + + // Included compose file uses ${WORKER_IMAGE} + includedContent := `services: + worker: + image: ${WORKER_IMAGE} +` + err = os.WriteFile(filepath.Join(tmpDir, "worker.yml"), []byte(includedContent), 0o644) + assert.NilError(t, err) + + // Main compose file includes worker.yml with env_file + mainContent := `name: incenvtest +include: + - path: worker.yml + env_file: worker.env +services: + web: + image: nginx +` + configDetails := types.ConfigDetails{ + WorkingDir: tmpDir, + ConfigFiles: []types.ConfigFile{ + {Filename: filepath.Join(tmpDir, "compose.yml"), Content: []byte(mainContent)}, + }, + Environment: map[string]string{}, + } + model, err := LoadLazyModel(context.TODO(), configDetails, func(o *Options) { + o.SkipConsistencyCheck = true + }) + assert.NilError(t, err) + + project, err := model.Resolve() + assert.NilError(t, err) + + // The main file's web service should be present + web := project.Services["web"] + assert.Equal(t, web.Image, "nginx") + + // The included worker service should have been interpolated with WORKER_IMAGE from worker.env + worker, ok := project.Services["worker"] + assert.Assert(t, ok, "expected 'worker' service from included file") + assert.Equal(t, worker.Image, "redis:7", + "expected included service to be interpolated with env_file variable, got: %s", worker.Image) +} + +func TestLoadLazyModel_ErrorInclude(t *testing.T) { + // Create a temporary directory with an included file containing an error. + // The error should reference the included file's name. + tmpDir := t.TempDir() + + includedContent := `name: included +services: + worker: + image: redis + mem_limit: not_a_size +` + err := os.WriteFile(filepath.Join(tmpDir, "worker.yml"), []byte(includedContent), 0o644) + assert.NilError(t, err) + + mainContent := `name: includetest +include: + - worker.yml +services: + web: + image: nginx +` + configDetails := types.ConfigDetails{ + WorkingDir: tmpDir, + ConfigFiles: []types.ConfigFile{ + {Filename: filepath.Join(tmpDir, "compose.yml"), Content: []byte(mainContent)}, + }, + Environment: map[string]string{}, + } + model, err := LoadLazyModel(context.TODO(), configDetails, func(o *Options) { + o.SkipConsistencyCheck = true + }) + assert.NilError(t, err) + + _, err = model.Resolve() + assert.Assert(t, err != nil, "expected an error for invalid mem_limit in included file") + + errMsg := err.Error() + t.Logf("error message: %s", errMsg) + + // The error should reference the included file + assert.Assert(t, strings.Contains(errMsg, "worker.yml"), + "expected error to contain included filename, got: %s", errMsg) + // Should contain line info + assert.Assert(t, strings.Contains(errMsg, "line "), + "expected error to contain line number, got: %s", errMsg) + assert.Assert(t, strings.Contains(errMsg, "column "), + "expected error to contain column number, got: %s", errMsg) +} + +func TestLoadLazyModel_MergeEnvironmentMappingAndSequence(t *testing.T) { + // First file declares environment as a mapping with variable references. + // Second file declares environment as a sequence with variable references. + // After merge (convertNodeToSequence creates new scalars from the mapping), + // all values must be correctly interpolated. + base := `name: mergeenvtest +services: + app: + image: nginx + environment: + FROM_MAP: ${MAP_VAR} + PLAIN_MAP: hello +` + overrideContent := `services: + app: + environment: + - FROM_SEQ=${SEQ_VAR} + - PLAIN_SEQ=world +` + env := map[string]string{ + "MAP_VAR": "map_value", + "SEQ_VAR": "seq_value", + } + workingDir, err := os.Getwd() + assert.NilError(t, err) + + configDetails := types.ConfigDetails{ + WorkingDir: workingDir, + ConfigFiles: []types.ConfigFile{ + {Filename: "compose.yml", Content: []byte(base)}, + {Filename: "compose.override.yml", Content: []byte(overrideContent)}, + }, + Environment: env, + } + model, err := LoadLazyModel(context.TODO(), configDetails, func(o *Options) { + o.SkipConsistencyCheck = true + }) + assert.NilError(t, err) + + project, err := model.Resolve() + assert.NilError(t, err) + + app := project.Services["app"] + + // FROM_MAP came from a mapping, was converted to "FROM_MAP=${MAP_VAR}" by merge, + // then interpolated → "map_value" + val, ok := app.Environment["FROM_MAP"] + assert.Assert(t, ok, "FROM_MAP missing") + assert.Assert(t, val != nil, "FROM_MAP is nil") + assert.Equal(t, *val, "map_value", "FROM_MAP not interpolated") + + // FROM_SEQ came from a sequence, kept as "FROM_SEQ=${SEQ_VAR}", + // then interpolated → "seq_value" + val, ok = app.Environment["FROM_SEQ"] + assert.Assert(t, ok, "FROM_SEQ missing") + assert.Assert(t, val != nil, "FROM_SEQ is nil") + assert.Equal(t, *val, "seq_value", "FROM_SEQ not interpolated") + + // Plain values should pass through unchanged + val, ok = app.Environment["PLAIN_MAP"] + assert.Assert(t, ok, "PLAIN_MAP missing") + assert.Assert(t, val != nil) + assert.Equal(t, *val, "hello") + + val, ok = app.Environment["PLAIN_SEQ"] + assert.Assert(t, ok, "PLAIN_SEQ missing") + assert.Assert(t, val != nil) + assert.Equal(t, *val, "world") +} diff --git a/loader/extends.go b/loader/extends.go deleted file mode 100644 index d85e84ba..00000000 --- a/loader/extends.go +++ /dev/null @@ -1,221 +0,0 @@ -/* - Copyright 2020 The Compose Specification Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package loader - -import ( - "context" - "fmt" - "path/filepath" - - "github.com/compose-spec/compose-go/v2/consts" - "github.com/compose-spec/compose-go/v2/override" - "github.com/compose-spec/compose-go/v2/paths" - "github.com/compose-spec/compose-go/v2/types" -) - -func ApplyExtends(ctx context.Context, dict map[string]any, opts *Options, tracker *cycleTracker, post PostProcessor) error { - a, ok := dict["services"] - if !ok { - return nil - } - services, ok := a.(map[string]any) - if !ok { - return fmt.Errorf("services must be a mapping") - } - for name := range services { - merged, err := applyServiceExtends(ctx, name, services, opts, tracker, post) - if err != nil { - return err - } - services[name] = merged - } - dict["services"] = services - return nil -} - -func applyServiceExtends(ctx context.Context, name string, services map[string]any, opts *Options, tracker *cycleTracker, post PostProcessor) (any, error) { - s := services[name] - if s == nil { - return nil, nil - } - service, ok := s.(map[string]any) - if !ok { - return nil, fmt.Errorf("services.%s must be a mapping", name) - } - extends, ok := service["extends"] - if !ok { - return s, nil - } - filename := ctx.Value(consts.ComposeFileKey{}).(string) - var ( - err error - ref string - file any - ) - switch v := extends.(type) { - case map[string]any: - ref, ok = v["service"].(string) - if !ok { - return nil, fmt.Errorf("extends.%s.service is required", name) - } - file = v["file"] - opts.ProcessEvent("extends", v) - case string: - ref = v - opts.ProcessEvent("extends", map[string]any{"service": ref}) - } - - var ( - base any - processor = post - ) - - if file != nil { - refFilename := file.(string) - services, processor, err = getExtendsBaseFromFile(ctx, name, ref, filename, refFilename, opts, tracker) - if err != nil { - return nil, err - } - filename = refFilename - } else { - _, ok := services[ref] - if !ok { - return nil, fmt.Errorf("cannot extend service %q in %s: service %q not found", name, filename, ref) - } - } - - tracker, err = tracker.Add(filename, name) - if err != nil { - return nil, err - } - - // recursively apply `extends` - base, err = applyServiceExtends(ctx, ref, services, opts, tracker, processor) - if err != nil { - return nil, err - } - - if base == nil { - return service, nil - } - source := deepClone(base).(map[string]any) - - err = post.Apply(map[string]any{ - "services": map[string]any{ - name: source, - }, - }) - if err != nil { - return nil, err - } - - merged, err := override.ExtendService(source, service) - if err != nil { - return nil, err - } - - delete(merged, "extends") - services[name] = merged - return merged, nil -} - -func getExtendsBaseFromFile( - ctx context.Context, - name, ref string, - path, refPath string, - opts *Options, - ct *cycleTracker, -) (map[string]any, PostProcessor, error) { - for _, loader := range opts.ResourceLoaders { - if !loader.Accept(refPath) { - continue - } - local, err := loader.Load(ctx, refPath) - if err != nil { - return nil, nil, err - } - localdir := filepath.Dir(local) - relworkingdir := loader.Dir(refPath) - - extendsOpts := opts.clone() - // replace localResourceLoader with a new flavour, using extended file base path - extendsOpts.ResourceLoaders = append(opts.RemoteResourceLoaders(), localResourceLoader{ - WorkingDir: localdir, - }) - extendsOpts.ResolvePaths = false // we do relative path resolution after file has been loaded - extendsOpts.SkipNormalization = true - extendsOpts.SkipConsistencyCheck = true - extendsOpts.SkipInclude = true - extendsOpts.SkipExtends = true // we manage extends recursively based on raw service definition - extendsOpts.SkipValidation = true // we validate the merge result - extendsOpts.SkipDefaultValues = true - source, processor, err := loadYamlFile(ctx, types.ConfigFile{Filename: local}, - extendsOpts, relworkingdir, nil, ct, map[string]any{}, nil) - if err != nil { - return nil, nil, err - } - m, ok := source["services"] - if !ok { - return nil, nil, fmt.Errorf("cannot extend service %q in %s: no services section", name, local) - } - services, ok := m.(map[string]any) - if !ok { - return nil, nil, fmt.Errorf("cannot extend service %q in %s: services must be a mapping", name, local) - } - _, ok = services[ref] - if !ok { - return nil, nil, fmt.Errorf( - "cannot extend service %q in %s: service %q not found in %s", - name, - path, - ref, - refPath, - ) - } - - var remotes []paths.RemoteResource - for _, loader := range opts.RemoteResourceLoaders() { - remotes = append(remotes, loader.Accept) - } - err = paths.ResolveRelativePaths(source, relworkingdir, remotes) - if err != nil { - return nil, nil, err - } - - return services, processor, nil - } - return nil, nil, fmt.Errorf("cannot read %s", refPath) -} - -func deepClone(value any) any { - switch v := value.(type) { - case []any: - cp := make([]any, len(v)) - for i, e := range v { - cp[i] = deepClone(e) - } - return cp - case map[string]any: - cp := make(map[string]any, len(v)) - for k, e := range v { - cp[k] = deepClone(e) - } - return cp - default: - return value - } -} diff --git a/loader/fix.go b/loader/fix.go deleted file mode 100644 index 7a6e88d8..00000000 --- a/loader/fix.go +++ /dev/null @@ -1,36 +0,0 @@ -/* - Copyright 2020 The Compose Specification Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package loader - -// fixEmptyNotNull is a workaround for https://github.com/xeipuuv/gojsonschema/issues/141 -// as go-yaml `[]` will load as a `[]any(nil)`, which is not the same as an empty array -func fixEmptyNotNull(value any) interface{} { - switch v := value.(type) { - case []any: - if v == nil { - return []any{} - } - for i, e := range v { - v[i] = fixEmptyNotNull(e) - } - case map[string]any: - for k, e := range v { - v[k] = fixEmptyNotNull(e) - } - } - return value -} diff --git a/loader/include.go b/loader/include.go deleted file mode 100644 index ff310447..00000000 --- a/loader/include.go +++ /dev/null @@ -1,223 +0,0 @@ -/* - Copyright 2020 The Compose Specification Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package loader - -import ( - "context" - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/compose-spec/compose-go/v2/dotenv" - interp "github.com/compose-spec/compose-go/v2/interpolation" - "github.com/compose-spec/compose-go/v2/override" - "github.com/compose-spec/compose-go/v2/tree" - "github.com/compose-spec/compose-go/v2/types" -) - -// loadIncludeConfig parse the required config from raw yaml -func loadIncludeConfig(source any) ([]types.IncludeConfig, error) { - if source == nil { - return nil, nil - } - configs, ok := source.([]any) - if !ok { - return nil, fmt.Errorf("`include` must be a list, got %s", source) - } - for i, config := range configs { - if v, ok := config.(string); ok { - configs[i] = map[string]any{ - "path": v, - } - } - } - var requires []types.IncludeConfig - err := Transform(source, &requires) - return requires, err -} - -func ApplyInclude(ctx context.Context, workingDir string, environment types.Mapping, model map[string]any, options *Options, included []string, processor PostProcessor) error { - includeConfig, err := loadIncludeConfig(model["include"]) - if err != nil { - return err - } - - for _, r := range includeConfig { - for _, listener := range options.Listeners { - listener("include", map[string]any{ - "path": r.Path, - "workingdir": workingDir, - }) - } - - var relworkingdir string - for i, p := range r.Path { - for _, loader := range options.ResourceLoaders { - if !loader.Accept(p) { - continue - } - path, err := loader.Load(ctx, p) - if err != nil { - return err - } - p = path - - if i == 0 { // This is the "main" file, used to define project-directory. Others are overrides - - switch { - case r.ProjectDirectory == "": - relworkingdir = loader.Dir(path) - r.ProjectDirectory = filepath.Dir(path) - case !filepath.IsAbs(r.ProjectDirectory): - relworkingdir = loader.Dir(r.ProjectDirectory) - r.ProjectDirectory = filepath.Join(workingDir, r.ProjectDirectory) - - default: - relworkingdir = r.ProjectDirectory - - } - for _, f := range included { - if f == path { - included = append(included, path) - return fmt.Errorf("include cycle detected:\n%s\n include %s", included[0], strings.Join(included[1:], "\n include ")) - } - } - } - } - r.Path[i] = p - } - - loadOptions := options.clone() - loadOptions.ResolvePaths = true - loadOptions.SkipNormalization = true - loadOptions.SkipConsistencyCheck = true - loadOptions.ResourceLoaders = append(loadOptions.RemoteResourceLoaders(), localResourceLoader{ - WorkingDir: r.ProjectDirectory, - }) - - if len(r.EnvFile) == 0 { - f := filepath.Join(r.ProjectDirectory, ".env") - if s, err := os.Stat(f); err == nil && !s.IsDir() { - r.EnvFile = types.StringList{f} - } - } else { - envFile := []string{} - for _, f := range r.EnvFile { - if f == "/dev/null" { - continue - } - if !filepath.IsAbs(f) { - f = filepath.Join(workingDir, f) - s, err := os.Stat(f) - if err != nil { - return err - } - if s.IsDir() { - return fmt.Errorf("%s is not a file", f) - } - } - envFile = append(envFile, f) - } - r.EnvFile = envFile - } - - envFromFile, err := dotenv.GetEnvFromFile(environment, r.EnvFile) - if err != nil { - return err - } - - config := types.ConfigDetails{ - WorkingDir: relworkingdir, - ConfigFiles: types.ToConfigFiles(r.Path), - Environment: environment.Clone().Merge(envFromFile), - } - loadOptions.Interpolate = &interp.Options{ - Substitute: options.Interpolate.Substitute, - LookupValue: config.LookupEnv, - TypeCastMapping: options.Interpolate.TypeCastMapping, - } - imported, err := loadYamlModel(ctx, config, loadOptions, &cycleTracker{}, included) - if err != nil { - return err - } - err = importResources(imported, model, processor) - if err != nil { - return err - } - } - delete(model, "include") - return nil -} - -// importResources import into model all resources defined by imported, and report error on conflict -func importResources(source map[string]any, target map[string]any, processor PostProcessor) error { - if err := importResource(source, target, "services", processor); err != nil { - return err - } - if err := importResource(source, target, "volumes", processor); err != nil { - return err - } - if err := importResource(source, target, "networks", processor); err != nil { - return err - } - if err := importResource(source, target, "secrets", processor); err != nil { - return err - } - if err := importResource(source, target, "configs", processor); err != nil { - return err - } - if err := importResource(source, target, "models", processor); err != nil { - return err - } - return nil -} - -func importResource(source map[string]any, target map[string]any, key string, processor PostProcessor) error { - from := source[key] - if from != nil { - var to map[string]any - if v, ok := target[key]; ok { - to = v.(map[string]any) - } else { - to = map[string]any{} - } - for name, a := range from.(map[string]any) { - conflict, ok := to[name] - if !ok { - to[name] = a - continue - } - err := processor.Apply(map[string]any{ - key: map[string]any{ - name: a, - }, - }) - if err != nil { - return err - } - - merged, err := override.MergeYaml(a, conflict, tree.NewPath(key, name)) - if err != nil { - return err - } - to[name] = merged - } - target[key] = to - } - return nil -} diff --git a/loader/loader.go b/loader/loader.go index f73ad92e..479d3b61 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -24,24 +24,15 @@ import ( "io" "os" "path/filepath" - "reflect" "regexp" "slices" - "strconv" "strings" "github.com/compose-spec/compose-go/v2/consts" "github.com/compose-spec/compose-go/v2/errdefs" interp "github.com/compose-spec/compose-go/v2/interpolation" - "github.com/compose-spec/compose-go/v2/override" - "github.com/compose-spec/compose-go/v2/paths" - "github.com/compose-spec/compose-go/v2/schema" "github.com/compose-spec/compose-go/v2/template" - "github.com/compose-spec/compose-go/v2/transform" - "github.com/compose-spec/compose-go/v2/tree" "github.com/compose-spec/compose-go/v2/types" - "github.com/compose-spec/compose-go/v2/validation" - "github.com/go-viper/mapstructure/v2" "github.com/sirupsen/logrus" "go.yaml.in/yaml/v4" ) @@ -169,27 +160,6 @@ func (l localResourceLoader) isDir(path string) bool { return fileInfo.IsDir() } -func (o *Options) clone() *Options { - return &Options{ - SkipValidation: o.SkipValidation, - SkipInterpolation: o.SkipInterpolation, - SkipNormalization: o.SkipNormalization, - ResolvePaths: o.ResolvePaths, - ConvertWindowsPaths: o.ConvertWindowsPaths, - SkipConsistencyCheck: o.SkipConsistencyCheck, - SkipExtends: o.SkipExtends, - SkipInclude: o.SkipInclude, - Interpolate: o.Interpolate, - discardEnvFiles: o.discardEnvFiles, - projectName: o.projectName, - projectNameImperativelySet: o.projectNameImperativelySet, - Profiles: o.Profiles, - ResourceLoaders: o.ResourceLoaders, - KnownExtensions: o.KnownExtensions, - Listeners: o.Listeners, - } -} - func (o *Options) SetProjectName(name string, imperativelySet bool) { o.projectName = name o.projectNameImperativelySet = imperativelySet @@ -199,46 +169,6 @@ func (o Options) GetProjectName() (string, bool) { return o.projectName, o.projectNameImperativelySet } -// serviceRef identifies a reference to a service. It's used to detect cyclic -// references in "extends". -type serviceRef struct { - filename string - service string -} - -type cycleTracker struct { - loaded []serviceRef -} - -func (ct *cycleTracker) Add(filename, service string) (*cycleTracker, error) { - toAdd := serviceRef{filename: filename, service: service} - for _, loaded := range ct.loaded { - if toAdd == loaded { - // Create an error message of the form: - // Circular reference: - // service-a in docker-compose.yml - // extends service-b in docker-compose.yml - // extends service-a in docker-compose.yml - errLines := []string{ - "Circular reference:", - fmt.Sprintf(" %s in %s", ct.loaded[0].service, ct.loaded[0].filename), - } - for _, service := range append(ct.loaded[1:], toAdd) { - errLines = append(errLines, fmt.Sprintf(" extends %s in %s", service.service, service.filename)) - } - - return nil, errors.New(strings.Join(errLines, "\n")) - } - } - - var branch []serviceRef - branch = append(branch, ct.loaded...) - branch = append(branch, toAdd) - return &cycleTracker{ - loaded: branch, - }, nil -} - // WithDiscardEnvFiles sets the Options to discard the `env_file` section after resolving to // the `environment` section func WithDiscardEnvFiles(opts *Options) { @@ -257,17 +187,6 @@ func WithProfiles(profiles []string) func(*Options) { } } -// PostProcessor is used to tweak compose model based on metadata extracted during yaml Unmarshal phase -// that hardly can be implemented using go-yaml and mapstructure -type PostProcessor interface { - // Apply changes to compose model based on recorder metadata - Apply(interface{}) error -} - -type NoopPostProcessor struct{} - -func (NoopPostProcessor) Apply(interface{}) error { return nil } - // LoadConfigFiles ingests config files with ResourceLoader and returns config details with paths to local copies func LoadConfigFiles(ctx context.Context, configFiles []string, workingDir string, options ...func(*Options)) (*types.ConfigDetails, error) { if len(configFiles) < 1 { @@ -322,32 +241,30 @@ func LoadConfigFiles(ctx context.Context, configFiles []string, workingDir strin // LoadWithContext reads a ConfigDetails and returns a fully loaded configuration as a compose-go Project func LoadWithContext(ctx context.Context, configDetails types.ConfigDetails, options ...func(*Options)) (*types.Project, error) { - opts := ToOptions(&configDetails, options) - dict, err := loadModelWithContext(ctx, &configDetails, opts) + model, err := LoadLazyModel(ctx, configDetails, options...) if err != nil { return nil, err } - return ModelToProject(dict, opts, configDetails) + return model.Resolve() } // LoadModelWithContext reads a ConfigDetails and returns a fully loaded configuration as a yaml dictionary func LoadModelWithContext(ctx context.Context, configDetails types.ConfigDetails, options ...func(*Options)) (map[string]any, error) { - opts := ToOptions(&configDetails, options) - return loadModelWithContext(ctx, &configDetails, opts) -} - -// LoadModelWithContext reads a ConfigDetails and returns a fully loaded configuration as a yaml dictionary -func loadModelWithContext(ctx context.Context, configDetails *types.ConfigDetails, opts *Options) (map[string]any, error) { - if len(configDetails.ConfigFiles) < 1 { - return nil, errors.New("no compose file specified") + project, err := LoadWithContext(ctx, configDetails, options...) + if err != nil { + return nil, err } - - err := projectName(configDetails, opts) + // Marshal the typed project back to a yaml dictionary for backward compatibility + b, err := yaml.Marshal(project) if err != nil { return nil, err } - - return load(ctx, *configDetails, opts, nil) + var dict map[string]any + if err := yaml.Unmarshal(b, &dict); err != nil { + return nil, err + } + dict["name"] = project.Name + return dict, nil } func ToOptions(configDetails *types.ConfigDetails, options []func(*Options)) *Options { @@ -367,251 +284,6 @@ func ToOptions(configDetails *types.ConfigDetails, options []func(*Options)) *Op return opts } -func loadYamlModel(ctx context.Context, config types.ConfigDetails, opts *Options, ct *cycleTracker, included []string) (map[string]interface{}, error) { - var ( - dict = map[string]interface{}{} - err error - ) - workingDir, environment := config.WorkingDir, config.Environment - - for _, file := range config.ConfigFiles { - dict, _, err = loadYamlFile(ctx, file, opts, workingDir, environment, ct, dict, included) - if err != nil { - return nil, err - } - } - - if !opts.SkipDefaultValues { - dict, err = transform.SetDefaultValues(dict) - if err != nil { - return nil, err - } - } - - if !opts.SkipValidation { - if err := validation.Validate(dict); err != nil { - return nil, err - } - } - - if opts.ResolvePaths { - var remotes []paths.RemoteResource - for _, loader := range opts.RemoteResourceLoaders() { - remotes = append(remotes, loader.Accept) - } - err = paths.ResolveRelativePaths(dict, config.WorkingDir, remotes) - if err != nil { - return nil, err - } - } - ResolveEnvironment(dict, config.Environment) - - return dict, nil -} - -func loadYamlFile(ctx context.Context, - file types.ConfigFile, - opts *Options, - workingDir string, - environment types.Mapping, - ct *cycleTracker, - dict map[string]interface{}, - included []string, -) (map[string]interface{}, PostProcessor, error) { - ctx = context.WithValue(ctx, consts.ComposeFileKey{}, file.Filename) - if file.Content == nil && file.Config == nil { - content, err := os.ReadFile(file.Filename) - if err != nil { - return nil, nil, err - } - file.Content = content - } - - processRawYaml := func(raw interface{}, processor PostProcessor) error { - converted, err := convertToStringKeysRecursive(raw, "") - if err != nil { - return err - } - cfg, ok := converted.(map[string]interface{}) - if !ok { - return errors.New("top-level object must be a mapping") - } - - if opts.Interpolate != nil && !opts.SkipInterpolation { - cfg, err = interp.Interpolate(cfg, *opts.Interpolate) - if err != nil { - return err - } - } - - fixEmptyNotNull(cfg) - - if !opts.SkipExtends { - err = ApplyExtends(ctx, cfg, opts, ct, processor) - if err != nil { - return err - } - } - - if err := processor.Apply(dict); err != nil { - return err - } - - if !opts.SkipInclude { - included = append(included, file.Filename) - err = ApplyInclude(ctx, workingDir, environment, cfg, opts, included, processor) - if err != nil { - return err - } - } - - dict, err = override.Merge(dict, cfg) - if err != nil { - return err - } - - dict, err = override.EnforceUnicity(dict) - if err != nil { - return err - } - - if !opts.SkipValidation { - if err := schema.Validate(dict); err != nil { - return fmt.Errorf("validating %s: %w", file.Filename, err) - } - if _, ok := dict["version"]; ok { - opts.warnObsoleteVersion(file.Filename) - delete(dict, "version") - } - } - - dict, err = transform.Canonical(dict, opts.SkipInterpolation) - if err != nil { - return err - } - - dict = OmitEmpty(dict) - - // Canonical transformation can reveal duplicates, typically as ports can be a range and conflict with an override - dict, err = override.EnforceUnicity(dict) - return err - } - - var processor PostProcessor - if file.Config == nil { - r := bytes.NewReader(file.Content) - decoder := yaml.NewDecoder(r) - for { - var raw interface{} - reset := &ResetProcessor{target: &raw} - err := decoder.Decode(reset) - if err != nil && errors.Is(err, io.EOF) { - break - } - if err != nil { - return nil, nil, fmt.Errorf("failed to parse %s: %w", file.Filename, err) - } - processor = reset - if err := processRawYaml(raw, processor); err != nil { - return nil, nil, err - } - } - } else { - if err := processRawYaml(file.Config, NoopPostProcessor{}); err != nil { - return nil, nil, err - } - } - return dict, processor, nil -} - -func load(ctx context.Context, configDetails types.ConfigDetails, opts *Options, loaded []string) (map[string]interface{}, error) { - mainFile := configDetails.ConfigFiles[0].Filename - for _, f := range loaded { - if f == mainFile { - loaded = append(loaded, mainFile) - return nil, fmt.Errorf("include cycle detected:\n%s\n include %s", loaded[0], strings.Join(loaded[1:], "\n include ")) - } - } - - dict, err := loadYamlModel(ctx, configDetails, opts, &cycleTracker{}, nil) - if err != nil { - return nil, err - } - - if len(dict) == 0 { - return nil, errors.New("empty compose file") - } - - if !opts.SkipValidation && opts.projectName == "" { - return nil, errors.New("project name must not be empty") - } - - if !opts.SkipNormalization { - dict["name"] = opts.projectName - dict, err = Normalize(dict, configDetails.Environment) - if err != nil { - return nil, err - } - } - - return dict, nil -} - -// ModelToProject binds a canonical yaml dict into compose-go structs -func ModelToProject(dict map[string]interface{}, opts *Options, configDetails types.ConfigDetails) (*types.Project, error) { - project := &types.Project{ - Name: opts.projectName, - WorkingDir: configDetails.WorkingDir, - Environment: configDetails.Environment, - } - delete(dict, "name") // project name set by yaml must be identified by caller as opts.projectName - - var err error - dict, err = processExtensions(dict, tree.NewPath(), opts.KnownExtensions) - if err != nil { - return nil, err - } - - err = Transform(dict, project) - if err != nil { - return nil, err - } - - if opts.ConvertWindowsPaths { - for i, service := range project.Services { - for j, volume := range service.Volumes { - service.Volumes[j] = convertVolumePath(volume) - } - project.Services[i] = service - } - } - - if project, err = project.WithProfiles(opts.Profiles); err != nil { - return nil, err - } - - if !opts.SkipConsistencyCheck { - err := checkConsistency(project) - if err != nil { - return nil, err - } - } - - if !opts.SkipResolveEnvironment { - project, err = project.WithServicesEnvironmentResolved(opts.discardEnvFiles) - if err != nil { - return nil, err - } - } - - project, err = project.WithServicesLabelsResolved(opts.discardEnvFiles) - if err != nil { - return nil, err - } - - return project, nil -} - func InvalidProjectNameErr(v string) error { return fmt.Errorf( "invalid project name %q: must consist only of lowercase alphanumeric characters, hyphens, and underscores as well as start with a letter or number", @@ -703,186 +375,6 @@ func NormalizeProjectName(s string) string { return strings.TrimLeft(s, "_-") } -var userDefinedKeys = []tree.Path{ - "services", - "services.*.depends_on", - "volumes", - "networks", - "secrets", - "configs", -} - -func processExtensions(dict map[string]any, p tree.Path, extensions map[string]any) (map[string]interface{}, error) { - extras := map[string]any{} - var err error - for key, value := range dict { - skip := false - for _, uk := range userDefinedKeys { - if p.Matches(uk) { - skip = true - break - } - } - if !skip && strings.HasPrefix(key, "x-") { - extras[key] = value - delete(dict, key) - continue - } - switch v := value.(type) { - case map[string]interface{}: - dict[key], err = processExtensions(v, p.Next(key), extensions) - if err != nil { - return nil, err - } - case []interface{}: - for i, e := range v { - if m, ok := e.(map[string]interface{}); ok { - v[i], err = processExtensions(m, p.Next(strconv.Itoa(i)), extensions) - if err != nil { - return nil, err - } - } - } - } - } - for name, val := range extras { - if typ, ok := extensions[name]; ok { - target := reflect.New(reflect.TypeOf(typ)).Elem().Interface() - err = Transform(val, &target) - if err != nil { - return nil, err - } - extras[name] = target - } - } - if len(extras) > 0 { - dict[consts.Extensions] = extras - } - return dict, nil -} - -// Transform converts the source into the target struct with compose types transformer -// and the specified transformers if any. -func Transform(source interface{}, target interface{}) error { - data := mapstructure.Metadata{} - config := &mapstructure.DecoderConfig{ - DecodeHook: mapstructure.ComposeDecodeHookFunc( - nameServices, - decoderHook, - cast, - secretConfigDecoderHook, - ), - Result: target, - TagName: "yaml", - Metadata: &data, - } - decoder, err := mapstructure.NewDecoder(config) - if err != nil { - return err - } - return decoder.Decode(source) -} - -// nameServices create implicit `name` key for convenience accessing service -func nameServices(from reflect.Value, to reflect.Value) (interface{}, error) { - if to.Type() == reflect.TypeOf(types.Services{}) { - nameK := reflect.ValueOf("name") - iter := from.MapRange() - for iter.Next() { - name := iter.Key() - elem := iter.Value() - elem.Elem().SetMapIndex(nameK, name) - } - } - return from.Interface(), nil -} - -func secretConfigDecoderHook(from, to reflect.Type, data interface{}) (interface{}, error) { - // Check if the input is a map and we're decoding into a SecretConfig - if from.Kind() == reflect.Map && to == reflect.TypeOf(types.SecretConfig{}) { - if v, ok := data.(map[string]interface{}); ok { - if ext, ok := v[consts.Extensions].(map[string]interface{}); ok { - if val, ok := ext[types.SecretConfigXValue].(string); ok { - // Return a map with the Content field populated - v["Content"] = val - delete(ext, types.SecretConfigXValue) - - if len(ext) == 0 { - delete(v, consts.Extensions) - } - } - } - } - } - - // Return the original data so the rest is handled by default mapstructure logic - return data, nil -} - -// keys need to be converted to strings for jsonschema -func convertToStringKeysRecursive(value interface{}, keyPrefix string) (interface{}, error) { - if mapping, ok := value.(map[string]interface{}); ok { - for key, entry := range mapping { - var newKeyPrefix string - if keyPrefix == "" { - newKeyPrefix = key - } else { - newKeyPrefix = fmt.Sprintf("%s.%s", keyPrefix, key) - } - convertedEntry, err := convertToStringKeysRecursive(entry, newKeyPrefix) - if err != nil { - return nil, err - } - mapping[key] = convertedEntry - } - return mapping, nil - } - if mapping, ok := value.(map[interface{}]interface{}); ok { - dict := make(map[string]interface{}) - for key, entry := range mapping { - str, ok := key.(string) - if !ok { - return nil, formatInvalidKeyError(keyPrefix, key) - } - var newKeyPrefix string - if keyPrefix == "" { - newKeyPrefix = str - } else { - newKeyPrefix = fmt.Sprintf("%s.%s", keyPrefix, str) - } - convertedEntry, err := convertToStringKeysRecursive(entry, newKeyPrefix) - if err != nil { - return nil, err - } - dict[str] = convertedEntry - } - return dict, nil - } - if list, ok := value.([]interface{}); ok { - var convertedList []interface{} - for index, entry := range list { - newKeyPrefix := fmt.Sprintf("%s[%d]", keyPrefix, index) - convertedEntry, err := convertToStringKeysRecursive(entry, newKeyPrefix) - if err != nil { - return nil, err - } - convertedList = append(convertedList, convertedEntry) - } - return convertedList, nil - } - return value, nil -} - -func formatInvalidKeyError(keyPrefix string, key interface{}) error { - var location string - if keyPrefix == "" { - location = "at top level" - } else { - location = fmt.Sprintf("in %s", keyPrefix) - } - return fmt.Errorf("non-string key %s: %#v", location, key) -} - // Windows path, c:\\my\\path\\shiny, need to be changed to be compatible with // the Engine. Volume path are expected to be linux style /c/my/path/shiny/ func convertVolumePath(volume types.ServiceVolumeConfig) types.ServiceVolumeConfig { diff --git a/loader/loader_test.go b/loader/loader_test.go index 20fd147f..42fc33aa 100644 --- a/loader/loader_test.go +++ b/loader/loader_test.go @@ -109,7 +109,7 @@ networks: - subnet: 172.28.0.0/16 ` -var samplePortsConfig = []types.ServicePortConfig{ +var samplePortsConfig = types.ServicePorts{ { Mode: "ingress", Target: 8080, @@ -2281,7 +2281,7 @@ services: capabilities: ["directX"] `) assert.NilError(t, err) - assert.DeepEqual(t, p.Services["test"].Gpus, []types.DeviceRequest{ + assert.DeepEqual(t, p.Services["test"].Gpus, types.GpuDevices{ { Driver: "nvidia", Count: -1, @@ -3019,7 +3019,7 @@ services: customLoader{prefix: "remote"}, } }) - assert.ErrorContains(t, err, "Circular reference") + assert.ErrorContains(t, err, "circular reference") } func TestLoadMulmtiDocumentYaml(t *testing.T) { @@ -3889,10 +3889,10 @@ models: ContextSize: 1024, RuntimeFlags: []string{"--some-flag"}, }) - assert.DeepEqual(t, p.Services["test_array"].Models, map[string]*types.ServiceModelConfig{ + assert.DeepEqual(t, p.Services["test_array"].Models, types.ServiceModels{ "foo": nil, }) - assert.DeepEqual(t, p.Services["test_mapping"].Models, map[string]*types.ServiceModelConfig{ + assert.DeepEqual(t, p.Services["test_mapping"].Models, types.ServiceModels{ "foo": { EndpointVariable: "MODEL_URL", ModelVariable: "MODEL", diff --git a/loader/loader_yaml_test.go b/loader/loader_yaml_test.go deleted file mode 100644 index a4318ac9..00000000 --- a/loader/loader_yaml_test.go +++ /dev/null @@ -1,124 +0,0 @@ -/* - Copyright 2020 The Compose Specification Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package loader - -import ( - "context" - "testing" - - "github.com/compose-spec/compose-go/v2/types" - "gotest.tools/v3/assert" -) - -func TestParseYAMLFiles(t *testing.T) { - model, err := loadYamlModel(context.TODO(), types.ConfigDetails{ - ConfigFiles: []types.ConfigFile{ - { - Filename: "test.yaml", - Content: []byte(` -x-extension: - test1: first - -services: - test: - image: foo - command: echo hello - init: true -`), - }, - { - Filename: "override.yaml", - Content: []byte(` -x-extension: - test2: second - -services: - test: - image: bar - command: echo world - init: false -`), - }, - }, - }, &Options{}, &cycleTracker{}, nil) - assert.NilError(t, err) - assert.DeepEqual(t, model, map[string]interface{}{ - "services": map[string]interface{}{ - "test": map[string]interface{}{ - "image": "bar", - "command": "echo world", - "init": false, - }, - }, - "x-extension": map[string]interface{}{ - "test1": "first", - "test2": "second", - }, - }) -} - -func TestParseYAMLFilesMergeOverride(t *testing.T) { - model, err := loadYamlModel(context.TODO(), types.ConfigDetails{ - ConfigFiles: []types.ConfigFile{ - { - Filename: "override.yaml", - Content: []byte(` -services: - base: - configs: - - source: credentials - target: /credentials/file1 - x: &x - extends: - base - configs: !override - - source: credentials - target: /literally-anywhere-else - - y: - <<: *x - -configs: - credentials: - content: | - dummy value -`), - }, - }, - }, &Options{}, &cycleTracker{}, nil) - assert.NilError(t, err) - assert.DeepEqual(t, model, map[string]interface{}{ - "configs": map[string]interface{}{"credentials": map[string]interface{}{"content": string("dummy value\n")}}, - "services": map[string]interface{}{ - "base": map[string]interface{}{ - "configs": []interface{}{ - map[string]interface{}{"source": string("credentials"), "target": string("/credentials/file1")}, - }, - }, - "x": map[string]interface{}{ - "configs": []interface{}{ - map[string]interface{}{"source": string("credentials"), "target": string("/literally-anywhere-else")}, - }, - }, - "y": map[string]interface{}{ - "configs": []interface{}{ - map[string]interface{}{"source": string("credentials"), "target": string("/literally-anywhere-else")}, - }, - }, - }, - }) -} diff --git a/loader/mapstructure.go b/loader/mapstructure.go deleted file mode 100644 index e5b902ab..00000000 --- a/loader/mapstructure.go +++ /dev/null @@ -1,79 +0,0 @@ -/* - Copyright 2020 The Compose Specification Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package loader - -import ( - "reflect" - "strconv" -) - -// comparable to yaml.Unmarshaler, decoder allow a type to define it's own custom logic to convert value -// see https://github.com/mitchellh/mapstructure/pull/294 -type decoder interface { - DecodeMapstructure(interface{}) error -} - -// see https://github.com/mitchellh/mapstructure/issues/115#issuecomment-735287466 -// adapted to support types derived from built-in types, as DecodeMapstructure would not be able to mutate internal -// value, so need to invoke DecodeMapstructure defined by pointer to type -func decoderHook(from reflect.Value, to reflect.Value) (interface{}, error) { - // If the destination implements the decoder interface - u, ok := to.Interface().(decoder) - if !ok { - // for non-struct types we need to invoke func (*type) DecodeMapstructure() - if to.CanAddr() { - pto := to.Addr() - u, ok = pto.Interface().(decoder) - } - if !ok { - return from.Interface(), nil - } - } - // If it is nil and a pointer, create and assign the target value first - if to.Type().Kind() == reflect.Ptr && to.IsNil() { - to.Set(reflect.New(to.Type().Elem())) - u = to.Interface().(decoder) - } - // Call the custom DecodeMapstructure method - if err := u.DecodeMapstructure(from.Interface()); err != nil { - return to.Interface(), err - } - return to.Interface(), nil -} - -func cast(from reflect.Value, to reflect.Value) (interface{}, error) { - switch from.Type().Kind() { - case reflect.String: - switch to.Kind() { - case reflect.Bool: - return toBoolean(from.String()) - case reflect.Int: - return toInt(from.String()) - case reflect.Int64: - return toInt64(from.String()) - case reflect.Float32: - return toFloat32(from.String()) - case reflect.Float64: - return toFloat(from.String()) - } - case reflect.Int: - if to.Kind() == reflect.String { - return strconv.FormatInt(from.Int(), 10), nil - } - } - return from.Interface(), nil -} diff --git a/loader/mapstructure_test.go b/loader/mapstructure_test.go deleted file mode 100644 index 4638ae08..00000000 --- a/loader/mapstructure_test.go +++ /dev/null @@ -1,65 +0,0 @@ -/* - Copyright 2020 The Compose Specification Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package loader - -import ( - "testing" - - "github.com/compose-spec/compose-go/v2/types" - "github.com/go-viper/mapstructure/v2" - "gotest.tools/v3/assert" -) - -func TestDecodeMapStructure(t *testing.T) { - var target types.ServiceConfig - data := mapstructure.Metadata{} - config := &mapstructure.DecoderConfig{ - Result: &target, - TagName: "yaml", - Metadata: &data, - DecodeHook: mapstructure.ComposeDecodeHookFunc(decoderHook), - } - decoder, err := mapstructure.NewDecoder(config) - assert.NilError(t, err) - err = decoder.Decode(map[string]interface{}{ - "mem_limit": "640k", - "command": "echo hello", - "stop_grace_period": "60s", - "labels": []interface{}{ - "FOO=BAR", - }, - "deploy": map[string]interface{}{ - "labels": map[string]interface{}{ - "FOO": "BAR", - "BAZ": nil, - "QIX": 2, - "ZOT": true, - }, - }, - }) - assert.NilError(t, err) - assert.Equal(t, target.MemLimit, types.UnitBytes(640*1024)) - assert.DeepEqual(t, target.Command, types.ShellCommand{"echo", "hello"}) - assert.Equal(t, *target.StopGracePeriod, types.Duration(60_000_000_000)) - assert.DeepEqual(t, target.Labels, types.Labels{"FOO": "BAR"}) - assert.DeepEqual(t, target.Deploy.Labels, types.Labels{ - "FOO": "BAR", - "BAZ": "", - "QIX": "2", - "ZOT": "true", - }) -} diff --git a/loader/normalize_project.go b/loader/normalize_project.go new file mode 100644 index 00000000..c9f8c4e1 --- /dev/null +++ b/loader/normalize_project.go @@ -0,0 +1,371 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package loader + +import ( + "fmt" + "path" + "path/filepath" + "strings" + + "github.com/compose-spec/compose-go/v2/paths" + "github.com/compose-spec/compose-go/v2/types" +) + +// normalizeProject applies post-decode normalization on a typed Project. +// This replaces the old Normalize() that operated on map[string]any. +func normalizeProject(project *types.Project) { + // 1. Set service names from map keys + for name, svc := range project.Services { + svc.Name = name + project.Services[name] = svc + } + + // 2. Normalize networks (inject implicit default network) + normalizeProjectNetworks(project) + + // 3. Set resource names from map keys with project name prefix + setResourceNames(project) + + for name, svc := range project.Services { + // 4. Normalize pull_policy + if svc.PullPolicy == types.PullPolicyIfNotPresent { + svc.PullPolicy = types.PullPolicyMissing + } + + // 5. Build defaults + if svc.Build != nil { + if svc.Build.Context == "" { + svc.Build.Context = "." + } + if svc.Build.Dockerfile == "" && svc.Build.DockerfileInline == "" { + svc.Build.Dockerfile = "Dockerfile" + } + } + + // 6. Infer depends_on from links, network_mode, volumes_from + inferDependsOn(&svc) + + // 7. Clean volume paths + for i, vol := range svc.Volumes { + svc.Volumes[i].Target = path.Clean(vol.Target) + if vol.Source != "" { + if path.IsAbs(vol.Source) { + // Preserve Unix-style absolute paths (e.g. /opt/data) + // which filepath.Clean would convert to \opt\data on Windows + svc.Volumes[i].Source = path.Clean(vol.Source) + } else { + svc.Volumes[i].Source = filepath.Clean(vol.Source) + } + } + } + + // 8. Default values for ports + for i, p := range svc.Ports { + if p.Protocol == "" { + svc.Ports[i].Protocol = "tcp" + } + if p.Mode == "" { + svc.Ports[i].Mode = "ingress" + } + } + + // 9. Default values for secrets/configs mounts + for i, s := range svc.Secrets { + if s.Target == "" { + svc.Secrets[i].Target = fmt.Sprintf("/run/secrets/%s", s.Source) + } + } + + // 10. Default values for volume bind + for i, vol := range svc.Volumes { + if vol.Type == types.VolumeTypeBind && vol.Bind != nil && !vol.Bind.CreateHostPath { + // The old pipeline sets create_host_path=true by default. + // In the new pipeline, the zero value is false. We set it to true + // only when Bind section exists but create_host_path was not explicitly set. + svc.Volumes[i].Bind.CreateHostPath = true + } + } + + // 11. Default values for device requests + if svc.Deploy != nil && svc.Deploy.Resources.Reservations != nil { + for i, dev := range svc.Deploy.Resources.Reservations.Devices { + if dev.Count == 0 && len(dev.IDs) == 0 { + all := types.DeviceCount(-1) // "all" + svc.Deploy.Resources.Reservations.Devices[i].Count = all + } + } + } + for i, gpu := range svc.Gpus { + if gpu.Count == 0 && len(gpu.IDs) == 0 { + all := types.DeviceCount(-1) + svc.Gpus[i].Count = all + } + } + + project.Services[name] = svc + } + + // 12. Resolve secrets/configs environment references + resolveSecretConfigEnvironment(project) +} + +func normalizeProjectNetworks(project *types.Project) { + usesDefaultNetwork := false + + for name, svc := range project.Services { + if svc.Provider != nil { + continue + } + if svc.NetworkMode != "" { + continue + } + if len(svc.Networks) == 0 { + svc.Networks = types.ServiceNetworks{ + "default": nil, + } + usesDefaultNetwork = true + } else if _, ok := svc.Networks["default"]; ok { + usesDefaultNetwork = true + } + project.Services[name] = svc + } + + if usesDefaultNetwork { + if project.Networks == nil { + project.Networks = types.Networks{} + } + if _, ok := project.Networks["default"]; !ok { + project.Networks["default"] = types.NetworkConfig{} + } + } +} + +func setResourceNames(project *types.Project) { + setNames := func(name string, externalName string, external types.External, projectName string) string { + if name != "" { + return name + } + if bool(external) { + return externalName + } + return fmt.Sprintf("%s_%s", projectName, externalName) + } + + for key, net := range project.Networks { + net.Name = setNames(net.Name, key, net.External, project.Name) + project.Networks[key] = net + } + for key, vol := range project.Volumes { + vol.Name = setNames(vol.Name, key, vol.External, project.Name) + project.Volumes[key] = vol + } + for key, cfg := range project.Configs { + cfg.Name = setNames(cfg.Name, key, cfg.External, project.Name) + project.Configs[key] = cfg + } + for key, sec := range project.Secrets { + sec.Name = setNames(sec.Name, key, sec.External, project.Name) + project.Secrets[key] = sec + } +} + +func inferDependsOn(svc *types.ServiceConfig) { + if svc.DependsOn == nil { + svc.DependsOn = types.DependsOnConfig{} + } + + addDep := func(name string, restart bool) { + if _, ok := svc.DependsOn[name]; !ok { + svc.DependsOn[name] = types.ServiceDependency{ + Condition: types.ServiceConditionStarted, + Restart: restart, + Required: true, + } + } + } + + // From links + for _, link := range svc.Links { + parts := strings.Split(link, ":") + addDep(parts[0], true) + } + + // From namespace references (network_mode, ipc, pid, uts, cgroup) + for _, ref := range []string{svc.NetworkMode, svc.Ipc, svc.Pid, svc.Uts, svc.Cgroup} { + if strings.HasPrefix(ref, types.ServicePrefix) { + addDep(ref[len(types.ServicePrefix):], true) + } + } + + // From volumes_from + for _, vol := range svc.VolumesFrom { + if !strings.HasPrefix(vol, types.ContainerPrefix) { + spec := strings.Split(vol, ":") + addDep(spec[0], false) + } + } + + // Remove empty depends_on to match old behavior + if len(svc.DependsOn) == 0 { + svc.DependsOn = nil + } +} + +func resolveSecretConfigEnvironment(project *types.Project) { + for name, secret := range project.Secrets { + if secret.Environment != "" { + if val, ok := project.Environment[secret.Environment]; ok { + secret.Content = val + } + project.Secrets[name] = secret + } + } + for name, config := range project.Configs { + if config.Environment != "" { + if val, ok := project.Environment[config.Environment]; ok { + config.Content = val + } + project.Configs[name] = config + } + } +} + +// isRemoteContext checks if a build context value is a remote reference (Git, HTTP, etc.) +func isRemoteContext(v string) bool { + for _, prefix := range []string{"https://", "http://", "git://", "ssh://", "github.com/", "git@"} { + if strings.HasPrefix(v, prefix) { + return true + } + } + return false +} + +// resolveProjectPaths resolves relative paths in a typed Project. +// This replaces the old paths.ResolveRelativePaths that operated on map[string]any. +func resolveProjectPaths(project *types.Project, opts *Options) error { //nolint:gocyclo + workDir := project.WorkingDir + + var remoteCheck []paths.RemoteResource + for _, loader := range opts.RemoteResourceLoaders() { + remoteCheck = append(remoteCheck, loader.Accept) + } + isRemote := func(p string) bool { + for _, check := range remoteCheck { + if check(p) { + return true + } + } + return false + } + + absPath := func(p string) string { + p = paths.ExpandUser(p) + if filepath.IsAbs(p) || path.IsAbs(p) || p == "" { + return p + } + return filepath.Join(workDir, p) + } + + for name, svc := range project.Services { + // Build context + if svc.Build != nil { + ctx := svc.Build.Context + if ctx != "" && !strings.Contains(ctx, "://") && + !strings.HasPrefix(ctx, types.ServicePrefix) && !isRemote(ctx) { + svc.Build.Context = absPath(ctx) + } + for k, v := range svc.Build.AdditionalContexts { + if !strings.Contains(v, "://") && !isRemote(v) && !isRemoteContext(v) { + svc.Build.AdditionalContexts[k] = absPath(v) + } + } + for i, key := range svc.Build.SSH { + if key.Path != "" { + svc.Build.SSH[i].Path = absPath(key.Path) + } + } + } + + // Env files + for i, ef := range svc.EnvFiles { + svc.EnvFiles[i].Path = absPath(ef.Path) + } + + // Label files + for i, lf := range svc.LabelFiles { + svc.LabelFiles[i] = absPath(lf) + } + + // Volumes (bind mounts) + for i, vol := range svc.Volumes { + if vol.Type == types.VolumeTypeBind { + if vol.Source == "" { + return fmt.Errorf(`invalid mount config for type "bind": field Source must not be empty`) + } + src := paths.ExpandUser(vol.Source) + if !filepath.IsAbs(src) && !path.IsAbs(src) && !paths.IsWindowsAbs(src) { + svc.Volumes[i].Source = filepath.Join(workDir, src) + } else { + svc.Volumes[i].Source = src + } + } + } + + // Extends file + if svc.Extends != nil && svc.Extends.File != "" && !isRemote(svc.Extends.File) { + svc.Extends.File = absPath(svc.Extends.File) + } + + // Develop watch paths + if svc.Develop != nil { + for i, w := range svc.Develop.Watch { + svc.Develop.Watch[i].Path = absPath(w.Path) + } + } + + project.Services[name] = svc + } + + // Configs + for name, cfg := range project.Configs { + if cfg.File != "" { + cfg.File = absPath(cfg.File) + project.Configs[name] = cfg + } + } + + // Secrets + for name, sec := range project.Secrets { + if sec.File != "" { + sec.File = absPath(sec.File) + project.Secrets[name] = sec + } + } + + // Volumes with local driver + bind + for name, vol := range project.Volumes { + if vol.Driver == "local" && vol.DriverOpts != nil { + if dev, ok := vol.DriverOpts["device"]; ok && vol.DriverOpts["o"] == "bind" { + vol.DriverOpts["device"] = absPath(dev) + project.Volumes[name] = vol + } + } + } + + return nil +} diff --git a/loader/omitEmpty.go b/loader/omitEmpty.go deleted file mode 100644 index eef6be8c..00000000 --- a/loader/omitEmpty.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - Copyright 2020 The Compose Specification Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package loader - -import "github.com/compose-spec/compose-go/v2/tree" - -var omitempty = []tree.Path{ - "services.*.dns", -} - -// OmitEmpty removes empty attributes which are irrelevant when unset -func OmitEmpty(yaml map[string]any) map[string]any { - cleaned := omitEmpty(yaml, tree.NewPath()) - return cleaned.(map[string]any) -} - -func omitEmpty(data any, p tree.Path) any { - switch v := data.(type) { - case map[string]any: - for k, e := range v { - if isEmpty(e) && mustOmit(p) { - delete(v, k) - continue - } - - v[k] = omitEmpty(e, p.Next(k)) - } - return v - case []any: - var c []any - for _, e := range v { - if isEmpty(e) && mustOmit(p) { - continue - } - - c = append(c, omitEmpty(e, p.Next("[]"))) - } - return c - default: - return data - } -} - -func mustOmit(p tree.Path) bool { - for _, pattern := range omitempty { - if p.Matches(pattern) { - return true - } - } - return false -} - -func isEmpty(e any) bool { - if e == nil { - return true - } - if v, ok := e.(string); ok && v == "" { - return true - } - return false -} diff --git a/loader/override_test.go b/loader/override_test.go index 7a4b62e3..4fb7a6ac 100644 --- a/loader/override_test.go +++ b/loader/override_test.go @@ -270,7 +270,7 @@ networks: Bind: &types.ServiceVolumeBind{CreateHostPath: true}, }, }) - assert.DeepEqual(t, test.Networks, map[string]*types.ServiceNetworkConfig{ + assert.DeepEqual(t, test.Networks, types.ServiceNetworks{ "zot": nil, }) } diff --git a/loader/reset_test.go b/loader/reset_test.go index 879b1288..8e138ae7 100644 --- a/loader/reset_test.go +++ b/loader/reset_test.go @@ -163,7 +163,7 @@ x-healthcheck: &healthcheck <<: *healthcheck `, expectError: true, - errorMsg: "cycle detected: node at path x-healthcheck.egress-service.egress-service references node at path x-healthcheck.egress-service", + errorMsg: "anchor 'healthcheck' value contains itself", }, } diff --git a/override/merge_node.go b/override/merge_node.go new file mode 100644 index 00000000..1f54cb1d --- /dev/null +++ b/override/merge_node.go @@ -0,0 +1,654 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package override + +import ( + "fmt" + "slices" + "strings" + + "github.com/compose-spec/compose-go/v2/format" + "github.com/compose-spec/compose-go/v2/tree" + "go.yaml.in/yaml/v4" +) + +// KeyValue is a key-value pair for building yaml.Node mappings. +type KeyValue struct { + Key string + Value *yaml.Node +} + +// FindKey finds a key in a MappingNode, returns (key node, value node) or (nil, nil). +func FindKey(node *yaml.Node, key string) (*yaml.Node, *yaml.Node) { + if node == nil || node.Kind != yaml.MappingNode { + return nil, nil + } + for i := 0; i+1 < len(node.Content); i += 2 { + if node.Content[i].Value == key { + return node.Content[i], node.Content[i+1] + } + } + return nil, nil +} + +// SetKey adds or replaces a key in a MappingNode. +func SetKey(node *yaml.Node, key string, value *yaml.Node) { + if node == nil || node.Kind != yaml.MappingNode { + return + } + for i := 0; i+1 < len(node.Content); i += 2 { + if node.Content[i].Value == key { + node.Content[i+1] = value + return + } + } + node.Content = append(node.Content, NewScalar(key), value) +} + +// DeleteKey removes a key from a MappingNode. +func DeleteKey(node *yaml.Node, key string) { + if node == nil || node.Kind != yaml.MappingNode { + return + } + content := make([]*yaml.Node, 0, len(node.Content)) + for i := 0; i+1 < len(node.Content); i += 2 { + if node.Content[i].Value != key { + content = append(content, node.Content[i], node.Content[i+1]) + } + } + node.Content = content +} + +// NewScalar creates a new ScalarNode with the given value. +func NewScalar(value string) *yaml.Node { + return &yaml.Node{Kind: yaml.ScalarNode, Value: value, Tag: "!!str"} +} + +// NewMapping creates a new MappingNode from key-value pairs. +func NewMapping(pairs ...KeyValue) *yaml.Node { + n := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + for _, p := range pairs { + n.Content = append(n.Content, NewScalar(p.Key), p.Value) + } + return n +} + +// NewSequence creates a new SequenceNode from items. +func NewSequence(items ...*yaml.Node) *yaml.Node { + return &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq", Content: items} +} + +type nodeMerger func(*yaml.Node, *yaml.Node, tree.Path) (*yaml.Node, error) + +var mergeNodeSpecials map[tree.Path]nodeMerger + +func init() { + mergeNodeSpecials = map[tree.Path]nodeMerger{ + "networks.*.ipam.config": mergeIPAMConfigNode, + "networks.*.labels": mergeToSequenceNode, + "volumes.*.labels": mergeToSequenceNode, + "services.*.annotations": mergeToSequenceNode, + "services.*.build": mergeBuildNode, + "services.*.build.args": mergeToSequenceNode, + "services.*.build.additional_contexts": mergeToSequenceNode, + "services.*.build.extra_hosts": mergeExtraHostsNode, + "services.*.build.labels": mergeToSequenceNode, + "services.*.command": overrideNode, + "services.*.depends_on": mergeDependsOnNode, + "services.*.deploy.labels": mergeToSequenceNode, + "services.*.dns": mergeToSequenceNode, + "services.*.dns_opt": mergeToSequenceNode, + "services.*.dns_search": mergeToSequenceNode, + "services.*.entrypoint": overrideNode, + "services.*.env_file": mergeToSequenceNode, + "services.*.label_file": mergeToSequenceNode, + "services.*.environment": mergeToSequenceNode, + "services.*.extra_hosts": mergeExtraHostsNode, + "services.*.healthcheck.test": overrideNode, + "services.*.labels": mergeToSequenceNode, + "services.*.volumes.*.volume.labels": mergeToSequenceNode, + "services.*.logging": mergeLoggingNode, + "services.*.models": mergeModelsNode, + "services.*.networks": mergeNetworksNode, + "services.*.sysctls": mergeToSequenceNode, + "services.*.tmpfs": mergeToSequenceNode, + "services.*.ulimits.*": mergeUlimitNode, + } +} + +// MergeNodes merges two yaml.Node trees following the same rules as MergeYaml. +func MergeNodes(base, override *yaml.Node, path tree.Path) (*yaml.Node, error) { + // Handle !reset: delete the key from the merged result + if override != nil && override.Tag == "!reset" { + return nil, nil + } + // Handle !override: skip merge, use override as-is + if override != nil && override.Tag == "!override" { + override.Tag = "" + return override, nil + } + + for pattern, merger := range mergeNodeSpecials { + if path.Matches(pattern) { + return merger(base, override, path) + } + } + if override == nil { + return base, nil + } + if base == nil { + return override, nil + } + // Treat !!null base as absent + if base.Tag == "!!null" { + return override, nil + } + switch { + case base.Kind == yaml.MappingNode && override.Kind == yaml.MappingNode: + return mergeNodeMappings(base, override, path) + case base.Kind == yaml.SequenceNode && override.Kind == yaml.SequenceNode: + result := &yaml.Node{ + Kind: yaml.SequenceNode, + Tag: base.Tag, + } + result.Content = append(result.Content, base.Content...) + result.Content = append(result.Content, override.Content...) + return result, nil + case base.Kind == yaml.MappingNode || override.Kind == yaml.MappingNode: + return nil, fmt.Errorf("cannot override %s", path) + default: + // scalar: override wins + return override, nil + } +} + +func mergeNodeMappings(base, override *yaml.Node, path tree.Path) (*yaml.Node, error) { + result := &yaml.Node{ + Kind: yaml.MappingNode, + Tag: base.Tag, + } + // Copy all base pairs + result.Content = append(result.Content, base.Content...) + + // Merge override pairs + for i := 0; i+1 < len(override.Content); i += 2 { + key := override.Content[i].Value + val := override.Content[i+1] + + // !reset on a non-existing key: nothing to do + if val.Tag == "!reset" { + DeleteKey(result, key) + continue + } + // !override on a new or existing key: use as-is + if val.Tag == "!override" { + val.Tag = "" + SetKey(result, key, val) + continue + } + + _, existing := FindKey(result, key) + if existing == nil { + result.Content = append(result.Content, override.Content[i], val) + continue + } + next := path.Next(key) + merged, err := MergeNodes(existing, val, next) + if err != nil { + return nil, err + } + if merged == nil { + // !reset: remove the key from result + DeleteKey(result, key) + continue + } + SetKey(result, key, merged) + } + return result, nil +} + +func overrideNode(_, override *yaml.Node, _ tree.Path) (*yaml.Node, error) { + return override, nil +} + +// convertNodeToSequence converts a MappingNode into a SequenceNode of "key=value" scalars, +// mirroring convertIntoSequence from merge.go. +func convertNodeToSequence(node *yaml.Node) *yaml.Node { + if node == nil { + return NewSequence() + } + switch node.Kind { + case yaml.MappingNode: + var entries []string + for i := 0; i+1 < len(node.Content); i += 2 { + k := node.Content[i].Value + v := node.Content[i+1] + if v.Tag == "!reset" { + continue // skip reset entries + } + switch { + case v.Tag == "!!null" || (v.Kind == yaml.ScalarNode && v.Value == "") && v.Tag == "": + entries = append(entries, k) + case v.Kind == yaml.SequenceNode: + for _, item := range v.Content { + entries = append(entries, fmt.Sprintf("%s=%s", k, item.Value)) + } + default: + entries = append(entries, fmt.Sprintf("%s=%s", k, v.Value)) + } + } + slices.Sort(entries) + items := make([]*yaml.Node, len(entries)) + for i, e := range entries { + items[i] = NewScalar(e) + } + return NewSequence(items...) + case yaml.SequenceNode: + return node + case yaml.ScalarNode: + return NewSequence(node) + } + return NewSequence() +} + +func mergeToSequenceNode(base, override *yaml.Node, _ tree.Path) (*yaml.Node, error) { + // Handle !reset items: collect reset keys from override mapping and filter them + resetKeys := map[string]bool{} + if override != nil && override.Kind == yaml.MappingNode { + for i := 0; i+1 < len(override.Content); i += 2 { + if override.Content[i+1].Tag == "!reset" { + resetKeys[override.Content[i].Value] = true + } + } + } + + right := convertNodeToSequence(base) + left := convertNodeToSequence(override) + + result := NewSequence() + // Add base items, filtering out reset keys + for _, item := range right.Content { + key := strings.SplitN(item.Value, "=", 2)[0] + if !resetKeys[key] { + result.Content = append(result.Content, item) + } + } + // Add override items, filtering out reset entries + for _, item := range left.Content { + key := strings.SplitN(item.Value, "=", 2)[0] + if !resetKeys[key] { + result.Content = append(result.Content, item) + } + } + return result, nil +} + +func mergeExtraHostsNode(base, override *yaml.Node, _ tree.Path) (*yaml.Node, error) { + right := convertNodeToSequence(base) + left := convertNodeToSequence(override) + + // Deduplicate: remove from left any entries already in right + rightValues := make(map[string]bool, len(right.Content)) + for _, n := range right.Content { + rightValues[n.Value] = true + } + var deduped []*yaml.Node + for _, n := range left.Content { + if !rightValues[n.Value] { + deduped = append(deduped, n) + } + } + result := NewSequence() + result.Content = append(result.Content, right.Content...) + result.Content = append(result.Content, deduped...) + return result, nil +} + +func mergeBuildNode(base, override *yaml.Node, path tree.Path) (*yaml.Node, error) { + toBuildMapping := func(n *yaml.Node) *yaml.Node { + if n == nil { + return NewMapping() + } + if n.Kind == yaml.ScalarNode { + return NewMapping(KeyValue{Key: "context", Value: n}) + } + return n + } + return mergeNodeMappings(toBuildMapping(base), toBuildMapping(override), path) +} + +func mergeDependsOnNode(base, override *yaml.Node, path tree.Path) (*yaml.Node, error) { + defaultVal := func() *yaml.Node { + return NewMapping( + KeyValue{Key: "condition", Value: NewScalar("service_started")}, + KeyValue{Key: "required", Value: &yaml.Node{Kind: yaml.ScalarNode, Value: "true", Tag: "!!bool"}}, + ) + } + right := convertNodeToMapping(base, defaultVal) + left := convertNodeToMapping(override, defaultVal) + return mergeNodeMappings(right, left, path) +} + +func mergeNetworksNode(base, override *yaml.Node, path tree.Path) (*yaml.Node, error) { + right := convertNodeToMapping(base, nil) + left := convertNodeToMapping(override, nil) + return mergeNodeMappings(right, left, path) +} + +func mergeModelsNode(base, override *yaml.Node, path tree.Path) (*yaml.Node, error) { + right := convertNodeToMapping(base, nil) + left := convertNodeToMapping(override, nil) + return mergeNodeMappings(right, left, path) +} + +// convertNodeToMapping converts a SequenceNode to a MappingNode. +// If defaultValue is non-nil, each key gets a copy of the default; otherwise keys map to null. +func convertNodeToMapping(node *yaml.Node, defaultValue func() *yaml.Node) *yaml.Node { + if node == nil { + return NewMapping() + } + switch node.Kind { + case yaml.MappingNode: + return node + case yaml.SequenceNode: + result := NewMapping() + for _, item := range node.Content { + var val *yaml.Node + if defaultValue != nil { + val = defaultValue() + } else { + val = &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!null"} + } + result.Content = append(result.Content, NewScalar(item.Value), val) + } + return result + } + return NewMapping() +} + +func mergeLoggingNode(base, override *yaml.Node, path tree.Path) (*yaml.Node, error) { + if base == nil || base.Kind != yaml.MappingNode { + return override, nil + } + if override == nil || override.Kind != yaml.MappingNode { + return base, nil + } + _, baseDriver := FindKey(base, "driver") + _, overDriver := FindKey(override, "driver") + + bothSet := baseDriver != nil && overDriver != nil + sameDriver := bothSet && baseDriver.Value == overDriver.Value + + if !bothSet || sameDriver { + return mergeNodeMappings(base, override, path) + } + return override, nil +} + +func mergeUlimitNode(base, override *yaml.Node, path tree.Path) (*yaml.Node, error) { + if base != nil && base.Kind == yaml.MappingNode && override != nil && override.Kind == yaml.MappingNode { + return mergeNodeMappings(base, override, path) + } + return override, nil +} + +func mergeIPAMConfigNode(base, override *yaml.Node, path tree.Path) (*yaml.Node, error) { + if base == nil || base.Kind != yaml.SequenceNode { + return override, fmt.Errorf("%s: unexpected node kind", path) + } + if override == nil || override.Kind != yaml.SequenceNode { + return override, fmt.Errorf("%s: unexpected node kind", path) + } + + var ipamConfigs []*yaml.Node + + for _, original := range base.Content { + right := convertNodeToMapping(original, nil) + for _, over := range override.Content { + left := convertNodeToMapping(over, nil) + _, rightSubnet := FindKey(right, "subnet") + _, leftSubnet := FindKey(left, "subnet") + + rightVal := "" + if rightSubnet != nil { + rightVal = rightSubnet.Value + } + leftVal := "" + if leftSubnet != nil { + leftVal = leftSubnet.Value + } + + if leftVal != rightVal { + // Add left if not already present + if !slices.ContainsFunc(ipamConfigs, func(n *yaml.Node) bool { + _, s := FindKey(n, "subnet") + return s != nil && s.Value == leftVal + }) { + ipamConfigs = append(ipamConfigs, left) + } + continue + } + merged, err := mergeNodeMappings(right, left, path) + if err != nil { + return nil, err + } + _, mergedSubnet := FindKey(merged, "subnet") + mergedVal := "" + if mergedSubnet != nil { + mergedVal = mergedSubnet.Value + } + idx := slices.IndexFunc(ipamConfigs, func(n *yaml.Node) bool { + _, s := FindKey(n, "subnet") + return s != nil && s.Value == mergedVal + }) + if idx >= 0 { + ipamConfigs[idx] = merged + } else { + ipamConfigs = append(ipamConfigs, merged) + } + } + } + return NewSequence(ipamConfigs...), nil +} + +// ExtendServiceNode merges a base service node with an override service node, +// using the same merge rules as ExtendService but operating on yaml.Node trees. +func ExtendServiceNode(base, override *yaml.Node) (*yaml.Node, error) { + return MergeNodes(base, override, tree.NewPath("services.x")) +} + +// nodeIndexer extracts a deduplication key from a yaml.Node sequence element. +type nodeIndexer func(*yaml.Node) string + +var unicityNodePatterns map[tree.Path]nodeIndexer + +func init() { + kv := func(n *yaml.Node) string { + if n.Kind == yaml.ScalarNode { + key, _, found := strings.Cut(n.Value, "=") + if found { + return key + } + return n.Value + } + return "" + } + target := func(n *yaml.Node) string { + if n.Kind == yaml.ScalarNode { + return n.Value + } + if n.Kind == yaml.MappingNode { + _, t := FindKey(n, "target") + if t != nil { + return t.Value + } + } + return "" + } + volume := func(n *yaml.Node) string { + if n.Kind == yaml.ScalarNode { + v, err := format.ParseVolume(n.Value) + if err != nil { + return n.Value + } + return v.Target + } + if n.Kind == yaml.MappingNode { + _, t := FindKey(n, "target") + if t != nil { + return t.Value + } + } + return "" + } + port := func(n *yaml.Node) string { + if n.Kind == yaml.ScalarNode { + return n.Value + } + if n.Kind == yaml.MappingNode { + parts := []string{} + for _, key := range []string{"host_ip", "published", "target", "protocol"} { + _, v := FindKey(n, key) + if v != nil { + parts = append(parts, v.Value) + } + } + return strings.Join(parts, ":") + } + return "" + } + envFile := func(n *yaml.Node) string { + if n.Kind == yaml.ScalarNode { + return n.Value + } + if n.Kind == yaml.MappingNode { + _, p := FindKey(n, "path") + if p != nil { + return p.Value + } + } + return "" + } + + unicityNodePatterns = map[tree.Path]nodeIndexer{ + "networks.*.labels": kv, + "services.*.annotations": kv, + "services.*.build.args": kv, + "services.*.build.additional_contexts": kv, + "services.*.build.labels": kv, + "services.*.build.tags": kv, + "services.*.cap_add": kv, + "services.*.cap_drop": kv, + "services.*.configs": target, + "services.*.deploy.labels": kv, + "services.*.dns": kv, + "services.*.dns_opt": kv, + "services.*.dns_search": kv, + "services.*.environment": kv, + "services.*.env_file": envFile, + "services.*.expose": kv, + "services.*.labels": kv, + "services.*.links": kv, + "services.*.networks.*.aliases": kv, + "services.*.networks.*.link_local_ips": kv, + "services.*.ports": port, + "services.*.profiles": kv, + "services.*.secrets": target, + "services.*.sysctls": kv, + "services.*.tmpfs": kv, + "services.*.volumes": volume, + } +} + +// EnforceUnicityNode removes duplicate elements in sequences following the +// same rules as EnforceUnicity but operating on yaml.Node trees. +func EnforceUnicityNode(node *yaml.Node, path tree.Path) { + if node == nil { + return + } + switch node.Kind { + case yaml.DocumentNode: + for _, child := range node.Content { + EnforceUnicityNode(child, path) + } + case yaml.MappingNode: + for i := 0; i+1 < len(node.Content); i += 2 { + EnforceUnicityNode(node.Content[i+1], path.Next(node.Content[i].Value)) + } + case yaml.SequenceNode: + for pattern, indexer := range unicityNodePatterns { + if !path.Matches(pattern) { + continue + } + seen := map[string]int{} + var result []*yaml.Node + for _, item := range node.Content { + key := indexer(item) + if key == "" { + result = append(result, item) + continue + } + if j, ok := seen[key]; ok { + result[j] = item + } else { + result = append(result, item) + seen[key] = len(result) - 1 + } + } + node.Content = result + return + } + for _, item := range node.Content { + EnforceUnicityNode(item, path.Next("[]")) + } + } +} + +// StripResetTags removes any remaining !reset tagged nodes from a tree. +// This is called after all merging is complete to clean up unprocessed reset markers. +func StripResetTags(node *yaml.Node) { + if node == nil { + return + } + switch node.Kind { + case yaml.DocumentNode: + for _, child := range node.Content { + StripResetTags(child) + } + case yaml.MappingNode: + var content []*yaml.Node + for i := 0; i+1 < len(node.Content); i += 2 { + val := node.Content[i+1] + if val.Tag == "!reset" { + continue + } + StripResetTags(val) + content = append(content, node.Content[i], val) + } + node.Content = content + case yaml.SequenceNode: + var content []*yaml.Node + for _, item := range node.Content { + if item.Tag == "!reset" { + continue + } + StripResetTags(item) + content = append(content, item) + } + node.Content = content + } +} diff --git a/override/merge_node_test.go b/override/merge_node_test.go new file mode 100644 index 00000000..c1443466 --- /dev/null +++ b/override/merge_node_test.go @@ -0,0 +1,279 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package override + +import ( + "encoding/json" + "testing" + + "github.com/compose-spec/compose-go/v2/tree" + "go.yaml.in/yaml/v4" + "gotest.tools/v3/assert" +) + +func TestFindKey(t *testing.T) { + node := NewMapping( + KeyValue{Key: "image", Value: NewScalar("nginx")}, + KeyValue{Key: "command", Value: NewScalar("echo hello")}, + ) + + // Find existing key + keyNode, valNode := FindKey(node, "image") + assert.Assert(t, keyNode != nil) + assert.Assert(t, valNode != nil) + assert.Equal(t, valNode.Value, "nginx") + + // Find another existing key + keyNode, valNode = FindKey(node, "command") + assert.Assert(t, keyNode != nil) + assert.Assert(t, valNode != nil) + assert.Equal(t, valNode.Value, "echo hello") + + // Missing key returns nil + keyNode, valNode = FindKey(node, "missing") + assert.Assert(t, keyNode == nil) + assert.Assert(t, valNode == nil) + + // Nil node returns nil + keyNode, valNode = FindKey(nil, "image") + assert.Assert(t, keyNode == nil) + assert.Assert(t, valNode == nil) +} + +func TestSetKey(t *testing.T) { + node := NewMapping( + KeyValue{Key: "image", Value: NewScalar("nginx")}, + ) + + // Set a new key + SetKey(node, "command", NewScalar("echo hello")) + _, valNode := FindKey(node, "command") + assert.Assert(t, valNode != nil) + assert.Equal(t, valNode.Value, "echo hello") + + // Replace an existing key + SetKey(node, "image", NewScalar("alpine")) + _, valNode = FindKey(node, "image") + assert.Assert(t, valNode != nil) + assert.Equal(t, valNode.Value, "alpine") + + // Verify total content length (2 key-value pairs = 4 nodes) + assert.Equal(t, len(node.Content), 4) +} + +func TestDeleteKey(t *testing.T) { + node := NewMapping( + KeyValue{Key: "image", Value: NewScalar("nginx")}, + KeyValue{Key: "command", Value: NewScalar("echo hello")}, + KeyValue{Key: "ports", Value: NewScalar("8080")}, + ) + + // Delete a key + DeleteKey(node, "command") + + // Verify it's gone + _, valNode := FindKey(node, "command") + assert.Assert(t, valNode == nil) + + // Verify other keys still present + _, valNode = FindKey(node, "image") + assert.Assert(t, valNode != nil) + assert.Equal(t, valNode.Value, "nginx") + + _, valNode = FindKey(node, "ports") + assert.Assert(t, valNode != nil) + assert.Equal(t, valNode.Value, "8080") + + // Content length should be 4 (2 remaining pairs) + assert.Equal(t, len(node.Content), 4) +} + +func TestNewScalar(t *testing.T) { + n := NewScalar("hello") + assert.Equal(t, n.Kind, yaml.ScalarNode) + assert.Equal(t, n.Value, "hello") + assert.Equal(t, n.Tag, "!!str") +} + +func TestNewMapping(t *testing.T) { + n := NewMapping( + KeyValue{Key: "a", Value: NewScalar("1")}, + KeyValue{Key: "b", Value: NewScalar("2")}, + ) + assert.Equal(t, n.Kind, yaml.MappingNode) + assert.Equal(t, len(n.Content), 4) // 2 key-value pairs +} + +func TestNewSequence(t *testing.T) { + n := NewSequence(NewScalar("a"), NewScalar("b"), NewScalar("c")) + assert.Equal(t, n.Kind, yaml.SequenceNode) + assert.Equal(t, len(n.Content), 3) +} + +// testMergeParity verifies that MergeNodes produces the same result as MergeYaml. +func testMergeParity(t *testing.T, baseYAML, overrideYAML string) { + t.Helper() + + // Parse as map[string]any + var baseMap, overrideMap map[string]any + assert.NilError(t, yaml.Unmarshal([]byte(baseYAML), &baseMap)) + assert.NilError(t, yaml.Unmarshal([]byte(overrideYAML), &overrideMap)) + + // Parse as yaml.Node + var baseNode, overrideNode yaml.Node + assert.NilError(t, yaml.Unmarshal([]byte(baseYAML), &baseNode)) + assert.NilError(t, yaml.Unmarshal([]byte(overrideYAML), &overrideNode)) + + // Merge maps + mapResult, err := MergeYaml(baseMap, overrideMap, tree.NewPath()) + assert.NilError(t, err) + + // Merge nodes - unwrap DocumentNode + bn := &baseNode + if bn.Kind == yaml.DocumentNode && len(bn.Content) > 0 { + bn = bn.Content[0] + } + on := &overrideNode + if on.Kind == yaml.DocumentNode && len(on.Content) > 0 { + on = on.Content[0] + } + + nodeResult, err := MergeNodes(bn, on, tree.NewPath()) + assert.NilError(t, err) + + // Decode node result to map + var nodeMap map[string]any + assert.NilError(t, nodeResult.Decode(&nodeMap)) + + // Compare using JSON marshalling to normalize types + mapJSON, _ := json.Marshal(mapResult) + nodeJSON, _ := json.Marshal(nodeMap) + assert.Equal(t, string(mapJSON), string(nodeJSON)) +} + +func TestMergeNodes_SimpleScalar(t *testing.T) { + base := ` +services: + web: + image: nginx +` + override := ` +services: + web: + image: alpine +` + testMergeParity(t, base, override) +} + +func TestMergeNodes_AddService(t *testing.T) { + base := ` +services: + web: + image: nginx +` + override := ` +services: + db: + image: postgres +` + testMergeParity(t, base, override) +} + +func TestMergeNodes_MergeLabels(t *testing.T) { + base := ` +services: + web: + image: nginx + labels: + foo: bar +` + override := ` +services: + web: + labels: + baz: qux +` + testMergeParity(t, base, override) +} + +func TestMergeNodes_MergeDependsOn(t *testing.T) { + base := ` +services: + web: + image: nginx + depends_on: + - db + db: + image: postgres +` + override := ` +services: + web: + depends_on: + - cache + cache: + image: redis +` + testMergeParity(t, base, override) +} + +func TestMergeNodes_MergeBuild(t *testing.T) { + base := ` +services: + web: + build: . +` + override := ` +services: + web: + build: + dockerfile: Dockerfile.dev +` + testMergeParity(t, base, override) +} + +func TestMergeNodes_OverrideCommand(t *testing.T) { + base := ` +services: + web: + image: nginx + command: ["nginx", "-g", "daemon off;"] +` + override := ` +services: + web: + command: ["echo", "hello"] +` + testMergeParity(t, base, override) +} + +func TestMergeNodes_MergeEnvironment(t *testing.T) { + base := ` +services: + web: + image: nginx + environment: + FOO: bar +` + override := ` +services: + web: + environment: + BAZ: qux +` + testMergeParity(t, base, override) +} diff --git a/transform/ports.go b/transform/ports.go index 68e26f3d..b13ff302 100644 --- a/transform/ports.go +++ b/transform/ports.go @@ -21,7 +21,7 @@ import ( "github.com/compose-spec/compose-go/v2/tree" "github.com/compose-spec/compose-go/v2/types" - "github.com/go-viper/mapstructure/v2" + "go.yaml.in/yaml/v4" ) func transformPorts(data any, p tree.Path, ignoreParseError bool) (any, error) { @@ -76,15 +76,12 @@ func transformPorts(data any, p tree.Path, ignoreParseError bool) (any, error) { } func encode(v any) (map[string]any, error) { - m := map[string]any{} - decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - Result: &m, - TagName: "yaml", - }) + b, err := yaml.Marshal(v) if err != nil { return nil, err } - err = decoder.Decode(v) + var m map[string]any + err = yaml.Unmarshal(b, &m) return m, err } diff --git a/transform/ports_test.go b/transform/ports_test.go index 27af78a8..2856cbe0 100644 --- a/transform/ports_test.go +++ b/transform/ports_test.go @@ -44,14 +44,14 @@ func Test_transformPorts(t *testing.T) { "mode": "ingress", "protocol": "tcp", "published": "8080", - "target": uint32(80), + "target": 80, }, map[string]any{ "host_ip": "127.0.0.1", "mode": "ingress", "protocol": "tcp", "published": "8081", - "target": uint32(81), + "target": 81, }, }, }, diff --git a/types/build.go b/types/build.go index 98931400..11f2e742 100644 --- a/types/build.go +++ b/types/build.go @@ -16,6 +16,8 @@ package types +import "go.yaml.in/yaml/v4" + // BuildConfig is a type for build type BuildConfig struct { Context string `yaml:"context,omitempty" json:"context,omitempty"` @@ -46,3 +48,22 @@ type BuildConfig struct { Extensions Extensions `yaml:"#extensions,inline,omitempty" json:"-"` } + +func (b *BuildConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + b.Context = node.Value + return nil + } + type plain BuildConfig + if err := node.Decode((*plain)(b)); err != nil { + return WrapNodeError(node, err) + } + if b.Context == "" { + b.Context = "." + } + if b.Dockerfile == "" && b.DockerfileInline == "" { + b.Dockerfile = "Dockerfile" + } + return nil +} diff --git a/types/bytes.go b/types/bytes.go index 1b2cd419..8f9e7c8d 100644 --- a/types/bytes.go +++ b/types/bytes.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/docker/go-units" + "go.yaml.in/yaml/v4" ) // UnitBytes is the bytes type @@ -35,14 +36,21 @@ func (u UnitBytes) MarshalJSON() ([]byte, error) { return []byte(fmt.Sprintf(`"%d"`, u)), nil } -func (u *UnitBytes) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case int: - *u = UnitBytes(v) - case string: - b, err := units.RAMInBytes(fmt.Sprint(value)) +func (u *UnitBytes) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Tag { + case "!!int": + var i int64 + if err := node.Decode(&i); err != nil { + return WrapNodeError(node, err) + } + *u = UnitBytes(i) + default: + b, err := units.RAMInBytes(node.Value) + if err != nil { + return WrapNodeError(node, err) + } *u = UnitBytes(b) - return err } return nil } diff --git a/types/command.go b/types/command.go index 559dc305..c2af706a 100644 --- a/types/command.go +++ b/types/command.go @@ -16,7 +16,10 @@ package types -import "github.com/mattn/go-shellwords" +import ( + "github.com/mattn/go-shellwords" + "go.yaml.in/yaml/v4" +) // ShellCommand is a string or list of string args. // @@ -67,18 +70,19 @@ func (s ShellCommand) MarshalYAML() (interface{}, error) { return []string(s), nil } -func (s *ShellCommand) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case string: - cmd, err := shellwords.Parse(v) +func (s *ShellCommand) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.ScalarNode: + cmd, err := shellwords.Parse(node.Value) if err != nil { - return err + return WrapNodeError(node, err) } *s = cmd - case []interface{}: - cmd := make([]string, len(v)) - for i, s := range v { - cmd[i] = s.(string) + case yaml.SequenceNode: + cmd := make([]string, len(node.Content)) + for i, item := range node.Content { + cmd[i] = item.Value } *s = cmd } diff --git a/types/config.go b/types/config.go index 9a0fdaf2..9cb5ed34 100644 --- a/types/config.go +++ b/types/config.go @@ -21,7 +21,7 @@ import ( "runtime" "strings" - "github.com/go-viper/mapstructure/v2" + "go.yaml.in/yaml/v4" ) // isCaseInsensitiveEnvVars is true on platforms where environment variable names are treated case-insensitively. @@ -138,7 +138,11 @@ func (c Config) MarshalJSON() ([]byte, error) { func (e Extensions) Get(name string, target interface{}) (bool, error) { if v, ok := e[name]; ok { - err := mapstructure.Decode(v, target) + b, err := yaml.Marshal(v) + if err != nil { + return true, err + } + err = yaml.Unmarshal(b, target) return true, err } return false, nil diff --git a/types/cpus.go b/types/cpus.go index f32c6e62..9c904817 100644 --- a/types/cpus.go +++ b/types/cpus.go @@ -17,29 +17,27 @@ package types import ( - "fmt" "strconv" + + "go.yaml.in/yaml/v4" ) type NanoCPUs float32 -func (n *NanoCPUs) DecodeMapstructure(a any) error { - switch v := a.(type) { - case string: - f, err := strconv.ParseFloat(v, 64) - if err != nil { - return err +func (n *NanoCPUs) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + var f float64 + if err := node.Decode(&f); err != nil { + if node.Kind == yaml.ScalarNode { + parsed, parseErr := strconv.ParseFloat(node.Value, 64) + if parseErr == nil { + *n = NanoCPUs(parsed) + return nil + } } - *n = NanoCPUs(f) - case int: - *n = NanoCPUs(v) - case float32: - *n = NanoCPUs(v) - case float64: - *n = NanoCPUs(v) - default: - return fmt.Errorf("unexpected value type %T for cpus", v) + return WrapNodeError(node, err) } + *n = NanoCPUs(f) return nil } diff --git a/types/device.go b/types/device.go index 5b30cc0c..2550a418 100644 --- a/types/device.go +++ b/types/device.go @@ -17,9 +17,10 @@ package types import ( - "fmt" "strconv" "strings" + + "go.yaml.in/yaml/v4" ) type DeviceRequest struct { @@ -30,24 +31,64 @@ type DeviceRequest struct { Options Mapping `yaml:"options,omitempty" json:"options,omitempty"` } +func (d *DeviceRequest) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + type plain DeviceRequest + if err := node.Decode((*plain)(d)); err != nil { + return WrapNodeError(node, err) + } + if d.Count == 0 && len(d.IDs) == 0 { + d.Count = -1 + } + return nil +} + type DeviceCount int64 -func (c *DeviceCount) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case int: - *c = DeviceCount(v) - case string: - if strings.ToLower(v) == "all" { - *c = -1 - return nil - } - i, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return fmt.Errorf("invalid value %q, the only value allowed is 'all' or a number", v) +func (c *DeviceCount) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode && strings.ToLower(node.Value) == "all" { + *c = -1 + return nil + } + var i int64 + if err := node.Decode(&i); err != nil { + // Try parsing as string (e.g., count: "1") + if node.Kind == yaml.ScalarNode { + parsed, parseErr := strconv.ParseInt(node.Value, 10, 64) + if parseErr == nil { + *c = DeviceCount(parsed) + return nil + } } - *c = DeviceCount(i) - default: - return fmt.Errorf("invalid type %T for device count", v) + return NodeErrorf(node, "invalid value %q, the only value allowed is 'all' or a number", node.Value) } + *c = DeviceCount(i) return nil } + +// GpuDevices is a slice of DeviceRequest that handles the short syntax +// `gpus: all` which expands to `[{count: -1}]`. +type GpuDevices []DeviceRequest + +func (g *GpuDevices) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + // Short syntax: gpus: all + *g = []DeviceRequest{{Count: -1}} + return nil + } + if node.Kind == yaml.SequenceNode { + var result []DeviceRequest + for _, item := range node.Content { + var d DeviceRequest + if err := item.Decode(&d); err != nil { + return WrapNodeError(item, err) + } + result = append(result, d) + } + *g = result + return nil + } + return NodeErrorf(node, "gpus must be a string or sequence") +} diff --git a/types/duration.go b/types/duration.go index c1c39730..78903c18 100644 --- a/types/duration.go +++ b/types/duration.go @@ -18,11 +18,11 @@ package types import ( "encoding/json" - "fmt" "strings" "time" "github.com/xhit/go-str2duration/v2" + "go.yaml.in/yaml/v4" ) // Duration is a thin wrapper around time.Duration with improved JSON marshalling @@ -32,10 +32,11 @@ func (d Duration) String() string { return time.Duration(d).String() } -func (d *Duration) DecodeMapstructure(value interface{}) error { - v, err := str2duration.ParseDuration(fmt.Sprint(value)) +func (d *Duration) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + v, err := str2duration.ParseDuration(node.Value) if err != nil { - return err + return WrapNodeError(node, err) } *d = Duration(v) return nil diff --git a/types/envfile.go b/types/envfile.go index a7d239ee..bc140f1d 100644 --- a/types/envfile.go +++ b/types/envfile.go @@ -16,8 +16,28 @@ package types +import "go.yaml.in/yaml/v4" + type EnvFile struct { Path string `yaml:"path,omitempty" json:"path,omitempty"` Required OptOut `yaml:"required,omitempty" json:"required,omitzero"` Format string `yaml:"format,omitempty" json:"format,omitempty"` } + +func (e *EnvFile) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + e.Path = node.Value + e.Required = true + return nil + } + type plain EnvFile + if err := node.Decode((*plain)(e)); err != nil { + return WrapNodeError(node, err) + } + // Default required to true if not explicitly set + if node.Kind == yaml.MappingNode && !hasKey(node, "required") { + e.Required = true + } + return nil +} diff --git a/types/healthcheck.go b/types/healthcheck.go index c6c3b37e..737d0990 100644 --- a/types/healthcheck.go +++ b/types/healthcheck.go @@ -17,7 +17,7 @@ package types import ( - "fmt" + "go.yaml.in/yaml/v4" ) // HealthCheckConfig the healthcheck configuration for a service @@ -36,18 +36,19 @@ type HealthCheckConfig struct { // HealthCheckTest is the command run to test the health of a service type HealthCheckTest []string -func (l *HealthCheckTest) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case string: - *l = []string{"CMD-SHELL", v} - case []interface{}: - seq := make([]string, len(v)) - for i, e := range v { - seq[i] = e.(string) +func (l *HealthCheckTest) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.ScalarNode: + *l = []string{"CMD-SHELL", node.Value} + case yaml.SequenceNode: + seq := make([]string, len(node.Content)) + for i, item := range node.Content { + seq[i] = item.Value } *l = seq default: - return fmt.Errorf("unexpected value type %T for healthcheck.test", value) + return NodeErrorf(node, "unexpected node kind %d for healthcheck.test", node.Kind) } return nil } diff --git a/types/hostList.go b/types/hostList.go index 9bc0fbc5..3216b62d 100644 --- a/types/hostList.go +++ b/types/hostList.go @@ -21,6 +21,8 @@ import ( "fmt" "sort" "strings" + + "go.yaml.in/yaml/v4" ) // HostsList is a list of colon-separated host-ip mappings @@ -81,47 +83,50 @@ func (h HostsList) MarshalJSON() ([]byte, error) { var hostListSerapators = []string{"=", ":"} -func (h *HostsList) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case map[string]interface{}: - list := make(HostsList, len(v)) - for i, e := range v { - if e == nil { - e = "" - } - switch t := e.(type) { - case string: - list[i] = []string{t} - case []any: - hosts := make([]string, len(t)) - for j, h := range t { - hosts[j] = fmt.Sprint(h) +func (h *HostsList) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.MappingNode: + list := make(HostsList, len(node.Content)/2) + for i := 0; i+1 < len(node.Content); i += 2 { + k := node.Content[i].Value + v := node.Content[i+1] + switch v.Kind { + case yaml.ScalarNode: + if v.Tag == "!!null" || v.Value == "" { + list[k] = []string{""} + } else { + list[k] = []string{v.Value} } - list[i] = hosts + case yaml.SequenceNode: + hosts := make([]string, len(v.Content)) + for j, item := range v.Content { + hosts[j] = item.Value + } + list[k] = hosts default: - return fmt.Errorf("unexpected value type %T for extra_hosts entry", value) + return NodeErrorf(v, "unexpected value type for extra_hosts entry") } } err := list.cleanup() if err != nil { - return err + return WrapNodeError(node, err) } *h = list - return nil - case []interface{}: - s := make([]string, len(v)) - for i, e := range v { - s[i] = fmt.Sprint(e) + case yaml.SequenceNode: + s := make([]string, len(node.Content)) + for i, item := range node.Content { + s[i] = item.Value } - list, err := NewHostsList(s) + l, err := NewHostsList(s) if err != nil { - return err + return WrapNodeError(node, err) } - *h = list - return nil + *h = l default: - return fmt.Errorf("unexpected value type %T for extra_hosts", value) + return NodeErrorf(node, "unexpected node kind %d for extra_hosts", node.Kind) } + return nil } func (h HostsList) cleanup() error { diff --git a/types/labels.go b/types/labels.go index 7ea5edc4..264d8c09 100644 --- a/types/labels.go +++ b/types/labels.go @@ -19,6 +19,8 @@ package types import ( "fmt" "strings" + + "go.yaml.in/yaml/v4" ) // Labels is a mapping type for labels @@ -60,36 +62,24 @@ func (l Labels) ToMappingWithEquals() MappingWithEquals { return mapping } -// label value can be a string | number | boolean | null (empty) -func labelValue(e interface{}) string { - if e == nil { - return "" - } - switch v := e.(type) { - case string: - return v - default: - return fmt.Sprint(v) - } -} - -func (l *Labels) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case map[string]interface{}: - labels := make(map[string]string, len(v)) - for k, e := range v { - labels[k] = labelValue(e) +func (l *Labels) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.MappingNode: + var m map[string]string + if err := node.Decode(&m); err != nil { + return err } - *l = labels - case []interface{}: - labels := make(map[string]string, len(v)) - for _, s := range v { - k, e, _ := strings.Cut(fmt.Sprint(s), "=") - labels[k] = labelValue(e) + *l = m + case yaml.SequenceNode: + labels := make(map[string]string, len(node.Content)) + for _, item := range node.Content { + k, e, _ := strings.Cut(item.Value, "=") + labels[k] = e } *l = labels default: - return fmt.Errorf("unexpected value type %T for labels", value) + return NodeErrorf(node, "unexpected node kind %d for labels", node.Kind) } return nil } diff --git a/types/labels_test.go b/types/labels_test.go index 9a4bf5c5..d633192c 100644 --- a/types/labels_test.go +++ b/types/labels_test.go @@ -19,15 +19,17 @@ package types import ( "testing" + "go.yaml.in/yaml/v4" "gotest.tools/v3/assert" ) func TestDecodeLabel(t *testing.T) { - l := Labels{} - err := l.DecodeMapstructure([]any{ - "a=b", - "c", - }) + input := ` +- a=b +- c +` + var l Labels + err := yaml.Unmarshal([]byte(input), &l) assert.NilError(t, err) assert.Equal(t, l["a"], "b") assert.Equal(t, l["c"], "") diff --git a/types/mapping.go b/types/mapping.go index fb14974f..9f7b629f 100644 --- a/types/mapping.go +++ b/types/mapping.go @@ -20,7 +20,8 @@ import ( "fmt" "sort" "strings" - "unicode" + + "go.yaml.in/yaml/v4" ) // MappingWithEquals is a mapping type that can be converted from a list of @@ -83,48 +84,39 @@ func (m MappingWithEquals) ToMapping() Mapping { return o } -func (m *MappingWithEquals) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case map[string]interface{}: - mapping := make(MappingWithEquals, len(v)) - for k, e := range v { - mapping[k] = mappingValue(e) +func (m *MappingWithEquals) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.MappingNode: + mapping := make(MappingWithEquals, len(node.Content)/2) + for i := 0; i+1 < len(node.Content); i += 2 { + k := node.Content[i].Value + v := node.Content[i+1] + if v.Tag == "!!null" { + mapping[k] = nil + } else { + s := v.Value + mapping[k] = &s + } } *m = mapping - case []interface{}: - mapping := make(MappingWithEquals, len(v)) - for _, s := range v { - k, e, ok := strings.Cut(fmt.Sprint(s), "=") - if k != "" && unicode.IsSpace(rune(k[len(k)-1])) { - return fmt.Errorf("environment variable %s is declared with a trailing space", k) - } + case yaml.SequenceNode: + mapping := make(MappingWithEquals, len(node.Content)) + for _, item := range node.Content { + k, e, ok := strings.Cut(item.Value, "=") if !ok { mapping[k] = nil } else { - mapping[k] = mappingValue(e) + mapping[k] = &e } } *m = mapping default: - return fmt.Errorf("unexpected value type %T for mapping", value) + return NodeErrorf(node, "unexpected node kind %d for mapping", node.Kind) } return nil } -// label value can be a string | number | boolean | null -func mappingValue(e interface{}) *string { - if e == nil { - return nil - } - switch v := e.(type) { - case string: - return &v - default: - s := fmt.Sprint(v) - return &s - } -} - // Mapping is a mapping type that can be converted from a list of // key[=value] strings. // For the key with an empty value (`key=`), or key without value (`key`), the @@ -189,42 +181,34 @@ func (m Mapping) Merge(o Mapping) Mapping { return m } -func (m *Mapping) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case map[string]interface{}: - mapping := make(Mapping, len(v)) - for k, e := range v { - if e == nil { - e = "" +func (m *Mapping) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.MappingNode: + mapping := make(Mapping, len(node.Content)/2) + for i := 0; i+1 < len(node.Content); i += 2 { + k := node.Content[i].Value + v := node.Content[i+1] + if v.Tag == "!!null" { + mapping[k] = "" + } else { + mapping[k] = v.Value } - mapping[k] = fmt.Sprint(e) } *m = mapping - case []interface{}: - *m = decodeMapping(v, "=") - default: - return fmt.Errorf("unexpected value type %T for mapping", value) - } - return nil -} - -// Generate a mapping by splitting strings at any of seps, which will be tried -// in-order for each input string. (For example, to allow the preferred 'host=ip' -// in 'extra_hosts', as well as 'host:ip' for backwards compatibility.) -func decodeMapping(v []interface{}, seps ...string) map[string]string { - mapping := make(Mapping, len(v)) - for _, s := range v { - for i, sep := range seps { - k, e, ok := strings.Cut(fmt.Sprint(s), sep) - if ok { - // Mapping found with this separator, stop here. - mapping[k] = e - break - } else if i == len(seps)-1 { - // No more separators to try, map to empty string. - mapping[k] = "" + case yaml.SequenceNode: + mapping := make(Mapping, len(node.Content)) + for _, item := range node.Content { + parts := strings.SplitN(item.Value, "=", 2) + if len(parts) == 1 { + mapping[parts[0]] = "" + } else { + mapping[parts[0]] = parts[1] } } + *m = mapping + default: + return NodeErrorf(node, "unexpected node kind %d for mapping", node.Kind) } - return mapping + return nil } diff --git a/types/models.go b/types/models.go index 4f144c0a..55462c8f 100644 --- a/types/models.go +++ b/types/models.go @@ -16,6 +16,38 @@ package types +import ( + "fmt" + + "go.yaml.in/yaml/v4" +) + +// ServiceModels is a map of model names to service model configurations. +// It supports both list syntax (models: [foo]) and map syntax. +type ServiceModels map[string]*ServiceModelConfig + +func (m *ServiceModels) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.SequenceNode: + models := make(ServiceModels, len(node.Content)) + for _, item := range node.Content { + models[item.Value] = nil + } + *m = models + case yaml.MappingNode: + type plain ServiceModels + var p plain + if err := node.Decode(&p); err != nil { + return err + } + *m = ServiceModels(p) + default: + return fmt.Errorf("models must be a mapping or sequence, got %v", node.Kind) + } + return nil +} + type ModelConfig struct { Name string `yaml:"name,omitempty" json:"name,omitempty"` Model string `yaml:"model,omitempty" json:"model,omitempty"` diff --git a/types/options.go b/types/options.go index 9aadb89c..e59c2109 100644 --- a/types/options.go +++ b/types/options.go @@ -16,51 +16,53 @@ package types -import "fmt" +import ( + "go.yaml.in/yaml/v4" +) // Options is a mapping type for options we pass as-is to container runtime type Options map[string]string -func (d *Options) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case map[string]interface{}: - m := make(map[string]string) - for key, e := range v { - if e == nil { - m[key] = "" - } else { - m[key] = fmt.Sprint(e) - } +func (d *Options) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind != yaml.MappingNode { + return NodeErrorf(node, "invalid node kind %d for options", node.Kind) + } + m := make(map[string]string, len(node.Content)/2) + for i := 0; i+1 < len(node.Content); i += 2 { + k := node.Content[i].Value + v := node.Content[i+1] + if v.Tag == "!!null" { + m[k] = "" + } else { + m[k] = v.Value } - *d = m - case map[string]string: - *d = v - default: - return fmt.Errorf("invalid type %T for options", value) } + *d = m return nil } // MultiOptions allow option to be repeated type MultiOptions map[string][]string -func (d *MultiOptions) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case map[string]interface{}: - m := make(map[string][]string) - for key, e := range v { - switch e := e.(type) { - case []interface{}: - for _, v := range e { - m[key] = append(m[key], fmt.Sprint(v)) - } - default: - m[key] = append(m[key], fmt.Sprint(e)) +func (d *MultiOptions) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind != yaml.MappingNode { + return NodeErrorf(node, "invalid node kind %d for options", node.Kind) + } + m := make(map[string][]string, len(node.Content)/2) + for i := 0; i+1 < len(node.Content); i += 2 { + k := node.Content[i].Value + v := node.Content[i+1] + switch v.Kind { + case yaml.SequenceNode: + for _, item := range v.Content { + m[k] = append(m[k], item.Value) } + default: + m[k] = append(m[k], v.Value) } - *d = m - default: - return fmt.Errorf("invalid type %T for options", value) } + *d = m return nil } diff --git a/types/ssh.go b/types/ssh.go index 6d0edb69..59d1daf6 100644 --- a/types/ssh.go +++ b/types/ssh.go @@ -18,6 +18,9 @@ package types import ( "fmt" + "strings" + + "go.yaml.in/yaml/v4" ) type SSHKey struct { @@ -53,21 +56,34 @@ func (s SSHKey) MarshalJSON() ([]byte, error) { return []byte(fmt.Sprintf(`%q: %s`, s.ID, s.Path)), nil } -func (s *SSHConfig) DecodeMapstructure(value interface{}) error { - v, ok := value.(map[string]any) - if !ok { - return fmt.Errorf("invalid ssh config type %T", value) - } - result := make(SSHConfig, len(v)) - i := 0 - for id, path := range v { - key := SSHKey{ID: id} - if path != nil { - key.Path = fmt.Sprint(path) +func (s *SSHConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.MappingNode: + result := make(SSHConfig, len(node.Content)/2) + for i := 0; i+1 < len(node.Content); i += 2 { + k := node.Content[i].Value + v := node.Content[i+1] + key := SSHKey{ID: k} + if v.Tag != "!!null" && v.Value != "" { + key.Path = v.Value + } + result[i/2] = key + } + *s = result + case yaml.SequenceNode: + result := make(SSHConfig, len(node.Content)) + for i, item := range node.Content { + id, path, ok := strings.Cut(item.Value, "=") + key := SSHKey{ID: id} + if ok { + key.Path = path + } + result[i] = key } - result[i] = key - i++ + *s = result + default: + return NodeErrorf(node, "invalid node kind %d for ssh config", node.Kind) } - *s = result return nil } diff --git a/types/stringOrList.go b/types/stringOrList.go index a6720df0..7c1d0836 100644 --- a/types/stringOrList.go +++ b/types/stringOrList.go @@ -16,27 +16,33 @@ package types -import "fmt" +import ( + "go.yaml.in/yaml/v4" +) // StringList is a type for fields that can be a string or list of strings type StringList []string -func (l *StringList) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case string: - *l = []string{v} - case []interface{}: - list := make([]string, len(v)) - for i, e := range v { - val, ok := e.(string) - if !ok { - return fmt.Errorf("invalid type %T for string list", value) +func (l *StringList) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.ScalarNode: + if node.Value == "" { + *l = nil + } else { + *l = []string{node.Value} + } + case yaml.SequenceNode: + list := make([]string, len(node.Content)) + for i, item := range node.Content { + if item.Kind != yaml.ScalarNode { + return NodeErrorf(item, "invalid type for string list") } - list[i] = val + list[i] = item.Value } *l = list default: - return fmt.Errorf("invalid type %T for string list", value) + return NodeErrorf(node, "invalid node kind %d for string list", node.Kind) } return nil } @@ -44,18 +50,19 @@ func (l *StringList) DecodeMapstructure(value interface{}) error { // StringOrNumberList is a type for fields that can be a list of strings or numbers type StringOrNumberList []string -func (l *StringOrNumberList) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case string: - *l = []string{v} - case []interface{}: - list := make([]string, len(v)) - for i, e := range v { - list[i] = fmt.Sprint(e) +func (l *StringOrNumberList) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.ScalarNode: + *l = []string{node.Value} + case yaml.SequenceNode: + list := make([]string, len(node.Content)) + for i, item := range node.Content { + list[i] = item.Value } *l = list default: - return fmt.Errorf("invalid type %T for string list", value) + return NodeErrorf(node, "invalid node kind %d for string list", node.Kind) } return nil } diff --git a/types/types.go b/types/types.go index fd4f3513..c87000e9 100644 --- a/types/types.go +++ b/types/types.go @@ -26,6 +26,7 @@ import ( "github.com/docker/go-connections/nat" "github.com/xhit/go-str2duration/v2" + "go.yaml.in/yaml/v4" ) // ServiceConfig is the configuration of one service @@ -75,71 +76,71 @@ type ServiceConfig struct { // If set, overrides ENTRYPOINT from the image. // // Set to `[]` or an empty string to clear the entrypoint from the image. - Entrypoint ShellCommand `yaml:"entrypoint,omitempty" json:"entrypoint"` // NOTE: we can NOT omitempty for JSON! see ShellCommand type for details. - Provider *ServiceProviderConfig `yaml:"provider,omitempty" json:"provider,omitempty"` - Environment MappingWithEquals `yaml:"environment,omitempty" json:"environment,omitempty"` - EnvFiles []EnvFile `yaml:"env_file,omitempty" json:"env_file,omitempty"` - Expose StringOrNumberList `yaml:"expose,omitempty" json:"expose,omitempty"` - Extends *ExtendsConfig `yaml:"extends,omitempty" json:"extends,omitempty"` - ExternalLinks []string `yaml:"external_links,omitempty" json:"external_links,omitempty"` - ExtraHosts HostsList `yaml:"extra_hosts,omitempty" json:"extra_hosts,omitempty"` - GroupAdd []string `yaml:"group_add,omitempty" json:"group_add,omitempty"` - Gpus []DeviceRequest `yaml:"gpus,omitempty" json:"gpus,omitempty"` - Hostname string `yaml:"hostname,omitempty" json:"hostname,omitempty"` - HealthCheck *HealthCheckConfig `yaml:"healthcheck,omitempty" json:"healthcheck,omitempty"` - Image string `yaml:"image,omitempty" json:"image,omitempty"` - Init *bool `yaml:"init,omitempty" json:"init,omitempty"` - Ipc string `yaml:"ipc,omitempty" json:"ipc,omitempty"` - Isolation string `yaml:"isolation,omitempty" json:"isolation,omitempty"` - Labels Labels `yaml:"labels,omitempty" json:"labels,omitempty"` - LabelFiles []string `yaml:"label_file,omitempty" json:"label_file,omitempty"` - CustomLabels Labels `yaml:"-" json:"-"` - Links []string `yaml:"links,omitempty" json:"links,omitempty"` - Logging *LoggingConfig `yaml:"logging,omitempty" json:"logging,omitempty"` - LogDriver string `yaml:"log_driver,omitempty" json:"log_driver,omitempty"` - LogOpt map[string]string `yaml:"log_opt,omitempty" json:"log_opt,omitempty"` - MemLimit UnitBytes `yaml:"mem_limit,omitempty" json:"mem_limit,omitempty"` - MemReservation UnitBytes `yaml:"mem_reservation,omitempty" json:"mem_reservation,omitempty"` - MemSwapLimit UnitBytes `yaml:"memswap_limit,omitempty" json:"memswap_limit,omitempty"` - MemSwappiness UnitBytes `yaml:"mem_swappiness,omitempty" json:"mem_swappiness,omitempty"` - MacAddress string `yaml:"mac_address,omitempty" json:"mac_address,omitempty"` - Models map[string]*ServiceModelConfig `yaml:"models,omitempty" json:"models,omitempty"` - Net string `yaml:"net,omitempty" json:"net,omitempty"` - NetworkMode string `yaml:"network_mode,omitempty" json:"network_mode,omitempty"` - Networks map[string]*ServiceNetworkConfig `yaml:"networks,omitempty" json:"networks,omitempty"` - OomKillDisable bool `yaml:"oom_kill_disable,omitempty" json:"oom_kill_disable,omitempty"` - OomScoreAdj int64 `yaml:"oom_score_adj,omitempty" json:"oom_score_adj,omitempty"` - Pid string `yaml:"pid,omitempty" json:"pid,omitempty"` - PidsLimit int64 `yaml:"pids_limit,omitempty" json:"pids_limit,omitempty"` - Platform string `yaml:"platform,omitempty" json:"platform,omitempty"` - Ports []ServicePortConfig `yaml:"ports,omitempty" json:"ports,omitempty"` - Privileged bool `yaml:"privileged,omitempty" json:"privileged,omitempty"` - PullPolicy string `yaml:"pull_policy,omitempty" json:"pull_policy,omitempty"` - ReadOnly bool `yaml:"read_only,omitempty" json:"read_only,omitempty"` - Restart string `yaml:"restart,omitempty" json:"restart,omitempty"` - Runtime string `yaml:"runtime,omitempty" json:"runtime,omitempty"` - Scale *int `yaml:"scale,omitempty" json:"scale,omitempty"` - Secrets []ServiceSecretConfig `yaml:"secrets,omitempty" json:"secrets,omitempty"` - SecurityOpt []string `yaml:"security_opt,omitempty" json:"security_opt,omitempty"` - ShmSize UnitBytes `yaml:"shm_size,omitempty" json:"shm_size,omitempty"` - StdinOpen bool `yaml:"stdin_open,omitempty" json:"stdin_open,omitempty"` - StopGracePeriod *Duration `yaml:"stop_grace_period,omitempty" json:"stop_grace_period,omitempty"` - StopSignal string `yaml:"stop_signal,omitempty" json:"stop_signal,omitempty"` - StorageOpt map[string]string `yaml:"storage_opt,omitempty" json:"storage_opt,omitempty"` - Sysctls Mapping `yaml:"sysctls,omitempty" json:"sysctls,omitempty"` - Tmpfs StringList `yaml:"tmpfs,omitempty" json:"tmpfs,omitempty"` - Tty bool `yaml:"tty,omitempty" json:"tty,omitempty"` - Ulimits map[string]*UlimitsConfig `yaml:"ulimits,omitempty" json:"ulimits,omitempty"` - UseAPISocket bool `yaml:"use_api_socket,omitempty" json:"use_api_socket,omitempty"` - User string `yaml:"user,omitempty" json:"user,omitempty"` - UserNSMode string `yaml:"userns_mode,omitempty" json:"userns_mode,omitempty"` - Uts string `yaml:"uts,omitempty" json:"uts,omitempty"` - VolumeDriver string `yaml:"volume_driver,omitempty" json:"volume_driver,omitempty"` - Volumes []ServiceVolumeConfig `yaml:"volumes,omitempty" json:"volumes,omitempty"` - VolumesFrom []string `yaml:"volumes_from,omitempty" json:"volumes_from,omitempty"` - WorkingDir string `yaml:"working_dir,omitempty" json:"working_dir,omitempty"` - PostStart []ServiceHook `yaml:"post_start,omitempty" json:"post_start,omitempty"` - PreStop []ServiceHook `yaml:"pre_stop,omitempty" json:"pre_stop,omitempty"` + Entrypoint ShellCommand `yaml:"entrypoint,omitempty" json:"entrypoint"` // NOTE: we can NOT omitempty for JSON! see ShellCommand type for details. + Provider *ServiceProviderConfig `yaml:"provider,omitempty" json:"provider,omitempty"` + Environment MappingWithEquals `yaml:"environment,omitempty" json:"environment,omitempty"` + EnvFiles []EnvFile `yaml:"env_file,omitempty" json:"env_file,omitempty"` + Expose StringOrNumberList `yaml:"expose,omitempty" json:"expose,omitempty"` + Extends *ExtendsConfig `yaml:"extends,omitempty" json:"extends,omitempty"` + ExternalLinks []string `yaml:"external_links,omitempty" json:"external_links,omitempty"` + ExtraHosts HostsList `yaml:"extra_hosts,omitempty" json:"extra_hosts,omitempty"` + GroupAdd []string `yaml:"group_add,omitempty" json:"group_add,omitempty"` + Gpus GpuDevices `yaml:"gpus,omitempty" json:"gpus,omitempty"` + Hostname string `yaml:"hostname,omitempty" json:"hostname,omitempty"` + HealthCheck *HealthCheckConfig `yaml:"healthcheck,omitempty" json:"healthcheck,omitempty"` + Image string `yaml:"image,omitempty" json:"image,omitempty"` + Init *bool `yaml:"init,omitempty" json:"init,omitempty"` + Ipc string `yaml:"ipc,omitempty" json:"ipc,omitempty"` + Isolation string `yaml:"isolation,omitempty" json:"isolation,omitempty"` + Labels Labels `yaml:"labels,omitempty" json:"labels,omitempty"` + LabelFiles StringList `yaml:"label_file,omitempty" json:"label_file,omitempty"` + CustomLabels Labels `yaml:"-" json:"-"` + Links []string `yaml:"links,omitempty" json:"links,omitempty"` + Logging *LoggingConfig `yaml:"logging,omitempty" json:"logging,omitempty"` + LogDriver string `yaml:"log_driver,omitempty" json:"log_driver,omitempty"` + LogOpt map[string]string `yaml:"log_opt,omitempty" json:"log_opt,omitempty"` + MemLimit UnitBytes `yaml:"mem_limit,omitempty" json:"mem_limit,omitempty"` + MemReservation UnitBytes `yaml:"mem_reservation,omitempty" json:"mem_reservation,omitempty"` + MemSwapLimit UnitBytes `yaml:"memswap_limit,omitempty" json:"memswap_limit,omitempty"` + MemSwappiness UnitBytes `yaml:"mem_swappiness,omitempty" json:"mem_swappiness,omitempty"` + MacAddress string `yaml:"mac_address,omitempty" json:"mac_address,omitempty"` + Models ServiceModels `yaml:"models,omitempty" json:"models,omitempty"` + Net string `yaml:"net,omitempty" json:"net,omitempty"` + NetworkMode string `yaml:"network_mode,omitempty" json:"network_mode,omitempty"` + Networks ServiceNetworks `yaml:"networks,omitempty" json:"networks,omitempty"` + OomKillDisable bool `yaml:"oom_kill_disable,omitempty" json:"oom_kill_disable,omitempty"` + OomScoreAdj int64 `yaml:"oom_score_adj,omitempty" json:"oom_score_adj,omitempty"` + Pid string `yaml:"pid,omitempty" json:"pid,omitempty"` + PidsLimit int64 `yaml:"pids_limit,omitempty" json:"pids_limit,omitempty"` + Platform string `yaml:"platform,omitempty" json:"platform,omitempty"` + Ports ServicePorts `yaml:"ports,omitempty" json:"ports,omitempty"` + Privileged bool `yaml:"privileged,omitempty" json:"privileged,omitempty"` + PullPolicy string `yaml:"pull_policy,omitempty" json:"pull_policy,omitempty"` + ReadOnly bool `yaml:"read_only,omitempty" json:"read_only,omitempty"` + Restart string `yaml:"restart,omitempty" json:"restart,omitempty"` + Runtime string `yaml:"runtime,omitempty" json:"runtime,omitempty"` + Scale *int `yaml:"scale,omitempty" json:"scale,omitempty"` + Secrets []ServiceSecretConfig `yaml:"secrets,omitempty" json:"secrets,omitempty"` + SecurityOpt []string `yaml:"security_opt,omitempty" json:"security_opt,omitempty"` + ShmSize UnitBytes `yaml:"shm_size,omitempty" json:"shm_size,omitempty"` + StdinOpen bool `yaml:"stdin_open,omitempty" json:"stdin_open,omitempty"` + StopGracePeriod *Duration `yaml:"stop_grace_period,omitempty" json:"stop_grace_period,omitempty"` + StopSignal string `yaml:"stop_signal,omitempty" json:"stop_signal,omitempty"` + StorageOpt map[string]string `yaml:"storage_opt,omitempty" json:"storage_opt,omitempty"` + Sysctls Mapping `yaml:"sysctls,omitempty" json:"sysctls,omitempty"` + Tmpfs StringList `yaml:"tmpfs,omitempty" json:"tmpfs,omitempty"` + Tty bool `yaml:"tty,omitempty" json:"tty,omitempty"` + Ulimits map[string]*UlimitsConfig `yaml:"ulimits,omitempty" json:"ulimits,omitempty"` + UseAPISocket bool `yaml:"use_api_socket,omitempty" json:"use_api_socket,omitempty"` + User string `yaml:"user,omitempty" json:"user,omitempty"` + UserNSMode string `yaml:"userns_mode,omitempty" json:"userns_mode,omitempty"` + Uts string `yaml:"uts,omitempty" json:"uts,omitempty"` + VolumeDriver string `yaml:"volume_driver,omitempty" json:"volume_driver,omitempty"` + Volumes []ServiceVolumeConfig `yaml:"volumes,omitempty" json:"volumes,omitempty"` + VolumesFrom []string `yaml:"volumes_from,omitempty" json:"volumes_from,omitempty"` + WorkingDir string `yaml:"working_dir,omitempty" json:"working_dir,omitempty"` + PostStart []ServiceHook `yaml:"post_start,omitempty" json:"post_start,omitempty"` + PreStop []ServiceHook `yaml:"pre_stop,omitempty" json:"pre_stop,omitempty"` Extensions Extensions `yaml:"#extensions,inline,omitempty" json:"-"` } @@ -322,6 +323,33 @@ type DeviceMapping struct { Extensions Extensions `yaml:"#extensions,inline,omitempty" json:"-"` } +func (d *DeviceMapping) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + // Short syntax: /dev/fuse or /dev/fuse:/dev/fuse or /dev/fuse:/dev/fuse:rwm + parts := strings.Split(node.Value, ":") + switch len(parts) { + case 3: + d.Source = parts[0] + d.Target = parts[1] + d.Permissions = parts[2] + case 2: + d.Source = parts[0] + d.Target = parts[1] + d.Permissions = "rwm" + case 1: + d.Source = parts[0] + d.Target = parts[0] + d.Permissions = "rwm" + default: + return NodeErrorf(node, "confusing device mapping, please use long syntax: %s", node.Value) + } + return nil + } + type plain DeviceMapping + return WrapNodeError(node, node.Decode((*plain)(d))) +} + // WeightDevice is a structure that holds device:weight pair type WeightDevice struct { Path string @@ -442,6 +470,32 @@ type PlacementPreferences struct { Extensions Extensions `yaml:"#extensions,inline,omitempty" json:"-"` } +// ServiceNetworks is a map of network names to service network configurations. +// It supports both list syntax (networks: [front, back]) and map syntax. +type ServiceNetworks map[string]*ServiceNetworkConfig + +func (n *ServiceNetworks) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.SequenceNode: + networks := make(ServiceNetworks, len(node.Content)) + for _, item := range node.Content { + networks[item.Value] = nil + } + *n = networks + case yaml.MappingNode: + type plain ServiceNetworks + var m plain + if err := node.Decode(&m); err != nil { + return err + } + *n = ServiceNetworks(m) + default: + return fmt.Errorf("networks must be a mapping or sequence, got %v", node.Kind) + } + return nil +} + // ServiceNetworkConfig is the network configuration for a service type ServiceNetworkConfig struct { Aliases []string `yaml:"aliases,omitempty" json:"aliases,omitempty"` @@ -506,6 +560,71 @@ func convertPortToPortConfig(port nat.Port, portBindings map[nat.Port][]nat.Port return portConfigs } +func (p *ServicePortConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + configs, err := ParsePortConfig(node.Value) + if err != nil { + return WrapNodeError(node, err) + } + if len(configs) == 1 { + *p = configs[0] + return nil + } + return NodeErrorf(node, "port range %q expands to multiple entries, use sequence form", node.Value) + } + type plain ServicePortConfig + if err := node.Decode((*plain)(p)); err != nil { + return WrapNodeError(node, err) + } + if p.Protocol == "" { + p.Protocol = "tcp" + } + if p.Mode == "" { + p.Mode = "ingress" + } + return nil +} + +// ServicePorts is a sequence of ServicePortConfig that handles port range expansion +// during YAML unmarshaling. Port ranges like "80-82:8080-8082" expand to +// multiple ServicePortConfig entries. +type ServicePorts []ServicePortConfig + +func (sp *ServicePorts) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind != yaml.SequenceNode { + return NodeErrorf(node, "ports must be a sequence") + } + var result []ServicePortConfig + for _, item := range node.Content { + itemNode := resolveYAMLNode(item) + if itemNode.Kind == yaml.ScalarNode { + // Could be a port range like "80-82:8080-8082" + configs, err := ParsePortConfig(itemNode.Value) + if err != nil { + return WrapNodeError(itemNode, err) + } + result = append(result, configs...) + } else { + var port ServicePortConfig + type plain ServicePortConfig + if err := itemNode.Decode((*plain)(&port)); err != nil { + return WrapNodeError(itemNode, err) + } + if port.Protocol == "" { + port.Protocol = "tcp" + } + if port.Mode == "" { + port.Mode = "ingress" + } + result = append(result, port) + } + } + *sp = result + return nil +} + // ServiceVolumeConfig are references to a volume used by a service type ServiceVolumeConfig struct { Type string `yaml:"type,omitempty" json:"type,omitempty"` @@ -540,6 +659,36 @@ func (s ServiceVolumeConfig) String() string { return fmt.Sprintf("%s:%s:%s", s.Source, s.Target, strings.Join(options, ",")) } +func (s *ServiceVolumeConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + if ParseVolumeFunc == nil { + return NodeErrorf(node, "volume short syntax %q requires ParseVolume function", node.Value) + } + parsed, err := ParseVolumeFunc(node.Value) + if err != nil { + return WrapNodeError(node, err) + } + *s = parsed + } else { + type plain ServiceVolumeConfig + if err := node.Decode((*plain)(s)); err != nil { + return WrapNodeError(node, err) + } + // Default create_host_path=true for bind volumes when bind section + // exists but create_host_path is not explicitly set + if s.Bind != nil { + if bindNode := findYAMLKey(node, "bind"); bindNode != nil { + chpNode := findYAMLKey(bindNode, "create_host_path") + if chpNode == nil { + s.Bind.CreateHostPath = true + } + } + } + } + return nil +} + const ( // VolumeTypeBind is the type for mounting host dir VolumeTypeBind = "bind" @@ -638,20 +787,22 @@ type FileReferenceConfig struct { Extensions Extensions `yaml:"#extensions,inline,omitempty" json:"-"` } -func (f *FileMode) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case *FileMode: - return nil - case string: - i, err := strconv.ParseInt(v, 8, 64) +func (f *FileMode) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Tag == "!!int" { + // YAML integer - let yaml/v4 handle octal (0-prefix), hex (0x-prefix), etc. + var i int64 + if err := node.Decode(&i); err != nil { + return WrapNodeError(node, err) + } + *f = FileMode(i) + } else { + // String — parse as octal (e.g., "0440") + i, err := strconv.ParseInt(node.Value, 8, 64) if err != nil { - return err + return WrapNodeError(node, err) } *f = FileMode(i) - case int: - *f = FileMode(v) - default: - return fmt.Errorf("unexpected value type %T for mode", value) } return nil } @@ -670,12 +821,49 @@ func (f *FileMode) String() string { return fmt.Sprintf("0%o", int64(*f)) } +func (f *FileReferenceConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + f.Source = node.Value + return nil + } + type plain FileReferenceConfig + return WrapNodeError(node, node.Decode((*plain)(f))) +} + // ServiceConfigObjConfig is the config obj configuration for a service type ServiceConfigObjConfig FileReferenceConfig +func (s *ServiceConfigObjConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + s.Source = node.Value + return nil + } + type plain ServiceConfigObjConfig + return WrapNodeError(node, node.Decode((*plain)(s))) +} + // ServiceSecretConfig is the secret configuration for a service type ServiceSecretConfig FileReferenceConfig +func (s *ServiceSecretConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + s.Source = node.Value + s.Target = fmt.Sprintf("/run/secrets/%s", s.Source) + return nil + } + type plain ServiceSecretConfig + if err := node.Decode((*plain)(s)); err != nil { + return WrapNodeError(node, err) + } + if s.Target == "" { + s.Target = fmt.Sprintf("/run/secrets/%s", s.Source) + } + return nil +} + // UlimitsConfig the ulimit configuration type UlimitsConfig struct { Single int `yaml:"single,omitempty" json:"single,omitempty"` @@ -685,29 +873,20 @@ type UlimitsConfig struct { Extensions Extensions `yaml:"#extensions,inline,omitempty" json:"-"` } -func (u *UlimitsConfig) DecodeMapstructure(value interface{}) error { - switch v := value.(type) { - case *UlimitsConfig: - // this call to DecodeMapstructure is triggered after initial value conversion as we use a map[string]*UlimitsConfig - return nil - case int: +func (u *UlimitsConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + var v int + if err := node.Decode(&v); err != nil { + return WrapNodeError(node, err) + } u.Single = v u.Soft = 0 u.Hard = 0 - case map[string]any: - u.Single = 0 - soft, ok := v["soft"] - if ok { - u.Soft = soft.(int) - } - hard, ok := v["hard"] - if ok { - u.Hard = hard.(int) - } - default: - return fmt.Errorf("unexpected value type %T for ulimit", value) + return nil } - return nil + type plain UlimitsConfig + return WrapNodeError(node, node.Decode((*plain)(u))) } // MarshalYAML makes UlimitsConfig implement yaml.Marshaller @@ -780,6 +959,22 @@ type VolumeConfig struct { // not managed, and should already exist. type External bool +func (e *External) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.ScalarNode: + var b bool + if err := node.Decode(&b); err != nil { + return err + } + *e = External(b) + case yaml.MappingNode: + // Legacy syntax: external: {name: foo} — treat as external: true + *e = true + } + return nil +} + // CredentialSpecConfig for credential spec on Windows type CredentialSpecConfig struct { Config string `yaml:"config,omitempty" json:"config,omitempty"` // Config was added in API v1.40 @@ -817,6 +1012,46 @@ const ( type DependsOnConfig map[string]ServiceDependency +func (d *DependsOnConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + switch node.Kind { + case yaml.SequenceNode: + config := make(DependsOnConfig, len(node.Content)) + for _, item := range node.Content { + config[item.Value] = ServiceDependency{ + Condition: ServiceConditionStarted, + Required: true, + } + } + *d = config + case yaml.MappingNode: + type plain DependsOnConfig + var p plain + if err := node.Decode(&p); err != nil { + return WrapNodeError(node, err) + } + for k, v := range p { + if v.Condition == "" { + v.Condition = ServiceConditionStarted + } + p[k] = v + } + // Set required=true for entries that didn't explicitly set it + for i := 0; i+1 < len(node.Content); i += 2 { + key := node.Content[i].Value + dep := p[key] + if !hasKey(node.Content[i+1], "required") { + dep.Required = true + } + p[key] = dep + } + *d = DependsOnConfig(p) + default: + return NodeErrorf(node, "unexpected node kind %d for depends_on", node.Kind) + } + return nil +} + type ServiceDependency struct { Condition string `yaml:"condition,omitempty" json:"condition,omitempty"` Restart bool `yaml:"restart,omitempty" json:"restart,omitempty"` @@ -876,3 +1111,13 @@ type IncludeConfig struct { ProjectDirectory string `yaml:"project_directory,omitempty" json:"project_directory,omitempty"` EnvFile StringList `yaml:"env_file,omitempty" json:"env_file,omitempty"` } + +func (ic *IncludeConfig) UnmarshalYAML(value *yaml.Node) error { + node := resolveYAMLNode(value) + if node.Kind == yaml.ScalarNode { + ic.Path = StringList{node.Value} + return nil + } + type plain IncludeConfig + return WrapNodeError(node, node.Decode((*plain)(ic))) +} diff --git a/types/unmarshal_test.go b/types/unmarshal_test.go new file mode 100644 index 00000000..7d66d44e --- /dev/null +++ b/types/unmarshal_test.go @@ -0,0 +1,541 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package types + +import ( + "testing" + "time" + + "go.yaml.in/yaml/v4" + "gotest.tools/v3/assert" +) + +// Labels + +func TestUnmarshalLabels_Map(t *testing.T) { + var labels Labels + err := yaml.Unmarshal([]byte(`foo: bar`), &labels) + assert.NilError(t, err) + assert.DeepEqual(t, labels, Labels{"foo": "bar"}) +} + +func TestUnmarshalLabels_List(t *testing.T) { + var labels Labels + err := yaml.Unmarshal([]byte("- foo=bar\n- baz=qux"), &labels) + assert.NilError(t, err) + assert.DeepEqual(t, labels, Labels{"foo": "bar", "baz": "qux"}) +} + +// Mapping + +func TestUnmarshalMapping_Map(t *testing.T) { + var m Mapping + err := yaml.Unmarshal([]byte("foo: bar\nbaz: qux"), &m) + assert.NilError(t, err) + assert.DeepEqual(t, m, Mapping{"foo": "bar", "baz": "qux"}) +} + +func TestUnmarshalMapping_MapNullValue(t *testing.T) { + var m Mapping + err := yaml.Unmarshal([]byte("foo:\nbaz: qux"), &m) + assert.NilError(t, err) + assert.DeepEqual(t, m, Mapping{"foo": "", "baz": "qux"}) +} + +func TestUnmarshalMapping_List(t *testing.T) { + var m Mapping + err := yaml.Unmarshal([]byte("- foo=bar\n- baz=qux"), &m) + assert.NilError(t, err) + assert.DeepEqual(t, m, Mapping{"foo": "bar", "baz": "qux"}) +} + +func TestUnmarshalMapping_ListNoValue(t *testing.T) { + var m Mapping + err := yaml.Unmarshal([]byte("- foo"), &m) + assert.NilError(t, err) + assert.DeepEqual(t, m, Mapping{"foo": ""}) +} + +// MappingWithEquals + +func TestUnmarshalMappingWithEquals_Map(t *testing.T) { + var m MappingWithEquals + err := yaml.Unmarshal([]byte("foo: bar\nbaz:"), &m) + assert.NilError(t, err) + bar := "bar" + assert.DeepEqual(t, m, MappingWithEquals{"foo": &bar, "baz": nil}) +} + +func TestUnmarshalMappingWithEquals_List(t *testing.T) { + var m MappingWithEquals + err := yaml.Unmarshal([]byte("- foo=bar\n- baz"), &m) + assert.NilError(t, err) + bar := "bar" + assert.DeepEqual(t, m, MappingWithEquals{"foo": &bar, "baz": nil}) +} + +func TestUnmarshalMappingWithEquals_ListEmptyValue(t *testing.T) { + var m MappingWithEquals + err := yaml.Unmarshal([]byte("- foo="), &m) + assert.NilError(t, err) + empty := "" + assert.DeepEqual(t, m, MappingWithEquals{"foo": &empty}) +} + +// ShellCommand + +func TestUnmarshalShellCommand_String(t *testing.T) { + var cmd ShellCommand + err := yaml.Unmarshal([]byte(`echo "hello world"`), &cmd) + assert.NilError(t, err) + assert.DeepEqual(t, cmd, ShellCommand{"echo", "hello world"}) +} + +func TestUnmarshalShellCommand_List(t *testing.T) { + var cmd ShellCommand + err := yaml.Unmarshal([]byte("- echo\n- hello world"), &cmd) + assert.NilError(t, err) + assert.DeepEqual(t, cmd, ShellCommand{"echo", "hello world"}) +} + +// UnitBytes + +func TestUnmarshalUnitBytes_Integer(t *testing.T) { + var u UnitBytes + err := yaml.Unmarshal([]byte("1024"), &u) + assert.NilError(t, err) + assert.Equal(t, u, UnitBytes(1024)) +} + +func TestUnmarshalUnitBytes_String(t *testing.T) { + var u UnitBytes + err := yaml.Unmarshal([]byte("1GB"), &u) + assert.NilError(t, err) + assert.Equal(t, u, UnitBytes(1073741824)) // docker/go-units uses binary: 1GB = 1GiB +} + +// Duration + +func TestUnmarshalDuration_String(t *testing.T) { + var d Duration + err := yaml.Unmarshal([]byte("1m30s"), &d) + assert.NilError(t, err) + assert.Equal(t, d, Duration(90*time.Second)) +} + +func TestUnmarshalDuration_Seconds(t *testing.T) { + var d Duration + err := yaml.Unmarshal([]byte("30s"), &d) + assert.NilError(t, err) + assert.Equal(t, d, Duration(30*time.Second)) +} + +// NanoCPUs + +func TestUnmarshalNanoCPUs_Float(t *testing.T) { + var n NanoCPUs + err := yaml.Unmarshal([]byte("1.5"), &n) + assert.NilError(t, err) + assert.Equal(t, n, NanoCPUs(1.5)) +} + +func TestUnmarshalNanoCPUs_Float64(t *testing.T) { + var n NanoCPUs + err := yaml.Unmarshal([]byte("0.5"), &n) + assert.NilError(t, err) + assert.Equal(t, n, NanoCPUs(0.5)) +} + +func TestUnmarshalNanoCPUs_Integer(t *testing.T) { + var n NanoCPUs + err := yaml.Unmarshal([]byte("2"), &n) + assert.NilError(t, err) + assert.Equal(t, n, NanoCPUs(2)) +} + +// DeviceCount + +func TestUnmarshalDeviceCount_Integer(t *testing.T) { + var c DeviceCount + err := yaml.Unmarshal([]byte("3"), &c) + assert.NilError(t, err) + assert.Equal(t, c, DeviceCount(3)) +} + +func TestUnmarshalDeviceCount_All(t *testing.T) { + var c DeviceCount + err := yaml.Unmarshal([]byte("all"), &c) + assert.NilError(t, err) + assert.Equal(t, c, DeviceCount(-1)) +} + +// HealthCheckTest + +func TestUnmarshalHealthCheckTest_String(t *testing.T) { + var h HealthCheckTest + err := yaml.Unmarshal([]byte("curl -f http://localhost/"), &h) + assert.NilError(t, err) + assert.DeepEqual(t, h, HealthCheckTest{"CMD-SHELL", "curl -f http://localhost/"}) +} + +func TestUnmarshalHealthCheckTest_List(t *testing.T) { + var h HealthCheckTest + err := yaml.Unmarshal([]byte("- CMD\n- curl\n- -f\n- http://localhost/"), &h) + assert.NilError(t, err) + assert.DeepEqual(t, h, HealthCheckTest{"CMD", "curl", "-f", "http://localhost/"}) +} + +// HostsList + +func TestUnmarshalHostsList_Map(t *testing.T) { + var h HostsList + err := yaml.Unmarshal([]byte("myhost: 192.168.1.1"), &h) + assert.NilError(t, err) + assert.DeepEqual(t, h, HostsList{"myhost": {"192.168.1.1"}}) +} + +func TestUnmarshalHostsList_ListEquals(t *testing.T) { + var h HostsList + err := yaml.Unmarshal([]byte("- myhost=192.168.1.1"), &h) + assert.NilError(t, err) + assert.DeepEqual(t, h, HostsList{"myhost": {"192.168.1.1"}}) +} + +func TestUnmarshalHostsList_ListColon(t *testing.T) { + var h HostsList + err := yaml.Unmarshal([]byte("- myhost:192.168.1.1"), &h) + assert.NilError(t, err) + assert.DeepEqual(t, h, HostsList{"myhost": {"192.168.1.1"}}) +} + +// StringList + +func TestUnmarshalStringList_String(t *testing.T) { + var s StringList + err := yaml.Unmarshal([]byte("hello"), &s) + assert.NilError(t, err) + assert.DeepEqual(t, s, StringList{"hello"}) +} + +func TestUnmarshalStringList_List(t *testing.T) { + var s StringList + err := yaml.Unmarshal([]byte("- hello\n- world"), &s) + assert.NilError(t, err) + assert.DeepEqual(t, s, StringList{"hello", "world"}) +} + +// StringOrNumberList + +func TestUnmarshalStringOrNumberList_String(t *testing.T) { + var s StringOrNumberList + err := yaml.Unmarshal([]byte("8080"), &s) + assert.NilError(t, err) + assert.DeepEqual(t, s, StringOrNumberList{"8080"}) +} + +func TestUnmarshalStringOrNumberList_List(t *testing.T) { + var s StringOrNumberList + err := yaml.Unmarshal([]byte("- 8080\n- 9090"), &s) + assert.NilError(t, err) + assert.DeepEqual(t, s, StringOrNumberList{"8080", "9090"}) +} + +func TestUnmarshalStringOrNumberList_MixedList(t *testing.T) { + var s StringOrNumberList + err := yaml.Unmarshal([]byte("- 8080\n- http"), &s) + assert.NilError(t, err) + assert.DeepEqual(t, s, StringOrNumberList{"8080", "http"}) +} + +// Options + +func TestUnmarshalOptions_Map(t *testing.T) { + var o Options + err := yaml.Unmarshal([]byte("foo: bar\nbaz: qux"), &o) + assert.NilError(t, err) + assert.DeepEqual(t, o, Options{"foo": "bar", "baz": "qux"}) +} + +func TestUnmarshalOptions_MapNullValue(t *testing.T) { + var o Options + err := yaml.Unmarshal([]byte("foo:\nbaz: qux"), &o) + assert.NilError(t, err) + assert.DeepEqual(t, o, Options{"foo": "", "baz": "qux"}) +} + +// MultiOptions + +func TestUnmarshalMultiOptions_Scalar(t *testing.T) { + var m MultiOptions + err := yaml.Unmarshal([]byte("foo: bar\nbaz: qux"), &m) + assert.NilError(t, err) + assert.DeepEqual(t, m, MultiOptions{"foo": {"bar"}, "baz": {"qux"}}) +} + +func TestUnmarshalMultiOptions_Sequence(t *testing.T) { + var m MultiOptions + err := yaml.Unmarshal([]byte("foo:\n - bar\n - baz"), &m) + assert.NilError(t, err) + assert.DeepEqual(t, m, MultiOptions{"foo": {"bar", "baz"}}) +} + +// SSHConfig + +func TestUnmarshalSSHConfig_Map(t *testing.T) { + var s SSHConfig + err := yaml.Unmarshal([]byte("default: /home/user/.ssh/id_rsa"), &s) + assert.NilError(t, err) + assert.DeepEqual(t, s, SSHConfig{{ID: "default", Path: "/home/user/.ssh/id_rsa"}}) +} + +func TestUnmarshalSSHConfig_MapNoPath(t *testing.T) { + var s SSHConfig + err := yaml.Unmarshal([]byte("default:"), &s) + assert.NilError(t, err) + assert.DeepEqual(t, s, SSHConfig{{ID: "default"}}) +} + +func TestUnmarshalSSHConfig_List(t *testing.T) { + var s SSHConfig + err := yaml.Unmarshal([]byte("- default=/home/user/.ssh/id_rsa"), &s) + assert.NilError(t, err) + assert.DeepEqual(t, s, SSHConfig{{ID: "default", Path: "/home/user/.ssh/id_rsa"}}) +} + +func TestUnmarshalSSHConfig_ListNoPath(t *testing.T) { + var s SSHConfig + err := yaml.Unmarshal([]byte("- default"), &s) + assert.NilError(t, err) + assert.DeepEqual(t, s, SSHConfig{{ID: "default"}}) +} + +// FileMode + +func TestUnmarshalFileMode_OctalString(t *testing.T) { + var m FileMode + err := yaml.Unmarshal([]byte(`"0755"`), &m) + assert.NilError(t, err) + assert.Equal(t, m, FileMode(0o755)) +} + +func TestUnmarshalFileMode_Integer(t *testing.T) { + var m FileMode + err := yaml.Unmarshal([]byte("0755"), &m) + assert.NilError(t, err) + // yaml/v4 treats 0755 as octal (YAML 1.2 compat), decoding to 493 + assert.Equal(t, m, FileMode(0o755)) +} + +// UlimitsConfig + +func TestUnmarshalUlimitsConfig_Integer(t *testing.T) { + var u UlimitsConfig + err := yaml.Unmarshal([]byte("1024"), &u) + assert.NilError(t, err) + assert.Equal(t, u.Single, 1024) + assert.Equal(t, u.Soft, 0) + assert.Equal(t, u.Hard, 0) +} + +func TestUnmarshalUlimitsConfig_Map(t *testing.T) { + var u UlimitsConfig + err := yaml.Unmarshal([]byte("soft: 1024\nhard: 2048"), &u) + assert.NilError(t, err) + assert.Equal(t, u.Single, 0) + assert.Equal(t, u.Soft, 1024) + assert.Equal(t, u.Hard, 2048) +} + +// BuildConfig + +func TestUnmarshalBuildConfig_String(t *testing.T) { + var b BuildConfig + err := yaml.Unmarshal([]byte("./dir"), &b) + assert.NilError(t, err) + assert.Equal(t, b.Context, "./dir") +} + +func TestUnmarshalBuildConfig_Map(t *testing.T) { + var b BuildConfig + err := yaml.Unmarshal([]byte("context: ./dir\ndockerfile: Dockerfile.dev"), &b) + assert.NilError(t, err) + assert.Equal(t, b.Context, "./dir") + assert.Equal(t, b.Dockerfile, "Dockerfile.dev") +} + +// DependsOnConfig + +func TestUnmarshalDependsOnConfig_List(t *testing.T) { + var d DependsOnConfig + err := yaml.Unmarshal([]byte("- db\n- redis"), &d) + assert.NilError(t, err) + assert.Equal(t, len(d), 2) + assert.Equal(t, d["db"].Condition, ServiceConditionStarted) + assert.Equal(t, d["db"].Required, true) + assert.Equal(t, d["redis"].Condition, ServiceConditionStarted) + assert.Equal(t, d["redis"].Required, true) +} + +func TestUnmarshalDependsOnConfig_Map(t *testing.T) { + var d DependsOnConfig + err := yaml.Unmarshal([]byte("db:\n condition: service_healthy"), &d) + assert.NilError(t, err) + assert.Equal(t, d["db"].Condition, ServiceConditionHealthy) + assert.Equal(t, d["db"].Required, true) // default when not explicitly set +} + +func TestUnmarshalDependsOnConfig_MapExplicitRequired(t *testing.T) { + var d DependsOnConfig + err := yaml.Unmarshal([]byte("db:\n condition: service_healthy\n required: false"), &d) + assert.NilError(t, err) + assert.Equal(t, d["db"].Condition, ServiceConditionHealthy) + assert.Equal(t, d["db"].Required, false) +} + +// EnvFile + +func TestUnmarshalEnvFile_String(t *testing.T) { + var e EnvFile + err := yaml.Unmarshal([]byte(".env"), &e) + assert.NilError(t, err) + assert.Equal(t, e.Path, ".env") + assert.Equal(t, bool(e.Required), true) +} + +func TestUnmarshalEnvFile_Map(t *testing.T) { + var e EnvFile + err := yaml.Unmarshal([]byte("path: .env\nrequired: false"), &e) + assert.NilError(t, err) + assert.Equal(t, e.Path, ".env") + assert.Equal(t, bool(e.Required), false) +} + +func TestUnmarshalEnvFile_MapNoRequired(t *testing.T) { + var e EnvFile + err := yaml.Unmarshal([]byte("path: .env"), &e) + assert.NilError(t, err) + assert.Equal(t, e.Path, ".env") + assert.Equal(t, bool(e.Required), true) // defaults to true +} + +// IncludeConfig + +func TestUnmarshalIncludeConfig_String(t *testing.T) { + var ic IncludeConfig + err := yaml.Unmarshal([]byte("docker-compose.yml"), &ic) + assert.NilError(t, err) + assert.DeepEqual(t, ic.Path, StringList{"docker-compose.yml"}) +} + +func TestUnmarshalIncludeConfig_Map(t *testing.T) { + var ic IncludeConfig + err := yaml.Unmarshal([]byte("path:\n - docker-compose.yml\nproject_directory: ./subdir"), &ic) + assert.NilError(t, err) + assert.DeepEqual(t, ic.Path, StringList{"docker-compose.yml"}) + assert.Equal(t, ic.ProjectDirectory, "./subdir") +} + +// ServiceSecretConfig + +func TestUnmarshalServiceSecretConfig_String(t *testing.T) { + var s ServiceSecretConfig + err := yaml.Unmarshal([]byte("my_secret"), &s) + assert.NilError(t, err) + assert.Equal(t, s.Source, "my_secret") +} + +func TestUnmarshalServiceSecretConfig_Map(t *testing.T) { + var s ServiceSecretConfig + err := yaml.Unmarshal([]byte("source: my_secret\ntarget: /run/secrets/my_secret"), &s) + assert.NilError(t, err) + assert.Equal(t, s.Source, "my_secret") + assert.Equal(t, s.Target, "/run/secrets/my_secret") +} + +// ServiceConfigObjConfig + +func TestUnmarshalServiceConfigObjConfig_String(t *testing.T) { + var c ServiceConfigObjConfig + err := yaml.Unmarshal([]byte("my_config"), &c) + assert.NilError(t, err) + assert.Equal(t, c.Source, "my_config") +} + +func TestUnmarshalServiceConfigObjConfig_Map(t *testing.T) { + var c ServiceConfigObjConfig + err := yaml.Unmarshal([]byte("source: my_config\ntarget: /etc/my_config"), &c) + assert.NilError(t, err) + assert.Equal(t, c.Source, "my_config") + assert.Equal(t, c.Target, "/etc/my_config") +} + +// FileReferenceConfig + +func TestUnmarshalFileReferenceConfig_String(t *testing.T) { + var f FileReferenceConfig + err := yaml.Unmarshal([]byte("my_ref"), &f) + assert.NilError(t, err) + assert.Equal(t, f.Source, "my_ref") +} + +func TestUnmarshalFileReferenceConfig_Map(t *testing.T) { + var f FileReferenceConfig + err := yaml.Unmarshal([]byte("source: my_ref\ntarget: /path/to/ref"), &f) + assert.NilError(t, err) + assert.Equal(t, f.Source, "my_ref") + assert.Equal(t, f.Target, "/path/to/ref") +} + +// ServicePortConfig + +func TestUnmarshalServicePortConfig_String(t *testing.T) { + var p ServicePortConfig + err := yaml.Unmarshal([]byte(`"8080:80"`), &p) + assert.NilError(t, err) + assert.Equal(t, p.Target, uint32(80)) + assert.Equal(t, p.Published, "8080") + assert.Equal(t, p.Protocol, "tcp") +} + +func TestUnmarshalServicePortConfig_Map(t *testing.T) { + var p ServicePortConfig + err := yaml.Unmarshal([]byte("target: 80\npublished: \"8080\"\nprotocol: tcp"), &p) + assert.NilError(t, err) + assert.Equal(t, p.Target, uint32(80)) + assert.Equal(t, p.Published, "8080") + assert.Equal(t, p.Protocol, "tcp") +} + +func TestUnmarshalServicePortConfig_StringTargetOnly(t *testing.T) { + var p ServicePortConfig + err := yaml.Unmarshal([]byte(`"80"`), &p) + assert.NilError(t, err) + assert.Equal(t, p.Target, uint32(80)) + assert.Equal(t, p.Protocol, "tcp") +} + +// ServiceVolumeConfig (map form only, string form requires ParseVolumeFunc) + +func TestUnmarshalServiceVolumeConfig_Map(t *testing.T) { + var v ServiceVolumeConfig + err := yaml.Unmarshal([]byte("type: bind\nsource: ./data\ntarget: /data\nread_only: true"), &v) + assert.NilError(t, err) + assert.Equal(t, v.Type, "bind") + assert.Equal(t, v.Source, "./data") + assert.Equal(t, v.Target, "/data") + assert.Equal(t, v.ReadOnly, true) +} diff --git a/types/yaml.go b/types/yaml.go new file mode 100644 index 00000000..2c3a52da --- /dev/null +++ b/types/yaml.go @@ -0,0 +1,134 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package types + +import ( + "errors" + "fmt" + "strings" + + "go.yaml.in/yaml/v4" +) + +// NodeError is an error that carries yaml source location (file, line, column). +type NodeError struct { + Line int + Column int + Source string + Err error +} + +func (e *NodeError) Error() string { + if e.Source != "" { + return fmt.Sprintf("%s:%d:%d: %s", e.Source, e.Line, e.Column, e.Err) + } + return fmt.Sprintf("line %d, column %d: %s", e.Line, e.Column, e.Err) +} + +func (e *NodeError) Unwrap() error { + return e.Err +} + +// NodeErrorf creates a NodeError from a yaml.Node and a formatted message. +func NodeErrorf(node *yaml.Node, format string, args ...any) error { + return &NodeError{ + Line: node.Line, + Column: node.Column, + Err: fmt.Errorf(format, args...), + } +} + +// WrapNodeError wraps an existing error with yaml.Node source location. +func WrapNodeError(node *yaml.Node, err error) error { + if err == nil { + return nil + } + return &NodeError{ + Line: node.Line, + Column: node.Column, + Err: err, + } +} + +// WithSource enriches any NodeError instances found in the error chain or +// message with the given source filename. It handles errors wrapped by +// yaml/v4's LoadErrors which break the standard errors.As chain. +func WithSource(err error, source string) error { + if err == nil { + return nil + } + // Direct match or wrapped + var ne *NodeError + if errors.As(err, &ne) { + enriched := &NodeError{ + Line: ne.Line, + Column: ne.Column, + Source: source, + Err: ne.Err, + } + return fmt.Errorf("%w", enriched) + } + // yaml/v4 LoadErrors wraps errors in a way that breaks errors.As. + // Check if the error message already has line info from our NodeError.Error(). + // If so, enrich the message with the source file. + msg := err.Error() + if strings.Contains(msg, "line ") && strings.Contains(msg, "column ") { + return fmt.Errorf("%s: %s", source, msg) + } + return err +} + +// resolveYAMLNode unwraps a DocumentNode wrapper that yaml/v4 passes to +// UnmarshalYAML methods. If the node is a DocumentNode with a single child, +// the child is returned; otherwise the node is returned as-is. +func resolveYAMLNode(node *yaml.Node) *yaml.Node { + if node.Kind == yaml.DocumentNode && len(node.Content) == 1 { + return node.Content[0] + } + return node +} + +// hasKey checks if a MappingNode contains a specific key +func hasKey(node *yaml.Node, key string) bool { + if node.Kind != yaml.MappingNode { + return false + } + for i := 0; i+1 < len(node.Content); i += 2 { + if node.Content[i].Value == key { + return true + } + } + return false +} + +// findYAMLKey finds a key in a MappingNode and returns the value node. +func findYAMLKey(node *yaml.Node, key string) *yaml.Node { + if node == nil || node.Kind != yaml.MappingNode { + return nil + } + for i := 0; i+1 < len(node.Content); i += 2 { + if node.Content[i].Value == key { + return node.Content[i+1] + } + } + return nil +} + +// ParseVolumeFunc is a package-level hook for parsing volume short syntax. +// It is set by the loader package to break the circular dependency between +// types and format (format imports types, so types cannot import format). +var ParseVolumeFunc func(string) (ServiceVolumeConfig, error) diff --git a/validation/external.go b/validation/external.go index b74d551a..9982ad8e 100644 --- a/validation/external.go +++ b/validation/external.go @@ -29,7 +29,20 @@ func checkExternal(v map[string]any, p tree.Path) error { if !ok { return nil } - if !b.(bool) { + + // Handle legacy syntax: external: {name: foo} + if ext, ok := b.(map[string]any); ok { + if extName, extNamed := ext["name"]; extNamed { + name, named := v["name"] + if named && extName != name { + return fmt.Errorf("%s: name and external.name conflict; only use name", p) + } + } + // Treat as external: true for the remaining checks + b = true + } + + if bVal, ok := b.(bool); ok && !bVal { return nil }