Skip to content

Commit 5dc674c

Browse files
authored
fix: check enum type recursively (#64)
1 parent 4378967 commit 5dc674c

4 files changed

Lines changed: 139 additions & 2 deletions

File tree

internal/defs/types.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ func (self *Type) Tag() Tag {
127127
}
128128
}
129129

130+
func (self *Type) IsEnum() bool {
131+
switch self.T {
132+
case T_enum:
133+
return true
134+
case T_pointer:
135+
return self.V.IsEnum()
136+
default:
137+
return false
138+
}
139+
}
140+
130141
func (self *Type) Free() {
131142
typePool.Put(self)
132143
}

internal/defs/types_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,25 @@ func TestTypes_MapKeyType(t *testing.T) {
3737
require.NoError(t, err)
3838
fmt.Println(tt)
3939
}
40+
41+
func TestTypes_Enum(t *testing.T) {
42+
type EnumType int64
43+
type Int32 int32
44+
type StructWithEnum struct {
45+
A EnumType `frugal:"1,optional,EnumType"`
46+
B *EnumType `frugal:"2,optional,EnumType"`
47+
C Int32 `frugal:"3,optional,Int32"`
48+
D int64 `frugal:"4,optional,i64"`
49+
}
50+
ff, err := DoResolveFields(reflect.TypeOf(StructWithEnum{}))
51+
require.NoError(t, err)
52+
require.Len(t, ff, 4)
53+
require.True(t, ff[0].Type.IsEnum())
54+
require.Equal(t, ff[0].Type.T, T_enum)
55+
require.True(t, ff[1].Type.IsEnum())
56+
require.Equal(t, ff[1].Type.T, T_pointer)
57+
require.Equal(t, ff[1].Type.V.T, T_enum)
58+
require.False(t, ff[2].Type.IsEnum())
59+
require.False(t, ff[3].Type.IsEnum())
60+
61+
}

internal/reflect/decoder_test.go

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ import (
2727
"github.com/stretchr/testify/require"
2828
)
2929

30-
func TestDecode(t *testing.T) {
30+
func init() {
3131
rand.Seed(time.Now().Unix())
32+
}
3233

34+
func TestDecode(t *testing.T) {
3335
type testcase struct {
3436
name string
3537
update func(p *TestTypes)
@@ -197,6 +199,108 @@ func TestDecode(t *testing.T) {
197199
}
198200
}
199201

202+
func TestDecodeOptional(t *testing.T) {
203+
type testcase struct {
204+
name string
205+
update func(p *TestTypesOptional)
206+
test func(t *testing.T, p1 *TestTypesOptional)
207+
}
208+
209+
var (
210+
vInt16 = int16(rand.Uint32() & 0xffff)
211+
vInt32 = int32(rand.Uint32())
212+
vInt64 = int64(rand.Uint64())
213+
vFloat64 = math.Float64frombits(rand.Uint64())
214+
vTrue = true
215+
vString = "hello"
216+
vByte = int8(0x55)
217+
vEnum = Numberz(int32(rand.Uint32()))
218+
)
219+
220+
for math.IsNaN(vFloat64) { // fix test failure
221+
vFloat64 = math.Float64frombits(rand.Uint64())
222+
}
223+
224+
testcases := []testcase{
225+
{
226+
name: "case_bool",
227+
update: func(p0 *TestTypesOptional) { p0.FBool = &vTrue },
228+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vTrue, *p1.FBool) },
229+
},
230+
{
231+
name: "case_string",
232+
update: func(p0 *TestTypesOptional) { p0.String_ = &vString },
233+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vString, *p1.String_) },
234+
},
235+
{
236+
name: "case_byte",
237+
update: func(p0 *TestTypesOptional) { p0.FByte = &vByte },
238+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vByte, *p1.FByte) },
239+
},
240+
{
241+
name: "case_int8",
242+
update: func(p0 *TestTypesOptional) { p0.I8 = &vByte },
243+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vByte, *p1.I8) },
244+
},
245+
{
246+
name: "case_int16",
247+
update: func(p0 *TestTypesOptional) { p0.I16 = &vInt16 },
248+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt16, *p1.I16) },
249+
},
250+
{
251+
name: "case_int32",
252+
update: func(p0 *TestTypesOptional) { p0.I32 = &vInt32 },
253+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt32, *p1.I32) },
254+
},
255+
{
256+
name: "case_int64",
257+
update: func(p0 *TestTypesOptional) { p0.I64 = &vInt64 },
258+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt64, *p1.I64) },
259+
},
260+
{
261+
name: "case_float64",
262+
update: func(p0 *TestTypesOptional) { p0.Double = &vFloat64 },
263+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vFloat64, *p1.Double) },
264+
},
265+
{
266+
name: "case_enum",
267+
update: func(p0 *TestTypesOptional) { p0.Enum = &vEnum },
268+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vEnum, *p1.Enum) },
269+
},
270+
{
271+
name: "case_typedef",
272+
update: func(p0 *TestTypesOptional) { p0.UID = &vInt64 },
273+
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt64, *p1.UID) },
274+
},
275+
}
276+
for _, tc := range testcases {
277+
name := tc.name
278+
updatef := tc.update
279+
testf := tc.test
280+
t.Run(name, func(t *testing.T) {
281+
p0 := NewTestTypesOptional()
282+
updatef(p0) // update by testcase func
283+
284+
b := make([]byte, EncodedSize(p0))
285+
n, err := Encode(b, p0)
286+
require.NoError(t, err)
287+
require.Equal(t, len(b), n)
288+
289+
// verify by gopkg thrift
290+
n, err = thrift.Binary.Skip(b, thrift.TType(tSTRUCT))
291+
require.NoError(t, err)
292+
require.Equal(t, n, len(b))
293+
294+
p1 := &TestTypesOptional{}
295+
n, err = Decode(b, p1)
296+
require.NoError(t, err)
297+
require.Equal(t, len(b), n)
298+
299+
testf(t, p1) // test by testcase func
300+
})
301+
}
302+
}
303+
200304
func TestDecodeRequired(t *testing.T) {
201305
type S0 struct {
202306
V *bool `frugal:"1,optional,bool"`

internal/reflect/ttype.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func newTType(x *defs.Type) *tType {
160160
t.T = ttype(x.Tag())
161161
t.WT = t.T
162162
t.Tag = x.T
163-
if t.Tag == defs.T_enum {
163+
if x.IsEnum() {
164164
t.T = tENUM
165165
}
166166
t.RT = x.S

0 commit comments

Comments
 (0)