diff --git a/loader/extends.go b/loader/extends.go index c4fd0be2..d85e84ba 100644 --- a/loader/extends.go +++ b/loader/extends.go @@ -27,7 +27,7 @@ import ( "github.com/compose-spec/compose-go/v2/types" ) -func ApplyExtends(ctx context.Context, dict map[string]any, opts *Options, tracker *cycleTracker, post ...PostProcessor) error { +func ApplyExtends(ctx context.Context, dict map[string]any, opts *Options, tracker *cycleTracker, post PostProcessor) error { a, ok := dict["services"] if !ok { return nil @@ -37,7 +37,7 @@ func ApplyExtends(ctx context.Context, dict map[string]any, opts *Options, track return fmt.Errorf("services must be a mapping") } for name := range services { - merged, err := applyServiceExtends(ctx, name, services, opts, tracker, post...) + merged, err := applyServiceExtends(ctx, name, services, opts, tracker, post) if err != nil { return err } @@ -47,7 +47,7 @@ func ApplyExtends(ctx context.Context, dict map[string]any, opts *Options, track return nil } -func applyServiceExtends(ctx context.Context, name string, services map[string]any, opts *Options, tracker *cycleTracker, post ...PostProcessor) (any, error) { +func applyServiceExtends(ctx context.Context, name string, services map[string]any, opts *Options, tracker *cycleTracker, post PostProcessor) (any, error) { s := services[name] if s == nil { return nil, nil @@ -81,7 +81,7 @@ func applyServiceExtends(ctx context.Context, name string, services map[string]a var ( base any - processor PostProcessor = NoopPostProcessor{} + processor = post ) if file != nil { @@ -114,16 +114,15 @@ func applyServiceExtends(ctx context.Context, name string, services map[string]a } source := deepClone(base).(map[string]any) - for _, processor := range post { - err = processor.Apply(map[string]any{ - "services": map[string]any{ - name: source, - }, - }) - if err != nil { - return nil, err - } + err = post.Apply(map[string]any{ + "services": map[string]any{ + name: source, + }, + }) + if err != nil { + return nil, err } + merged, err := override.ExtendService(source, service) if err != nil { return nil, err diff --git a/loader/loader.go b/loader/loader.go index 99b15364..91f687f9 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -427,7 +427,7 @@ func loadYamlFile(ctx context.Context, file.Content = content } - processRawYaml := func(raw interface{}, processors ...PostProcessor) error { + processRawYaml := func(raw interface{}, processor PostProcessor) error { converted, err := convertToStringKeysRecursive(raw, "") if err != nil { return err @@ -447,16 +447,14 @@ func loadYamlFile(ctx context.Context, fixEmptyNotNull(cfg) if !opts.SkipExtends { - err = ApplyExtends(ctx, cfg, opts, ct, processors...) + err = ApplyExtends(ctx, cfg, opts, ct, processor) if err != nil { return err } } - for _, processor := range processors { - if err := processor.Apply(dict); err != nil { - return err - } + if err := processor.Apply(dict); err != nil { + return err } if !opts.SkipInclude { @@ -519,7 +517,7 @@ func loadYamlFile(ctx context.Context, } } } else { - if err := processRawYaml(file.Config); err != nil { + if err := processRawYaml(file.Config, NoopPostProcessor{}); err != nil { return nil, nil, err } } diff --git a/loader/loader_test.go b/loader/loader_test.go index 101afa3d..66449711 100644 --- a/loader/loader_test.go +++ b/loader/loader_test.go @@ -3917,65 +3917,3 @@ services: assert.Equal(t, build.Provenance, "mode=max") assert.Equal(t, build.SBOM, "true") } - -func TestOverrideMiddle(t *testing.T) { - pwd := t.TempDir() - base := filepath.Join(pwd, "base.yaml") - err := os.WriteFile(base, []byte(` -services: - base: - volumes: - - /foo:/foo -`), 0o700) - assert.NilError(t, err) - - override := filepath.Join(pwd, "override.yaml") - err = os.WriteFile(override, []byte(` -services: - override: - extends: - file: ./base.yaml - service: base - volumes: !override - - /bar:/bar -`), 0o700) - assert.NilError(t, err) - - compose := filepath.Join(pwd, "compose.yaml") - err = os.WriteFile(compose, []byte(` -name: test -services: - test: - image: test - extends: - file: ./override.yaml - service: override - volumes: - - /zot:/zot -`), 0o700) - assert.NilError(t, err) - - project, err := LoadWithContext(context.TODO(), types.ConfigDetails{ - WorkingDir: pwd, - ConfigFiles: []types.ConfigFile{ - {Filename: compose}, - }, - }) - assert.NilError(t, err) - test := project.Services["test"] - assert.Equal(t, len(test.Volumes), 2) - assert.DeepEqual(t, test.Volumes, []types.ServiceVolumeConfig{ - { - Type: "bind", - Source: "/bar", - Target: "/bar", - Bind: &types.ServiceVolumeBind{CreateHostPath: true}, - }, - { - Type: "bind", - Source: "/zot", - Target: "/zot", - Bind: &types.ServiceVolumeBind{CreateHostPath: true}, - }, - }) -} diff --git a/loader/override_test.go b/loader/override_test.go index 035fb296..7a4b62e3 100644 --- a/loader/override_test.go +++ b/loader/override_test.go @@ -18,6 +18,8 @@ package loader import ( "context" + "os" + "path/filepath" "testing" "github.com/compose-spec/compose-go/v2/types" @@ -197,3 +199,125 @@ services: assert.NilError(t, err) assert.Equal(t, len(p.Services["test"].Volumes), 1) } + +// see https://github.com/docker/compose/issues/13298 +func TestOverrideMiddle(t *testing.T) { + pwd := t.TempDir() + base := filepath.Join(pwd, "base.yaml") + err := os.WriteFile(base, []byte(` +services: + base: + volumes: + - /foo:/foo + networks: + - foo +`), 0o700) + assert.NilError(t, err) + + override := filepath.Join(pwd, "override.yaml") + err = os.WriteFile(override, []byte(` +services: + override: + extends: + file: ./base.yaml + service: base + volumes: !override + - /bar:/bar + networks: !override + - bar +`), 0o700) + assert.NilError(t, err) + + compose := filepath.Join(pwd, "compose.yaml") + err = os.WriteFile(compose, []byte(` +name: test +services: + test: + image: test + extends: + file: ./override.yaml + service: override + volumes: + - /zot:/zot + networks: !override + - zot + +networks: + zot: {} +`), 0o700) + assert.NilError(t, err) + + project, err := LoadWithContext(context.TODO(), types.ConfigDetails{ + WorkingDir: pwd, + ConfigFiles: []types.ConfigFile{ + {Filename: compose}, + }, + }) + assert.NilError(t, err) + test := project.Services["test"] + assert.Equal(t, len(test.Volumes), 2) + assert.DeepEqual(t, test.Volumes, []types.ServiceVolumeConfig{ + { + Type: "bind", + Source: "/bar", + Target: "/bar", + Bind: &types.ServiceVolumeBind{CreateHostPath: true}, + }, + { + Type: "bind", + Source: "/zot", + Target: "/zot", + Bind: &types.ServiceVolumeBind{CreateHostPath: true}, + }, + }) + assert.DeepEqual(t, test.Networks, map[string]*types.ServiceNetworkConfig{ + "zot": nil, + }) +} + +// https://github.com/docker/compose/issues/13346 +func TestOverrideSelfExtends(t *testing.T) { + yaml := ` +name: test-override-extends +services: + depend_base: + image: nginx + ports: + - "8092:80" + depend_one: + image: nginx + ports: + - "8091:80" + depend_two: + extends: + service: depend_one + main_one: + image: nginx + depends_on: + - depend_one + ports: + - "8090:80" + main_two: + extends: main_one + depends_on: !override + - depend_two + main: + extends: + service: main_two + depends_on: + - depend_base +` + p, err := LoadWithContext(context.Background(), types.ConfigDetails{ + ConfigFiles: []types.ConfigFile{ + { + Filename: "-", + Content: []byte(yaml), + }, + }, + }) + assert.NilError(t, err) + assert.DeepEqual(t, p.Services["main"].DependsOn, types.DependsOnConfig{ + "depend_base": {Condition: "service_started", Required: true}, + "depend_two": {Condition: "service_started", Required: true}, + }) +}