Skip to content

Add MarshalBinary/UnmarshalBinary interface support #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: ismail/any_amino
Choose a base branch
from
50 changes: 50 additions & 0 deletions amino.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package amino

import (
"bytes"
"encoding"
"fmt"
"io"
"reflect"
Expand Down Expand Up @@ -212,6 +213,26 @@ func (cdc *Codec) MarshalBinaryBare(o interface{}) ([]byte, error) {
if err != nil {
return nil, err
}

if rv.Type().Implements(binaryMarshalerType) {
if info.Registered {
pb := info.Prefix.Bytes()
buf.Write(pb)
}

bz, err := rv.Interface().(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
return nil, err
}

_, err = buf.Write(bz)
if err != nil {
return nil, err
}

return buf.Bytes(), nil
}

// in the case of of a repeated struct (e.g. type Alias []SomeStruct),
// we do not need to prepend with `(field_number << 3) | wire_type` as this
// would need to be done for each struct and not only for the first.
Expand Down Expand Up @@ -398,6 +419,35 @@ func (cdc *Codec) UnmarshalBinaryBare(bz []byte, ptr interface{}) error {
return err
}

if rv.CanAddr() {
addr := rv.Addr()
if addr.Type().Implements(binaryUnmarshalerType) {
bz2 := bz

if info.Registered {
pb := info.Prefix.Bytes()
l := len(pb)
if len(bz) < l {
return fmt.Errorf(
"unmarshalBinaryBare expected to read prefix bytes %X (since it is registered concrete) but got %X",
pb, bz,
)
} else {
pb2 := bz[:l]
bz2 = bz[l:]
if !bytes.Equal(pb2, pb) {
return fmt.Errorf(
"unmarshalBinaryBare expected to read prefix bytes %X (since it is registered concrete) but got %X",
pb, pb2,
)
}
}
}

return addr.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(bz2)
}
}

// If registered concrete, consume and verify prefix bytes.
if info.Registered {
aminoAny := &RegisteredAny{}
Expand Down
44 changes: 44 additions & 0 deletions binary_encode_override_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package amino_test

import (
"github.com/stretchr/testify/assert"
amino "github.com/tendermint/go-amino"
"testing"
)

type Thing struct {
Name string
}

// func (thing Thing) MarshalAmino() (string, error) {
// return thing.Name, nil
// }
//
// func (thing Thing) UnmarshalAmino(name string) error {
// thing.Name = name
// return nil
// }

func (thing Thing) MarshalBinary() ([]byte, error) {
return []byte(thing.Name), nil
}

func (thing *Thing) UnmarshalBinary(bz []byte) error {
thing.Name = string(bz)
return nil
}

func TestMarshalBinaryOverride(t *testing.T) {
var cdc = amino.NewCodec()
cdc.RegisterConcrete(&Thing{}, "amino/thing", nil)

thing1 := Thing{Name: "a"}

bz, err := cdc.MarshalBinaryBare(thing1)
assert.Nil(t, err)

var thing2 Thing
err = cdc.UnmarshalBinaryBare(bz, &thing2)
assert.Nil(t, err)
assert.Equal(t, thing1, thing2)
}
11 changes: 7 additions & 4 deletions reflect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package amino

import (
"encoding"
"encoding/json"
"fmt"
"reflect"
Expand All @@ -13,10 +14,12 @@ import (
const printLog = false

var (
timeType = reflect.TypeOf(time.Time{})
jsonMarshalerType = reflect.TypeOf(new(json.Marshaler)).Elem()
jsonUnmarshalerType = reflect.TypeOf(new(json.Unmarshaler)).Elem()
errorType = reflect.TypeOf(new(error)).Elem()
timeType = reflect.TypeOf(time.Time{})
jsonMarshalerType = reflect.TypeOf(new(json.Marshaler)).Elem()
jsonUnmarshalerType = reflect.TypeOf(new(json.Unmarshaler)).Elem()
binaryMarshalerType = reflect.TypeOf(new(encoding.BinaryMarshaler)).Elem()
binaryUnmarshalerType = reflect.TypeOf(new(encoding.BinaryUnmarshaler)).Elem()
errorType = reflect.TypeOf(new(error)).Elem()
)

//----------------------------------------
Expand Down