Skip to content

Commit 47fff9b

Browse files
authored
fix(binding): use UnmarshalText for types like enum (#1359)
1 parent 5513927 commit 47fff9b

File tree

4 files changed

+275
-8
lines changed

4 files changed

+275
-8
lines changed

pkg/app/server/binding/binder_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
package binding
4242

4343
import (
44+
"encoding"
4445
"encoding/json"
46+
"errors"
4547
"fmt"
4648
"mime/multipart"
4749
"net/url"
@@ -1665,6 +1667,41 @@ func TestBind_NormalizeContentType(t *testing.T) {
16651667
assert.DeepEqual(t, "version", result.Version)
16661668
}
16671669

1670+
type TestEnumType int32
1671+
1672+
var _ encoding.TextUnmarshaler = (*TestEnumType)(nil)
1673+
1674+
func (p *TestEnumType) UnmarshalText(v []byte) error {
1675+
switch string(v) {
1676+
case "one":
1677+
*p = 1
1678+
case "two":
1679+
*p = 2
1680+
default:
1681+
return errors.New("invalid")
1682+
}
1683+
return nil
1684+
}
1685+
1686+
func TestBind_TextUnmarshaler(t *testing.T) {
1687+
type Query struct {
1688+
A TestEnumType `query:"a"`
1689+
B TestEnumType `query:"b"`
1690+
C *TestEnumType `query:"c"`
1691+
D *TestEnumType `query:"d"`
1692+
}
1693+
q := &Query{}
1694+
req := newMockRequest().SetRequestURI("http://example.com?a=1&b=one&c=2&d=two")
1695+
err := DefaultBinder().BindQuery(req.Req, q)
1696+
assert.Nil(t, err)
1697+
assert.DeepEqual(t, TestEnumType(1), q.A)
1698+
assert.DeepEqual(t, TestEnumType(1), q.B)
1699+
assert.NotNil(t, q.C)
1700+
assert.NotNil(t, q.D)
1701+
assert.DeepEqual(t, TestEnumType(2), *q.C)
1702+
assert.DeepEqual(t, TestEnumType(2), *q.D)
1703+
}
1704+
16681705
func Benchmark_Binding(b *testing.B) {
16691706
type Req struct {
16701707
Version string `path:"v"`

pkg/app/server/binding/internal/decoder/base_type_decoder.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa
123123
}
124124

125125
// Non-pointer elems
126+
if field.CanAddr() {
127+
if tryTextUnmarshaler(field.Addr(), text) {
128+
return nil
129+
}
130+
}
126131
err = d.decoder.UnmarshalString(text, field, d.config.LooseZeroMode)
127132
if err != nil {
128133
return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err)

pkg/app/server/binding/internal/decoder/util.go

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package decoder
1818

1919
import (
20+
"encoding"
2021
"fmt"
2122
"reflect"
2223
"strings"
@@ -45,32 +46,50 @@ func toDefaultValue(typ reflect.Type, defaultValue string) string {
4546
}
4647

4748
// stringToValue is used to dynamically create reflect.Value for 'text'
48-
func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params, config *DecodeConfig) (v reflect.Value, err error) {
49-
v = reflect.New(elemType).Elem()
49+
func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params, config *DecodeConfig) (reflect.Value, error) {
5050
if customizedFunc, exist := config.TypeUnmarshalFuncs[elemType]; exist {
5151
val, err := customizedFunc(req, params, text)
5252
if err != nil {
5353
return reflect.Value{}, err
5454
}
5555
return val, nil
5656
}
57+
v := reflect.New(elemType)
58+
if tryTextUnmarshaler(v, text) {
59+
return v.Elem(), nil
60+
}
5761
switch elemType.Kind() {
58-
case reflect.Struct:
59-
err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface())
60-
case reflect.Map:
61-
err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface())
62+
case reflect.Struct, reflect.Map:
63+
if err := hJson.Unmarshal(bytesconv.S2b(text), v.Interface()); err != nil {
64+
return reflect.Value{}, err
65+
}
66+
return v.Elem(), nil
67+
6268
case reflect.Array, reflect.Slice:
6369
// do nothing
70+
return v.Elem(), nil
71+
6472
default:
6573
decoder, err := SelectTextDecoder(elemType)
6674
if err != nil {
67-
return reflect.Value{}, fmt.Errorf("unsupported type %s for slice/array", elemType.String())
75+
return reflect.Value{}, err
6876
}
77+
v = v.Elem()
6978
err = decoder.UnmarshalString(text, v, config.LooseZeroMode)
7079
if err != nil {
7180
return reflect.Value{}, fmt.Errorf("unable to decode '%s' as %s: %w", text, elemType.String(), err)
7281
}
82+
return v, nil
7383
}
7484

75-
return v, err
85+
}
86+
87+
func tryTextUnmarshaler(v reflect.Value, s string) bool {
88+
enc, ok := v.Interface().(encoding.TextUnmarshaler)
89+
if ok {
90+
if err := enc.UnmarshalText(bytesconv.S2b(s)); err == nil {
91+
return true
92+
}
93+
}
94+
return false
7695
}
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/*
2+
* Copyright 2024 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package decoder
18+
19+
import (
20+
"encoding"
21+
"errors"
22+
"reflect"
23+
"testing"
24+
25+
"github.com/cloudwego/hertz/pkg/common/test/assert"
26+
"github.com/cloudwego/hertz/pkg/protocol"
27+
"github.com/cloudwego/hertz/pkg/route/param"
28+
)
29+
30+
type testTextUnmarshaler struct {
31+
Value string
32+
}
33+
34+
func (t *testTextUnmarshaler) UnmarshalText(text []byte) error {
35+
t.Value = string(text)
36+
return nil
37+
}
38+
39+
var _ encoding.TextUnmarshaler = (*testTextUnmarshaler)(nil)
40+
41+
func TestStringToValue(t *testing.T) {
42+
tests := []struct {
43+
name string
44+
elemType reflect.Type
45+
text string
46+
config *DecodeConfig
47+
expectValue interface{}
48+
expectError bool
49+
}{
50+
{
51+
name: "string type",
52+
elemType: reflect.TypeOf(""),
53+
text: "test string",
54+
expectValue: "test string",
55+
},
56+
{
57+
name: "int type",
58+
elemType: reflect.TypeOf(0),
59+
text: "42",
60+
expectValue: 42,
61+
},
62+
{
63+
name: "bool type",
64+
elemType: reflect.TypeOf(false),
65+
text: "true",
66+
expectValue: true,
67+
},
68+
{
69+
name: "float type",
70+
elemType: reflect.TypeOf(0.0),
71+
text: "3.14",
72+
expectValue: 3.14,
73+
},
74+
{
75+
name: "text unmarshaler",
76+
elemType: reflect.TypeOf(testTextUnmarshaler{}),
77+
text: "custom text",
78+
expectValue: testTextUnmarshaler{Value: "custom text"},
79+
},
80+
{
81+
name: "invalid int",
82+
elemType: reflect.TypeOf(0),
83+
text: "not an int",
84+
expectError: true,
85+
},
86+
{
87+
name: "struct type",
88+
elemType: reflect.TypeOf(struct{ Name string }{}),
89+
text: `{"Name":"test"}`,
90+
expectValue: struct{ Name string }{Name: "test"},
91+
},
92+
{
93+
name: "struct type err",
94+
elemType: reflect.TypeOf(struct{ Name string }{}),
95+
text: `{"Name":1}`,
96+
expectError: true,
97+
},
98+
{
99+
name: "list type",
100+
elemType: reflect.TypeOf([]int{}),
101+
expectValue: *new([]int),
102+
},
103+
{
104+
name: "map type",
105+
elemType: reflect.TypeOf(map[string]interface{}{}),
106+
text: `{"key":"value"}`,
107+
expectValue: map[string]interface{}{"key": "value"},
108+
},
109+
{
110+
name: "unsupported type",
111+
elemType: reflect.TypeOf(complex64(0)),
112+
expectError: true,
113+
},
114+
{
115+
name: "custom type unmarshal func",
116+
elemType: reflect.TypeOf(testTextUnmarshaler{}),
117+
text: "custom func",
118+
config: &DecodeConfig{
119+
TypeUnmarshalFuncs: map[reflect.Type]CustomizeDecodeFunc{
120+
reflect.TypeOf(testTextUnmarshaler{}): func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) {
121+
return reflect.ValueOf(testTextUnmarshaler{Value: "from custom func"}), nil
122+
},
123+
},
124+
},
125+
expectValue: testTextUnmarshaler{Value: "from custom func"},
126+
},
127+
{
128+
name: "custom type unmarshal func err",
129+
elemType: reflect.TypeOf(testTextUnmarshaler{}),
130+
config: &DecodeConfig{
131+
TypeUnmarshalFuncs: map[reflect.Type]CustomizeDecodeFunc{
132+
reflect.TypeOf(testTextUnmarshaler{}): func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) {
133+
return reflect.Value{}, errors.New("err")
134+
},
135+
},
136+
},
137+
expectError: true,
138+
},
139+
}
140+
141+
for _, tt := range tests {
142+
t.Run(tt.name, func(t *testing.T) {
143+
req := &protocol.Request{}
144+
params := param.Params{}
145+
config := tt.config
146+
if config == nil {
147+
config = &DecodeConfig{}
148+
}
149+
val, err := stringToValue(tt.elemType, tt.text, req, params, config)
150+
if tt.expectError {
151+
assert.NotNil(t, err)
152+
return
153+
}
154+
assert.Nil(t, err)
155+
assert.DeepEqual(t, tt.expectValue, val.Interface())
156+
})
157+
}
158+
}
159+
160+
func TestTryTextUnmarshaler(t *testing.T) {
161+
tests := []struct {
162+
name string
163+
value interface{}
164+
text string
165+
expected bool
166+
}{
167+
{
168+
name: "text unmarshaler",
169+
value: &testTextUnmarshaler{},
170+
text: "test text",
171+
expected: true,
172+
},
173+
{
174+
name: "non text unmarshaler",
175+
value: &struct{}{},
176+
text: "test text",
177+
expected: false,
178+
},
179+
{
180+
name: "nil value",
181+
value: nil,
182+
text: "test text",
183+
expected: false,
184+
},
185+
}
186+
187+
for _, tt := range tests {
188+
t.Run(tt.name, func(t *testing.T) {
189+
var v reflect.Value
190+
if tt.value != nil {
191+
v = reflect.ValueOf(tt.value)
192+
} else {
193+
v = reflect.ValueOf(&tt.value).Elem()
194+
}
195+
196+
result := tryTextUnmarshaler(v, tt.text)
197+
assert.DeepEqual(t, tt.expected, result)
198+
199+
if tt.expected && tt.value != nil {
200+
// Verify the value was actually set
201+
unmarshaler := tt.value.(*testTextUnmarshaler)
202+
assert.DeepEqual(t, tt.text, unmarshaler.Value)
203+
}
204+
})
205+
}
206+
}

0 commit comments

Comments
 (0)