Skip to content

Commit d225e24

Browse files
authored
fix comment map (#635)
1 parent c8cc5c5 commit d225e24

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

decode.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/base64"
88
"fmt"
99
"io"
10+
"maps"
1011
"math"
1112
"os"
1213
"path/filepath"
@@ -30,6 +31,7 @@ type Decoder struct {
3031
aliasValueMap map[*ast.AliasNode]any
3132
anchorValueMap map[string]reflect.Value
3233
customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error
34+
commentMaps []CommentMap
3335
toCommentMap CommentMap
3436
opts []DecodeOption
3537
referenceFiles []string
@@ -1957,6 +1959,12 @@ func (d *Decoder) parse(bytes []byte) (*ast.File, error) {
19571959
if v != nil {
19581960
normalizedFile.Docs = append(normalizedFile.Docs, doc)
19591961
}
1962+
cm := CommentMap{}
1963+
maps.Copy(cm, d.toCommentMap)
1964+
d.commentMaps = append(d.commentMaps, cm)
1965+
for k := range d.toCommentMap {
1966+
delete(d.toCommentMap, k)
1967+
}
19601968
}
19611969
return normalizedFile, nil
19621970
}
@@ -1980,9 +1988,6 @@ func (d *Decoder) decodeInit() error {
19801988
return err
19811989
}
19821990
d.parsedFile = file
1983-
for k := range d.toCommentMap {
1984-
delete(d.toCommentMap, k)
1985-
}
19861991
return nil
19871992
}
19881993

@@ -1995,6 +2000,9 @@ func (d *Decoder) decode(ctx context.Context, v reflect.Value) error {
19952000
if body == nil {
19962001
return nil
19972002
}
2003+
if len(d.commentMaps) > d.streamIndex {
2004+
maps.Copy(d.toCommentMap, d.commentMaps[d.streamIndex])
2005+
}
19982006
if err := d.decodeValue(ctx, v.Elem(), body); err != nil {
19992007
return err
20002008
}

yaml_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package yaml_test
2+
3+
import (
4+
"io"
5+
"reflect"
6+
"strings"
7+
"testing"
8+
9+
"github.com/goccy/go-yaml"
10+
)
11+
12+
func TestRoundTripWithComment(t *testing.T) {
13+
yml := `
14+
# head comment
15+
key: value # line comment
16+
`
17+
var v struct {
18+
Key string
19+
}
20+
comments := yaml.CommentMap{}
21+
22+
if err := yaml.UnmarshalWithOptions([]byte(yml), &v, yaml.Strict(), yaml.CommentToMap(comments)); err != nil {
23+
t.Fatal(err)
24+
}
25+
out, err := yaml.MarshalWithOptions(v, yaml.WithComment(comments))
26+
if err != nil {
27+
t.Fatal(err)
28+
}
29+
got := "\n" + string(out)
30+
if yml != got {
31+
t.Fatalf("failed to get round tripped yaml: %s", got)
32+
}
33+
}
34+
35+
func TestStreamDecodingWithComment(t *testing.T) {
36+
yml := `
37+
a:
38+
b:
39+
c: # comment
40+
---
41+
foo: bar # comment
42+
---
43+
- a
44+
- b
45+
- c # comment
46+
`
47+
cm := yaml.CommentMap{}
48+
dec := yaml.NewDecoder(strings.NewReader(yml), yaml.CommentToMap(cm))
49+
var commentPathsWithDocIndex [][]string
50+
for {
51+
var v any
52+
if err := dec.Decode(&v); err != nil {
53+
if err == io.EOF {
54+
break
55+
}
56+
t.Fatal(err)
57+
}
58+
paths := make([]string, 0, len(cm))
59+
for k := range cm {
60+
paths = append(paths, k)
61+
}
62+
commentPathsWithDocIndex = append(commentPathsWithDocIndex, paths)
63+
for k := range cm {
64+
delete(cm, k)
65+
}
66+
}
67+
if !reflect.DeepEqual(commentPathsWithDocIndex, [][]string{
68+
{"$.a.b.c"},
69+
{"$.foo"},
70+
{"$[2]"},
71+
}) {
72+
t.Fatalf("failed to get comment: %v", commentPathsWithDocIndex)
73+
}
74+
}

0 commit comments

Comments
 (0)