diff --git a/Makefile b/Makefile index cc5ebba..dc5a4a2 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,7 @@ generate: build ./tests/unknown_fields.go \ ./tests/type_declaration.go \ ./tests/members_escaped.go \ + ./tests/interface_type.go \ ./tests/intern.go \ ./tests/nocopy.go \ ./tests/escaping.go \ diff --git a/gen/encoder.go b/gen/encoder.go index 22db5e9..e2a6665 100644 --- a/gen/encoder.go +++ b/gen/encoder.go @@ -264,23 +264,39 @@ func (g *Generator) genTypeEncoderNoCheck(t reflect.Type, in string, tags fieldT case reflect.Interface: if t.NumMethod() != 0 { if g.interfaceIsEasyjsonMarshaller(t) { - fmt.Fprintln(g.out, ws+in+".MarshalEasyJSON(out)") + fmt.Fprintln(g.out, ws+" if easyjson.IsNilInterface("+in+") {") + fmt.Fprintln(g.out, ws+" out.RawString(`null`)") + fmt.Fprintln(g.out, ws+" } else {") + fmt.Fprintln(g.out, ws+" "+in+".MarshalEasyJSON(out)") + fmt.Fprintln(g.out, ws+" }") } else if g.interfaceIsJSONMarshaller(t) { fmt.Fprintln(g.out, ws+"if m, ok := "+in+".(easyjson.Marshaler); ok {") - fmt.Fprintln(g.out, ws+" m.MarshalEasyJSON(out)") + fmt.Fprintln(g.out, ws+" if easyjson.IsNilInterface("+in+") {") + fmt.Fprintln(g.out, ws+" out.RawString(`null`)") + fmt.Fprintln(g.out, ws+" } else {") + fmt.Fprintln(g.out, ws+" m.MarshalEasyJSON(out)") + fmt.Fprintln(g.out, ws+" }") fmt.Fprintln(g.out, ws+"} else {") - fmt.Fprintln(g.out, ws+in+".MarshalJSON()") + fmt.Fprintln(g.out, ws+" if easyjson.IsNilInterface("+in+") {") + fmt.Fprintln(g.out, ws+" out.RawString(`null`)") + fmt.Fprintln(g.out, ws+" } else {") + fmt.Fprintln(g.out, ws+" "+in+".MarshalJSON()") + fmt.Fprintln(g.out, ws+" }") fmt.Fprintln(g.out, ws+"}") } else { return fmt.Errorf("interface type %v not supported: only interface{} and interfaces that implement json or easyjson Marshaling are allowed", t) } } else { - fmt.Fprintln(g.out, ws+"if m, ok := "+in+".(easyjson.Marshaler); ok {") - fmt.Fprintln(g.out, ws+" m.MarshalEasyJSON(out)") - fmt.Fprintln(g.out, ws+"} else if m, ok := "+in+".(json.Marshaler); ok {") - fmt.Fprintln(g.out, ws+" out.Raw(m.MarshalJSON())") - fmt.Fprintln(g.out, ws+"} else {") - fmt.Fprintln(g.out, ws+" out.Raw(json.Marshal("+in+"))") + fmt.Fprintln(g.out, ws+"if easyjson.IsNilInterface("+in+") {") + fmt.Fprintln(g.out, ws+" out.RawString(`null`)") + fmt.Fprintln(g.out, ws+" } else {") + fmt.Fprintln(g.out, ws+" if m, ok := "+in+".(easyjson.Marshaler); ok {") + fmt.Fprintln(g.out, ws+" m.MarshalEasyJSON(out)") + fmt.Fprintln(g.out, ws+" } else if m, ok := "+in+".(json.Marshaler); ok {") + fmt.Fprintln(g.out, ws+" out.Raw(m.MarshalJSON())") + fmt.Fprintln(g.out, ws+" } else {") + fmt.Fprintln(g.out, ws+" out.Raw(json.Marshal("+in+"))") + fmt.Fprintln(g.out, ws+" }") fmt.Fprintln(g.out, ws+"}") } default: @@ -419,7 +435,6 @@ func (g *Generator) genStructEncoder(t reflect.Type) error { firstCondition := true for i, f := range fs { firstCondition, err = g.genStructFieldEncoder(t, f, i == 0, firstCondition) - if err != nil { return err } diff --git a/helpers.go b/helpers.go index efe34bf..2f607cd 100644 --- a/helpers.go +++ b/helpers.go @@ -3,7 +3,6 @@ package easyjson import ( "io" - "io/ioutil" "net/http" "strconv" "unsafe" @@ -43,14 +42,14 @@ type UnknownsMarshaler interface { MarshalUnknowns(w *jwriter.Writer, first bool) } -func isNilInterface(i interface{}) bool { +func IsNilInterface(i interface{}) bool { return (*[2]uintptr)(unsafe.Pointer(&i))[1] == 0 } // Marshal returns data as a single byte slice. Method is suboptimal as the data is likely to be copied // from a chain of smaller chunks. func Marshal(v Marshaler) ([]byte, error) { - if isNilInterface(v) { + if IsNilInterface(v) { return nullBytes, nil } @@ -61,7 +60,7 @@ func Marshal(v Marshaler) ([]byte, error) { // MarshalToWriter marshals the data to an io.Writer. func MarshalToWriter(v Marshaler, w io.Writer) (written int, err error) { - if isNilInterface(v) { + if IsNilInterface(v) { return w.Write(nullBytes) } @@ -75,7 +74,7 @@ func MarshalToWriter(v Marshaler, w io.Writer) (written int, err error) { // false if an error occurred before any http.ResponseWriter methods were actually // invoked (in this case a 500 reply is possible). func MarshalToHTTPResponseWriter(v Marshaler, w http.ResponseWriter) (started bool, written int, err error) { - if isNilInterface(v) { + if IsNilInterface(v) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Length", strconv.Itoa(len(nullBytes))) written, err = w.Write(nullBytes) @@ -104,7 +103,7 @@ func Unmarshal(data []byte, v Unmarshaler) error { // UnmarshalFromReader reads all the data in the reader and decodes as JSON into the object. func UnmarshalFromReader(r io.Reader, v Unmarshaler) error { - data, err := ioutil.ReadAll(r) + data, err := io.ReadAll(r) if err != nil { return err } diff --git a/helpers_test.go b/helpers_test.go index a10a46c..e37a341 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -5,7 +5,7 @@ import "testing" func BenchmarkNilCheck(b *testing.B) { var a *int for i := 0; i < b.N; i++ { - if !isNilInterface(a) { + if !IsNilInterface(a) { b.Fatal("expected it to be nil") } } diff --git a/tests/interface_test.go b/tests/interface_test.go new file mode 100644 index 0000000..96c8678 --- /dev/null +++ b/tests/interface_test.go @@ -0,0 +1,72 @@ +package tests + +import ( + "bytes" + "testing" +) + +func TestInterfaceMarshal(t *testing.T) { + inner := InterfaceType{ + Field1: 1, + } + wrapper := WrapperType{ + Inner: inner, + } + + data, err := wrapper.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON failed: %v", err) + } + + expected := []byte(`{"Inner":{"Field1":1}}`) + if !bytes.Equal(data, expected) { + t.Fatalf("MarshalJSON failed: got=%s, expected=%s", string(data), string(expected)) + } +} + +func TestInterfaceUnmarshal(t *testing.T) { + data := []byte(`{"Inner":{"Field1":1}}`) + + var inner InterfaceType + wrapper := WrapperType{Inner: &inner} + err := wrapper.UnmarshalJSON(data) + if err != nil { + t.Fatalf("UnmarshalJSON failed: %v", err) + } + + if inner.Field1 != 1 { + t.Fatalf("UnmarshalJSON failed: got=%d, expected=%d", inner.Field1, 1) + } +} + +func TestInterfaceMarshalNil(t *testing.T) { + var inner *InterfaceType + wrapper := WrapperType{ + Inner: inner, + } + + data, err := wrapper.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON failed: %v", err) + } + + expected := []byte(`{"Inner":null}`) + if !bytes.Equal(data, expected) { + t.Fatalf("MarshalJSON failed: got=%s, expected=%s", string(data), string(expected)) + } +} + +func TestInterfaceUnmarshalNil(t *testing.T) { + data := []byte(`{"Inner":null}`) + + var inner InterfaceType + wrapper := WrapperType{Inner: &inner} + err := wrapper.UnmarshalJSON(data) + if err != nil { + t.Fatalf("UnmarshalJSON failed: %v", err) + } + + if inner.Field1 != 0 { + t.Fatalf("UnmarshalJSON failed: got=%d, expected=0", inner.Field1) + } +} diff --git a/tests/interface_type.go b/tests/interface_type.go new file mode 100644 index 0000000..f884be8 --- /dev/null +++ b/tests/interface_type.go @@ -0,0 +1,11 @@ +package tests + +// easyjson:json +type WrapperType struct { + Inner any `json:"Inner"` +} + +// easyjson:json +type InterfaceType struct { + Field1 int `json:"Field1"` +}