Skip to content

Commit ceb94a1

Browse files
committed
auto anchor
1 parent 9cbf5d4 commit ceb94a1

File tree

3 files changed

+71
-5
lines changed

3 files changed

+71
-5
lines changed

decode.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ type Decoder struct {
2727
reader io.Reader
2828
referenceReaders []io.Reader
2929
anchorNodeMap map[string]ast.Node
30-
aliasValueMap map[*ast.AliasNode]any
30+
aliasValueMap map[string]any
3131
anchorValueMap map[string]reflect.Value
3232
customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error
3333
toCommentMap CommentMap
@@ -51,7 +51,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
5151
return &Decoder{
5252
reader: r,
5353
anchorNodeMap: map[string]ast.Node{},
54-
aliasValueMap: make(map[*ast.AliasNode]any),
54+
aliasValueMap: make(map[string]any),
5555
anchorValueMap: map[string]reflect.Value{},
5656
customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{},
5757
opts: opts,
@@ -447,13 +447,18 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
447447
return nil, err
448448
}
449449
d.anchorNodeMap[anchorName] = n.Value
450+
d.anchorValueMap[anchorName] = reflect.ValueOf(anchorValue)
450451
return anchorValue, nil
451452
case *ast.AliasNode:
452-
if v, exists := d.aliasValueMap[n]; exists {
453+
if v, exists := d.anchorValueMap[n.Value.String()]; exists {
454+
return v.Interface(), nil
455+
}
456+
text := n.String()
457+
if v, exists := d.aliasValueMap[text]; exists {
453458
return v, nil
454459
}
455460
// To handle the case where alias is processed recursively, the result of alias can be set to nil in advance.
456-
d.aliasValueMap[n] = nil
461+
d.aliasValueMap[text] = nil
457462

458463
aliasName := n.Value.GetToken().Value
459464
node, exists := d.anchorNodeMap[aliasName]
@@ -465,7 +470,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
465470
return nil, err
466471
}
467472
// once the correct alias value is obtained, overwrite with that value.
468-
d.aliasValueMap[n] = aliasValue
473+
d.aliasValueMap[text] = aliasValue
469474
return aliasValue, nil
470475
case *ast.LiteralNode:
471476
return n.Value.GetValue(), nil
@@ -1985,6 +1990,7 @@ func (d *Decoder) decodeInit() error {
19851990

19861991
func (d *Decoder) decode(ctx context.Context, v reflect.Value) error {
19871992
d.decodeDepth = 0
1993+
d.aliasValueMap = make(map[string]any)
19881994
if len(d.parsedFile.Docs) <= d.streamIndex {
19891995
return io.EOF
19901996
}

encode.go

+53
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ type Encoder struct {
3434
isFlowStyle bool
3535
isJSONStyle bool
3636
useJSONMarshaler bool
37+
useAutoAnchor bool
3738
anchorCallback func(*ast.AnchorNode, interface{}) error
3839
anchorPtrToNameMap map[uintptr]string
40+
anchorNameRefMap map[string]struct{}
3941
customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error)
4042
useLiteralStyleIfMultiline bool
4143
commentMap map[*Path][]*Comment
@@ -56,6 +58,7 @@ func NewEncoder(w io.Writer, opts ...EncodeOption) *Encoder {
5658
opts: opts,
5759
indent: DefaultIndentSpaces,
5860
anchorPtrToNameMap: map[uintptr]string{},
61+
anchorNameRefMap: make(map[string]struct{}),
5962
customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){},
6063
line: 1,
6164
column: 1,
@@ -111,6 +114,10 @@ func (e *Encoder) EncodeToNodeContext(ctx context.Context, v interface{}) (ast.N
111114
return nil, err
112115
}
113116
}
117+
if _, err := e.encodeValue(ctx, reflect.ValueOf(v), 1); err != nil {
118+
return nil, err
119+
}
120+
e.anchorPtrToNameMap = make(map[uintptr]string)
114121
node, err := e.encodeValue(ctx, reflect.ValueOf(v), 1)
115122
if err != nil {
116123
return nil, err
@@ -448,6 +455,7 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
448455
case reflect.Ptr:
449456
anchorName := e.anchorPtrToNameMap[v.Pointer()]
450457
if anchorName != "" {
458+
e.anchorNameRefMap[anchorName] = struct{}{}
451459
aliasName := anchorName
452460
alias := ast.Alias(token.New("*", "*", e.pos(column)))
453461
alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column)))
@@ -464,6 +472,14 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
464472
if mapSlice, ok := v.Interface().(MapSlice); ok {
465473
return e.encodeMapSlice(ctx, mapSlice, column)
466474
}
475+
anchorName := e.anchorPtrToNameMap[v.Pointer()]
476+
if anchorName != "" {
477+
e.anchorNameRefMap[anchorName] = struct{}{}
478+
aliasName := anchorName
479+
alias := ast.Alias(token.New("*", "*", e.pos(column)))
480+
alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column)))
481+
return alias, nil
482+
}
467483
return e.encodeSlice(ctx, v)
468484
case reflect.Array:
469485
return e.encodeArray(ctx, v)
@@ -478,6 +494,13 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
478494
}
479495
return e.encodeStruct(ctx, v, column)
480496
case reflect.Map:
497+
anchorName := e.anchorPtrToNameMap[v.Pointer()]
498+
if anchorName != "" {
499+
aliasName := anchorName
500+
alias := ast.Alias(token.New("*", "*", e.pos(column)))
501+
alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column)))
502+
return alias, nil
503+
}
481504
return e.encodeMap(ctx, v, column), nil
482505
default:
483506
return nil, fmt.Errorf("unknown value type %s", v.Type().String())
@@ -662,11 +685,21 @@ func (e *Encoder) encodeMap(ctx context.Context, value reflect.Value, column int
662685
if e.isMapNode(value) {
663686
value.AddColumn(e.indent)
664687
}
688+
if _, exists := e.anchorNameRefMap[fmt.Sprint(key)]; exists {
689+
anchorName := fmt.Sprint(key)
690+
anchorNode := ast.Anchor(token.New("&", "&", e.pos(column)))
691+
anchorNode.Name = ast.String(token.New(anchorName, anchorName, e.pos(column)))
692+
anchorNode.Value = value
693+
value = anchorNode
694+
}
665695
node.Values = append(node.Values, ast.MappingValue(
666696
nil,
667697
e.encodeString(fmt.Sprint(key), column),
668698
value,
669699
))
700+
if ptr := e.toPointer(v); ptr != 0 {
701+
e.anchorPtrToNameMap[ptr] = fmt.Sprint(key)
702+
}
670703
}
671704
return node
672705
}
@@ -868,3 +901,23 @@ func (e *Encoder) encodeStruct(ctx context.Context, value reflect.Value, column
868901
}
869902
return node, nil
870903
}
904+
905+
func (e *Encoder) toPointer(v reflect.Value) uintptr {
906+
if e.isInvalidValue(v) {
907+
return 0
908+
}
909+
910+
switch v.Type().Kind() {
911+
case reflect.Ptr:
912+
return v.Pointer()
913+
case reflect.Interface:
914+
return e.toPointer(v.Elem())
915+
case reflect.Slice:
916+
return v.Pointer()
917+
case reflect.Array:
918+
return v.Pointer()
919+
case reflect.Map:
920+
return v.Pointer()
921+
}
922+
return 0
923+
}

option.go

+7
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ func Flow(isFlowStyle bool) EncodeOption {
143143
}
144144
}
145145

146+
func UseAutoAnchor() EncodeOption {
147+
return func(e *Encoder) error {
148+
e.useAutoAnchor = true
149+
return nil
150+
}
151+
}
152+
146153
// UseLiteralStyleIfMultiline causes encoding multiline strings with a literal syntax,
147154
// no matter what characters they include
148155
func UseLiteralStyleIfMultiline(useLiteralStyleIfMultiline bool) EncodeOption {

0 commit comments

Comments
 (0)