diff --git a/internal/builtins/pkg_context_test.go b/internal/builtins/pkg_context_test.go index 32a87fc93..ca4b67a19 100644 --- a/internal/builtins/pkg_context_test.go +++ b/internal/builtins/pkg_context_test.go @@ -18,6 +18,7 @@ import ( "bytes" "os" "path/filepath" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -62,7 +63,9 @@ func TestPkgContextGenerator(t *testing.T) { if err != test.expErr { t.Errorf("exp: %v got: %v", test.expErr, err) } - if diff := cmp.Diff(string(exp), out.String()); diff != "" { + expected := strings.ReplaceAll(string(exp), "\r\n", "\n") + actual := strings.ReplaceAll(out.String(), "\r\n", "\n") + if diff := cmp.Diff(expected, actual); diff != "" { t.Errorf("pkg context mistmach (-want +got):\n%s", diff) } }) diff --git a/internal/util/get/get_test.go b/internal/util/get/get_test.go index ab965599b..0d38f3097 100644 --- a/internal/util/get/get_test.go +++ b/internal/util/get/get_test.go @@ -213,15 +213,29 @@ func TestCommand_Run_subdir_symlinks(t *testing.T) { }.Run(fake.CtxWithPrinter(cliOutput, cliOutput)) assert.NoError(t, err) - // ensure warning for symlink is printed on the CLI - assert.Contains(t, cliOutput.String(), `[Warn] Ignoring symlink "config-symlink"`) + sourceSymlinkPath := filepath.Join(g.DatasetDirectory, testutil.Dataset6, subdir, "config-symlink") + info, statErr := os.Lstat(sourceSymlinkPath) + assert.NoError(t, statErr) + isSymlinkInSource := info.Mode()&os.ModeSymlink != 0 + + if isSymlinkInSource { + // ensure warning for symlink is printed on the CLI + assert.Contains(t, cliOutput.String(), `[Warn] Ignoring symlink "config-symlink"`) + } else { + // on environments without symlink materialization, there is no ignore warning + assert.NotContains(t, cliOutput.String(), `[Warn] Ignoring symlink "config-symlink"`) + } // verify the cloned contents do not contains symlinks diff, err := testutil.Diff(filepath.Join(g.DatasetDirectory, testutil.Dataset6, subdir), absPath, true) assert.NoError(t, err) diff = diff.Difference(testutil.KptfileSet) - // original repo contains symlink and cloned doesn't, so the difference - assert.Contains(t, diff.List(), "config-symlink") + if isSymlinkInSource { + // original repo contains symlink and cloned doesn't, so the difference + assert.Contains(t, diff.List(), "config-symlink") + } else { + assert.NotContains(t, diff.List(), "config-symlink") + } // verify the KptFile contains the expected values commit, err := g.GetCommit() diff --git a/internal/util/render/executor_test.go b/internal/util/render/executor_test.go index 2c8e253f5..e8b83637f 100644 --- a/internal/util/render/executor_test.go +++ b/internal/util/render/executor_test.go @@ -267,7 +267,7 @@ kind: Kptfile metadata: name: root-package annotations: - kpt.dev/bfs-rendering: %t + kpt.dev/bfs-rendering: "%t" `, renderBfs)) assert.NoError(t, err) @@ -335,8 +335,12 @@ func TestRenderer_Execute_RenderOrder(t *testing.T) { renderer, outputBuffer, ctx := setupRendererTest(t, tc.renderBfs) fnResults, err := renderer.Execute(ctx) - assert.NoError(t, err) - assert.NotNil(t, fnResults) + if !assert.NoError(t, err) { + return + } + if !assert.NotNil(t, fnResults) { + return + } assert.Equal(t, 0, len(fnResults.Items)) output := outputBuffer.String() @@ -420,7 +424,7 @@ kind: Kptfile metadata: name: root-package annotations: - ktp.dev/bfs-rendering: true + kpt.dev/bfs-rendering: "true" `)) assert.NoError(t, err) @@ -434,7 +438,9 @@ metadata: // Create a mock hydration context root, err := newPkgNode(mockFileSystem, rootPkgPath, nil) - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } hctx := &hydrationContext{ root: root, diff --git a/pkg/kptfile/kptfileutil/util.go b/pkg/kptfile/kptfileutil/util.go index 01018a3ff..a4f2bfa76 100644 --- a/pkg/kptfile/kptfileutil/util.go +++ b/pkg/kptfile/kptfileutil/util.go @@ -21,6 +21,7 @@ import ( "io" "os" "path/filepath" + "reflect" "slices" "strings" @@ -28,6 +29,7 @@ import ( "github.com/kptdev/kpt/internal/util/git" kptfilev1 "github.com/kptdev/kpt/pkg/api/kptfile/v1" "github.com/kptdev/kpt/pkg/lib/errors" + "github.com/kptdev/krm-functions-sdk/go/fn" "k8s.io/apimachinery/pkg/runtime/schema" "sigs.k8s.io/kustomize/kyaml/filesys" "sigs.k8s.io/kustomize/kyaml/sets" @@ -44,6 +46,13 @@ var SupportedKptfileVersions = []schema.GroupVersionKind{ kptfilev1.KptFileGVK(), } +var sdkInternalKptfileAnnotations = []string{ + "config.kubernetes.io/index", + "internal.config.kubernetes.io/index", + "internal.config.kubernetes.io/path", + "internal.config.kubernetes.io/seqindent", +} + // KptfileError records errors regarding reading or parsing of a Kptfile. type KptfileError struct { Path types.UniquePath @@ -78,6 +87,20 @@ func (e *UnknownKptfileResourceError) Error() string { func WriteFile(dir string, k any) error { const op errors.Op = "kptfileutil.WriteFile" + if kf, ok := k.(*kptfilev1.KptFile); ok { + if err := writeKptfilePreservingFormat(dir, kf); err != nil { + return errors.E(op, types.UniquePath(dir), err) + } + return nil + } + + if kf, ok := k.(kptfilev1.KptFile); ok { + if err := writeKptfilePreservingFormat(dir, &kf); err != nil { + return errors.E(op, types.UniquePath(dir), err) + } + return nil + } + b, err := yaml.MarshalWithOptions(k, &yaml.EncoderOptions{SeqIndent: yaml.WideSequenceStyle}) if err != nil { return err @@ -94,6 +117,157 @@ func WriteFile(dir string, k any) error { return nil } +func writeKptfilePreservingFormat(dir string, kf *kptfilev1.KptFile) error { + kptfilePath := filepath.Join(dir, kptfilev1.KptFileName) + if _, err := os.Stat(dir); err != nil { + return err + } + + content, err := os.ReadFile(kptfilePath) + if err != nil { + if goerrors.Is(err, os.ErrNotExist) { + b, marshalErr := yaml.MarshalWithOptions(kf, &yaml.EncoderOptions{SeqIndent: yaml.WideSequenceStyle}) + if marshalErr != nil { + return marshalErr + } + return os.WriteFile(kptfilePath, b, 0600) + } + return err + } + + existingResources := map[string]string{kptfilev1.KptFileName: string(content)} + existingKptfile, err := fn.NewKptfileFromPackage(existingResources) + if err != nil { + return err + } + if err := applyTypedKptfileToSDK(existingKptfile, kf); err != nil { + return err + } + if err := existingKptfile.WriteToPackage(existingResources); err != nil { + return err + } + return os.WriteFile(kptfilePath, []byte(existingResources[kptfilev1.KptFileName]), 0600) +} + +func applyTypedKptfileToSDK(sdkKptfile *fn.Kptfile, desired *kptfilev1.KptFile) error { + if sdkKptfile == nil || sdkKptfile.Obj == nil { + return fmt.Errorf("cannot update empty sdk Kptfile") + } + + if err := sdkKptfile.Obj.SetNestedString(desired.APIVersion, "apiVersion"); err != nil { + return err + } + if err := sdkKptfile.Obj.SetNestedString(desired.Kind, "kind"); err != nil { + return err + } + if err := sdkKptfile.Obj.SetNestedString(desired.Name, "metadata", "name"); err != nil { + return err + } + + if err := setOrRemoveNestedField(sdkKptfile.Obj, desired.Annotations, "metadata", "annotations"); err != nil { + return err + } + if err := setOrRemoveNestedField(sdkKptfile.Obj, desired.Labels, "metadata", "labels"); err != nil { + return err + } + if err := setOrRemoveNestedField(sdkKptfile.Obj, desired.Pipeline, "pipeline"); err != nil { + return err + } + if err := setOrRemoveNestedField(sdkKptfile.Obj, desired.Info, "info"); err != nil { + return err + } + if err := setOrRemoveNestedField(sdkKptfile.Obj, desired.Inventory, "inventory"); err != nil { + return err + } + if err := setOrRemoveNestedField(sdkKptfile.Obj, desired.Status, "status"); err != nil { + return err + } + + if err := setOrRemoveUpstream(sdkKptfile.Obj, desired.Upstream); err != nil { + return err + } + + if err := setOrRemoveUpstreamLock(sdkKptfile.Obj, desired.UpstreamLock); err != nil { + return err + } + + return nil +} + +func setOrRemoveNestedField(obj *fn.KubeObject, val any, fields ...string) error { + if val == nil || reflect.ValueOf(val).IsZero() { + _, err := obj.RemoveNestedField(fields...) + return err + } + return obj.SetNestedField(val, fields...) +} + +func setOrRemoveNestedString(obj *fn.KubeObject, value string, fields ...string) error { + if strings.TrimSpace(value) == "" { + _, err := obj.RemoveNestedField(fields...) + return err + } + return obj.SetNestedString(value, fields...) +} + +func setOrRemoveUpstream(obj *fn.KubeObject, upstream *kptfilev1.Upstream) error { + if upstream == nil { + _, err := obj.RemoveNestedField("upstream") + return err + } + + obj.UpsertMap("upstream") + if err := setOrRemoveNestedString(obj, string(upstream.Type), "upstream", "type"); err != nil { + return err + } + if err := setOrRemoveNestedString(obj, string(upstream.UpdateStrategy), "upstream", "updateStrategy"); err != nil { + return err + } + + if upstream.Git == nil { + _, err := obj.RemoveNestedField("upstream", "git") + return err + } + + obj.UpsertMap("upstream").UpsertMap("git") + if err := setOrRemoveNestedString(obj, upstream.Git.Repo, "upstream", "git", "repo"); err != nil { + return err + } + if err := setOrRemoveNestedString(obj, upstream.Git.Directory, "upstream", "git", "directory"); err != nil { + return err + } + return setOrRemoveNestedString(obj, upstream.Git.Ref, "upstream", "git", "ref") +} + +func setOrRemoveUpstreamLock(obj *fn.KubeObject, upstreamLock *kptfilev1.Locator) error { + if upstreamLock == nil { + _, err := obj.RemoveNestedField("upstreamLock") + return err + } + + obj.UpsertMap("upstreamLock") + if err := setOrRemoveNestedString(obj, string(upstreamLock.Type), "upstreamLock", "type"); err != nil { + return err + } + + if upstreamLock.Git == nil { + _, err := obj.RemoveNestedField("upstreamLock", "git") + return err + } + + obj.UpsertMap("upstreamLock").UpsertMap("git") + if err := setOrRemoveNestedString(obj, upstreamLock.Git.Repo, "upstreamLock", "git", "repo"); err != nil { + return err + } + if err := setOrRemoveNestedString(obj, upstreamLock.Git.Directory, "upstreamLock", "git", "directory"); err != nil { + return err + } + if err := setOrRemoveNestedString(obj, upstreamLock.Git.Ref, "upstreamLock", "git", "ref"); err != nil { + return err + } + return setOrRemoveNestedString(obj, upstreamLock.Git.Commit, "upstreamLock", "git", "commit") +} + // ValidateInventory returns true and a nil error if the passed inventory // is valid; otherwiste, false and the reason the inventory is not valid // is returned. A valid inventory must have a non-empty namespace, name, @@ -299,18 +473,88 @@ func DecodeKptfile(in io.Reader) (*kptfilev1.KptFile, error) { if err != nil { return kf, err } - if err := checkKptfileVersion(c); err != nil { + if err := validateKptfileContent(c); err != nil { return kf, err } - d := yaml.NewDecoder(bytes.NewBuffer(c)) - d.KnownFields(true) - if err := d.Decode(kf); err != nil { + kubeObjects, err := fn.ReadKubeObjectsFromFile(kptfilev1.KptFileName, string(c)) + if err != nil { + return kf, err + } + + sdkKptfile, err := fn.NewKptfileFromKubeObjectList(kubeObjects) + if err != nil { return kf, err } + + if err := sdkKptfile.Obj.As(kf); err != nil { + return kf, err + } + + stripSDKInternalKptfileAnnotations(kf) + return kf, nil } +// UpdateKptfileContent updates Kptfile YAML content in-memory using SDK Kptfile +// read/write APIs while preserving existing YAML document structure and comments. +func UpdateKptfileContent(content string, mutator func(*kptfilev1.KptFile)) (string, error) { + if err := validateKptfileContent([]byte(content)); err != nil { + return "", err + } + + resources := map[string]string{kptfilev1.KptFileName: content} + sdkKptfile, err := fn.NewKptfileFromPackage(resources) + if err != nil { + return "", err + } + + typedKptfile := &kptfilev1.KptFile{} + if err := sdkKptfile.Obj.As(typedKptfile); err != nil { + return "", err + } + stripSDKInternalKptfileAnnotations(typedKptfile) + + mutator(typedKptfile) + + if err := applyTypedKptfileToSDK(sdkKptfile, typedKptfile); err != nil { + return "", err + } + + if err := sdkKptfile.WriteToPackage(resources); err != nil { + return "", err + } + + return resources[kptfilev1.KptFileName], nil +} + +func validateKptfileContent(content []byte) error { + if err := checkKptfileVersion(content); err != nil { + return err + } + + d := yaml.NewDecoder(bytes.NewBuffer(content)) + d.KnownFields(true) + if err := d.Decode(&kptfilev1.KptFile{}); err != nil { + return err + } + + return nil +} + +func stripSDKInternalKptfileAnnotations(kf *kptfilev1.KptFile) { + if kf == nil || kf.ObjectMeta.Annotations == nil { + return + } + + for _, key := range sdkInternalKptfileAnnotations { + delete(kf.ObjectMeta.Annotations, key) + } + if len(kf.ObjectMeta.Annotations) == 0 { + kf.ObjectMeta.Annotations = nil + } +} + // checkKptfileVersion verifies the apiVersion and kind of the resource // within the Kptfile. If the legacy version is found, the DeprecatedKptfileError // is returned. If the currently supported apiVersion and kind is found, no diff --git a/pkg/kptfile/kptfileutil/util_test.go b/pkg/kptfile/kptfileutil/util_test.go index 424ef1162..53222a214 100644 --- a/pkg/kptfile/kptfileutil/util_test.go +++ b/pkg/kptfile/kptfileutil/util_test.go @@ -595,11 +595,390 @@ status: t.FailNow() } - assert.Equal(t, strings.TrimSpace(tc.expected)+"\n", string(c)) + expectedObj := map[string]any{} + err = yaml.Unmarshal([]byte(strings.TrimSpace(tc.expected)), &expectedObj) + if !assert.NoError(t, err) { + t.FailNow() + } + + actualObj := map[string]any{} + err = yaml.Unmarshal(c, &actualObj) + if !assert.NoError(t, err) { + t.FailNow() + } + + assert.Equal(t, expectedObj, actualObj) + }) + } +} + +func TestUpdateKptfile_PreservesCommentsAndFormatting(t *testing.T) { + writeKptfileToTemp := func(tt *testing.T, content string) string { + dir := tt.TempDir() + err := os.WriteFile(filepath.Join(dir, kptfilev1.KptFileName), []byte(content), 0600) + if !assert.NoError(tt, err) { + tt.FailNow() + } + return dir + } + + originDir := writeKptfileToTemp(t, ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample +upstream: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.0.0 +`) + + updatedDir := writeKptfileToTemp(t, ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample +upstream: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.1.0 +upstreamLock: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.1.0 + commit: abcdef +`) + + localDir := writeKptfileToTemp(t, ` +# local package level comment +apiVersion: kpt.dev/v1 # api comment +kind: Kptfile +metadata: + name: sample + +# preserve this section comment +upstream: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.0.0 # keep inline comment +`) + + err := UpdateKptfile(localDir, updatedDir, originDir, true) + if !assert.NoError(t, err) { + t.FailNow() + } + + contentBytes, err := os.ReadFile(filepath.Join(localDir, kptfilev1.KptFileName)) + if !assert.NoError(t, err) { + t.FailNow() + } + content := string(contentBytes) + + assert.Contains(t, content, "# local package level comment") + assert.Contains(t, content, "apiVersion: kpt.dev/v1 # api comment") + assert.Contains(t, content, "# preserve this section comment") + assert.Contains(t, content, "ref: v1.1.0 # keep inline comment") + assert.Contains(t, content, "commit: abcdef") +} + +func TestUpdateKptfile_PreservesExactFormattingAndComments(t *testing.T) { + writeKptfileToTemp := func(tt *testing.T, content string) string { + dir := tt.TempDir() + err := os.WriteFile(filepath.Join(dir, kptfilev1.KptFileName), []byte(content), 0600) + if !assert.NoError(tt, err) { + tt.FailNow() + } + return dir + } + + originDir := writeKptfileToTemp(t, ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample +upstream: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.0.0 +upstreamLock: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.0.0 + commit: abc123 +`) + + updatedDir := writeKptfileToTemp(t, ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample +upstream: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.1.0 +upstreamLock: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.1.0 + commit: def456 +`) + + localDir := writeKptfileToTemp(t, ` +apiVersion: kpt.dev/v1 # keep api inline comment +kind: Kptfile +metadata: + name: sample +# preserve this comment block +upstream: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.0.0 # keep ref inline comment + +upstreamLock: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.0.0 + commit: abc123 # keep commit inline comment +`) + + err := UpdateKptfile(localDir, updatedDir, originDir, true) + if !assert.NoError(t, err) { + t.FailNow() + } + + contentBytes, err := os.ReadFile(filepath.Join(localDir, kptfilev1.KptFileName)) + if !assert.NoError(t, err) { + t.FailNow() + } + + want := ` +apiVersion: kpt.dev/v1 # keep api inline comment +kind: Kptfile +metadata: + name: sample +# preserve this comment block +upstream: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.1.0 # keep ref inline comment +upstreamLock: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.1.0 + commit: def456 # keep commit inline comment +` + + assert.Equal(t, strings.TrimSpace(want), strings.TrimSpace(string(contentBytes))) +} + +func TestWriteFile_ReturnsErrorWhenDirectoryMissing(t *testing.T) { + nonExistentDir := filepath.Join(t.TempDir(), "does-not-exist") + + err := WriteFile(nonExistentDir, DefaultKptfile("sample")) + assert.Error(t, err) +} + +func TestUpdateKptfile_ReturnsErrorOnInvalidLocalKptfile(t *testing.T) { + writeKptfileToTemp := func(tt *testing.T, content string) string { + dir := tt.TempDir() + err := os.WriteFile(filepath.Join(dir, kptfilev1.KptFileName), []byte(content), 0600) + if !assert.NoError(tt, err) { + tt.FailNow() + } + return dir + } + + originDir := writeKptfileToTemp(t, ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample +`) + + updatedDir := writeKptfileToTemp(t, ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample +`) + + localDir := writeKptfileToTemp(t, ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: [bad +`) + + err := UpdateKptfile(localDir, updatedDir, originDir, true) + assert.Error(t, err) +} + +func TestUpdateKptfileContent_UsesDecodeValidation(t *testing.T) { + testCases := map[string]struct { + content string + expectedErr any + expectedDecodeError string + }{ + "deprecated version": { + content: ` +apiVersion: kpt.dev/v1alpha2 +kind: Kptfile +metadata: + name: sample +`, + expectedErr: &DeprecatedKptfileError{}, + expectedDecodeError: "old resource version \"v1alpha2\" found in Kptfile", + }, + "unknown kind": { + content: ` +apiVersion: kpt.dev/v1 +kind: ConfigMap +metadata: + name: sample +`, + expectedErr: &UnknownKptfileResourceError{}, + expectedDecodeError: "unknown resource type \"kpt.dev/v1, Kind=ConfigMap\" found in Kptfile", + }, + "unknown field": { + content: ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample +unexpectedField: true +`, + expectedDecodeError: "yaml: unmarshal errors:\n line 6: field unexpectedField not found in type v1.KptFile", + }, + } + + for tn, tc := range testCases { + t.Run(tn, func(t *testing.T) { + _, decodeErr := DecodeKptfile(strings.NewReader(tc.content)) + _, updateErr := UpdateKptfileContent(tc.content, func(*kptfilev1.KptFile) {}) + + if !assert.EqualError(t, decodeErr, tc.expectedDecodeError) { + t.FailNow() + } + if !assert.EqualError(t, updateErr, decodeErr.Error()) { + t.FailNow() + } + if tc.expectedErr != nil { + assert.IsType(t, tc.expectedErr, decodeErr) + assert.IsType(t, tc.expectedErr, updateErr) + } }) } } +func TestUpdateKptfileContent_StripsSDKInternalAnnotations(t *testing.T) { + t.Run("preserves user annotations", func(t *testing.T) { + content := ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample + annotations: + config.kubernetes.io/index: "0" + internal.config.kubernetes.io/path: Kptfile + user.example.com/keep: value +` + + updatedContent, err := UpdateKptfileContent(content, func(kf *kptfilev1.KptFile) { + kf.Name = "updated-sample" + }) + if !assert.NoError(t, err) { + t.FailNow() + } + + updatedKf, err := DecodeKptfile(strings.NewReader(updatedContent)) + if !assert.NoError(t, err) { + t.FailNow() + } + + assert.Equal(t, "updated-sample", updatedKf.Name) + if assert.NotNil(t, updatedKf.Annotations) { + assert.Equal(t, "value", updatedKf.Annotations["user.example.com/keep"]) + for _, key := range sdkInternalKptfileAnnotations { + assert.NotContains(t, updatedKf.Annotations, key) + } + } + assert.NotContains(t, updatedContent, "config.kubernetes.io/index") + assert.NotContains(t, updatedContent, "internal.config.kubernetes.io/path") + }) + + t.Run("removes empty annotation map", func(t *testing.T) { + content := ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample + annotations: + config.kubernetes.io/index: "0" + internal.config.kubernetes.io/index: "0" +` + + updatedContent, err := UpdateKptfileContent(content, func(*kptfilev1.KptFile) {}) + if !assert.NoError(t, err) { + t.FailNow() + } + + updatedKf, err := DecodeKptfile(strings.NewReader(updatedContent)) + if !assert.NoError(t, err) { + t.FailNow() + } + + assert.Nil(t, updatedKf.Annotations) + assert.NotContains(t, updatedContent, "annotations:") + }) + + t.Run("handles missing annotations safely", func(t *testing.T) { + content := ` +apiVersion: kpt.dev/v1 +kind: Kptfile +metadata: + name: sample +` + + updatedContent, err := UpdateKptfileContent(content, func(kf *kptfilev1.KptFile) { + kf.Name = "updated-sample" + }) + if !assert.NoError(t, err) { + t.FailNow() + } + + updatedKf, err := DecodeKptfile(strings.NewReader(updatedContent)) + if !assert.NoError(t, err) { + t.FailNow() + } + + assert.Equal(t, "updated-sample", updatedKf.Name) + assert.Nil(t, updatedKf.Annotations) + }) +} + func TestMerge(t *testing.T) { testCases := map[string]struct { origin string diff --git a/pkg/lib/kptops/clone.go b/pkg/lib/kptops/clone.go index 0cd51c402..f210f48b7 100644 --- a/pkg/lib/kptops/clone.go +++ b/pkg/lib/kptops/clone.go @@ -21,49 +21,35 @@ import ( kptfilev1 "github.com/kptdev/kpt/pkg/api/kptfile/v1" "github.com/kptdev/kpt/pkg/kptfile/kptfileutil" - "sigs.k8s.io/kustomize/kyaml/yaml" ) func UpdateUpstream(kptfileContents string, name string, upstream kptfilev1.Upstream, lock kptfilev1.Locator) (string, error) { - kptfile, err := kptfileutil.DecodeKptfile(strings.NewReader(kptfileContents)) - if err != nil { - return "", fmt.Errorf("cannot parse Kptfile: %w", err) - } - // Normalize the repository URL and directory path normalizeGitFields(&upstream) normalizeGitLockFields(&lock) // Use separate function for lock - // populate the cloneFrom values so we know where the package came from - kptfile.UpstreamLock = &lock - kptfile.Upstream = &upstream - if name != "" { - kptfile.Name = name - } - - b, err := yaml.MarshalWithOptions(kptfile, &yaml.EncoderOptions{SeqIndent: yaml.WideSequenceStyle}) - if err != nil { - return "", fmt.Errorf("cannot save Kptfile: %w", err) - } - - return string(b), nil + return updateKptfileContentsPreservingFormat(kptfileContents, func(kptfile *kptfilev1.KptFile) { + kptfile.UpstreamLock = &lock + kptfile.Upstream = &upstream + if name != "" { + kptfile.Name = name + } + }) } func UpdateName(kptfileContents string, name string) (string, error) { - kptfile, err := kptfileutil.DecodeKptfile(strings.NewReader(kptfileContents)) - if err != nil { - return "", fmt.Errorf("cannot parse Kptfile: %w", err) - } - - // update the name of the package - kptfile.Name = name + return updateKptfileContentsPreservingFormat(kptfileContents, func(kptfile *kptfilev1.KptFile) { + kptfile.Name = name + }) +} - b, err := yaml.MarshalWithOptions(kptfile, &yaml.EncoderOptions{SeqIndent: yaml.WideSequenceStyle}) +func updateKptfileContentsPreservingFormat(kptfileContents string, mutator func(*kptfilev1.KptFile)) (string, error) { + out, err := kptfileutil.UpdateKptfileContent(kptfileContents, mutator) if err != nil { - return "", fmt.Errorf("cannot save Kptfile: %w", err) + return "", fmt.Errorf("cannot update Kptfile: %w", err) } - return string(b), nil + return out, nil } func UpdateKptfileUpstream(name string, contents map[string]string, upstream kptfilev1.Upstream, lock kptfilev1.Locator) error { diff --git a/pkg/lib/kptops/clone_test.go b/pkg/lib/kptops/clone_test.go index d2123bbfb..492893d30 100644 --- a/pkg/lib/kptops/clone_test.go +++ b/pkg/lib/kptops/clone_test.go @@ -15,6 +15,7 @@ package kptops import ( + "strings" "testing" kptfilev1 "github.com/kptdev/kpt/pkg/api/kptfile/v1" @@ -79,3 +80,93 @@ func TestNormalizeGitLockFields(t *testing.T) { t.Errorf("Expected unchanged repo URL, got %q", lock.Git.Repo) } } + +func TestUpdateUpstream_PreservesCommentsAndFormatting(t *testing.T) { + input := ` +apiVersion: kpt.dev/v1 # api inline comment +kind: Kptfile +metadata: + name: sample +# upstream comment +upstream: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.0.0 # ref inline comment +` + + upstream := kptfilev1.Upstream{ + Type: kptfilev1.GitOrigin, + Git: &kptfilev1.Git{ + Repo: "https://github.com/example/repo", + Directory: "/package", + Ref: "v1.1.0", + }, + } + + lock := kptfilev1.Locator{ + Type: kptfilev1.GitOrigin, + Git: &kptfilev1.GitLock{ + Repo: "https://github.com/example/repo", + Directory: "/package", + Ref: "v1.1.0", + Commit: "abcdef", + }, + } + + got, err := UpdateUpstream(input, "", upstream, lock) + if err != nil { + t.Fatalf("UpdateUpstream returned error: %v", err) + } + + want := ` +apiVersion: kpt.dev/v1 # api inline comment +kind: Kptfile +metadata: + name: sample +# upstream comment +upstream: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.1.0 # ref inline comment +upstreamLock: + type: git + git: + repo: https://github.com/example/repo.git + directory: package + ref: v1.1.0 + commit: abcdef +` + + if strings.TrimSpace(got) != strings.TrimSpace(want) { + t.Fatalf("updated Kptfile mismatch\nwant:\n%s\n\ngot:\n%s", want, got) + } +} + +func TestUpdateName_PreservesCommentsAndFormatting(t *testing.T) { + input := ` +apiVersion: kpt.dev/v1 # api inline comment +kind: Kptfile +metadata: + name: old-name # name inline comment +` + + got, err := UpdateName(input, "new-name") + if err != nil { + t.Fatalf("UpdateName returned error: %v", err) + } + + want := ` +apiVersion: kpt.dev/v1 # api inline comment +kind: Kptfile +metadata: + name: new-name # name inline comment +` + + if strings.TrimSpace(got) != strings.TrimSpace(want) { + t.Fatalf("updated Kptfile mismatch\nwant:\n%s\n\ngot:\n%s", want, got) + } +} diff --git a/pkg/lib/kptops/render_test.go b/pkg/lib/kptops/render_test.go index 8170e236b..cee3916b9 100644 --- a/pkg/lib/kptops/render_test.go +++ b/pkg/lib/kptops/render_test.go @@ -78,10 +78,10 @@ func TestRender(t *testing.T) { t.Errorf("Render failed: %v", err) } - got := output.String() - want := readFile(t, filepath.Join(testdata, test.name, test.want)) + got := strings.ReplaceAll(output.String(), "\r\n", "\n") + want := strings.ReplaceAll(string(readFile(t, filepath.Join(testdata, test.name, test.want))), "\r\n", "\n") - if diff := cmp.Diff(string(want), got); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Errorf("Unexpected result (-want, +got): %s", diff) } })