Skip to content

Commit 3faa0cd

Browse files
alecsammonmweibel
authored andcommitted
Fix pointer marshal
1 parent dc3eac4 commit 3faa0cd

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

sheriff.go

+11
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,21 @@ func marshalValue(options *Options, v reflect.Value) (interface{}, error) {
282282
// types which are e.g. structs, slices or maps and implement one of the following interfaces should not be
283283
// marshalled by sheriff because they'll be correctly marshalled by json.Marshal instead.
284284
// Otherwise (e.g. net.IP) a byte slice may be output as a list of uints instead of as an IP string.
285+
// This needs to be checked for both value and pointer types.
285286
switch val.(type) {
286287
case json.Marshaler, encoding.TextMarshaler, fmt.Stringer:
287288
return val, nil
288289
}
290+
291+
if v.CanAddr() {
292+
addrVal := v.Addr().Interface()
293+
294+
switch addrVal.(type) {
295+
case json.Marshaler, encoding.TextMarshaler, fmt.Stringer:
296+
return addrVal, nil
297+
}
298+
}
299+
289300
k := v.Kind()
290301

291302
switch k {

sheriff_test.go

+38-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sheriff
22

33
import (
44
"encoding/json"
5+
"fmt"
56
"net"
67
"reflect"
78
"testing"
@@ -579,14 +580,46 @@ type TestMarshal_Embedded struct {
579580
Foo string `json:"foo" groups:"test"`
580581
}
581582

583+
// TestMarshal_EmbeddedCustom is used to test an embedded struct with a custom marshaler that is not a pointer.
584+
type TestMarshal_EmbeddedCustom struct {
585+
Val int
586+
Set bool
587+
}
588+
589+
func (t TestMarshal_EmbeddedCustom) MarshalJSON() ([]byte, error) {
590+
if t.Set {
591+
return []byte(fmt.Sprintf("%d", t.Val)), nil
592+
}
593+
594+
return nil, nil
595+
}
596+
597+
// TestMarshal_EmbeddedCustomPtr is used to test an embedded struct with a custom marshaler that is a pointer.
598+
type TestMarshal_EmbeddedCustomPtr struct {
599+
Val int
600+
Set bool
601+
}
602+
603+
func (t *TestMarshal_EmbeddedCustomPtr) MarshalJSON() ([]byte, error) {
604+
if t.Set {
605+
return []byte(fmt.Sprintf("%d", t.Val)), nil
606+
}
607+
608+
return nil, nil
609+
}
610+
582611
type TestMarshal_EmbeddedParent struct {
583612
*TestMarshal_Embedded
584-
Bar string `json:"bar" groups:"test"`
613+
*TestMarshal_EmbeddedCustom `json:"value"`
614+
*TestMarshal_EmbeddedCustomPtr `json:"value_ptr"`
615+
Bar string `json:"bar" groups:"test"`
585616
}
586617

587618
func TestMarshal_EmbeddedField(t *testing.T) {
588619
v := TestMarshal_EmbeddedParent{
589620
&TestMarshal_Embedded{"Hello"},
621+
&TestMarshal_EmbeddedCustom{10, true},
622+
&TestMarshal_EmbeddedCustomPtr{20, true},
590623
"World",
591624
}
592625
o := &Options{Groups: []string{"test"}}
@@ -598,8 +631,10 @@ func TestMarshal_EmbeddedField(t *testing.T) {
598631
assert.NoError(t, err)
599632

600633
expected, err := json.Marshal(map[string]interface{}{
601-
"bar": "World",
602-
"foo": "Hello",
634+
"bar": "World",
635+
"foo": "Hello",
636+
"value": 10,
637+
"value_ptr": 20,
603638
})
604639
assert.NoError(t, err)
605640

0 commit comments

Comments
 (0)