diff --git a/encoder.go b/encoder.go index 52f2c10..a9ea413 100644 --- a/encoder.go +++ b/encoder.go @@ -3,12 +3,55 @@ package schema import ( "errors" "fmt" + "net/url" "reflect" "strconv" + "strings" ) type encoderFunc func(reflect.Value) string +// UrlValues represents url.Values which could be encoded with custom order. +type UrlValues struct { + keys []string + values map[string][]string +} + +// Values returns map[string][]string which can be used as url.Values. +func (v *UrlValues) Values() map[string][]string { + return v.values +} + +// Encode encodes the values into URL encoded form ("foo=quux&bar=baz") sorted by custom order. +func (v *UrlValues) Encode() string { + if len(v.values) == 0 { + return "" + } + var buf strings.Builder + for _, k := range v.keys { + vs := v.values[k] + keyEscaped := url.QueryEscape(k) + for _, v := range vs { + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(keyEscaped) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(v)) + } + } + return buf.String() +} + +func (v *UrlValues) removeKey(key string) { + for i, x := range v.keys { + if x == key { + v.keys = append(v.keys[:i], v.keys[i+1:]...) + return + } + } +} + // Encoder encodes values from a struct into url.Values. type Encoder struct { cache *cache @@ -23,14 +66,30 @@ func NewEncoder() *Encoder { // Encode encodes a struct into map[string][]string. // // Intended for use with url.Values. -func (e *Encoder) Encode(src interface{}, dst map[string][]string) error { +func (e *Encoder) Encode(src any, dst map[string][]string) error { v := reflect.ValueOf(src) + values := &UrlValues{ + values: dst, + } - return e.encode(v, dst) + return e.encode(v, values) +} + +// EncodeValues encodes a struct into UrlValues which will keep the order of the struct's fields. +func (e *Encoder) EncodeValues(src any) (*UrlValues, error) { + v := reflect.ValueOf(src) + values := &UrlValues{ + values: map[string][]string{}, + } + + if err := e.encode(v, values); err != nil { + return nil, err + } + return values, nil } // RegisterEncoder registers a converter for encoding a custom type. -func (e *Encoder) RegisterEncoder(value interface{}, encoder func(reflect.Value) string) { +func (e *Encoder) RegisterEncoder(value any, encoder func(reflect.Value) string) { e.regenc[reflect.TypeOf(value)] = encoder } @@ -75,7 +134,7 @@ func isZero(v reflect.Value) bool { return v.Interface() == z.Interface() } -func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error { +func (e *Encoder) encode(v reflect.Value, values *UrlValues) error { if v.Kind() == reflect.Ptr { v = v.Elem() } @@ -94,7 +153,7 @@ func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error { // Encode struct pointer types if the field is a valid pointer and a struct. if isValidStructPointer(v.Field(i)) && !e.hasCustomEncoder(v.Field(i).Type()) { - err := e.encode(v.Field(i).Elem(), dst) + err := e.encode(v.Field(i).Elem(), values) if err != nil { errors[v.Field(i).Elem().Type().String()] = err } @@ -110,12 +169,15 @@ func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error { continue } - dst[name] = append(dst[name], value) + if _, ok := values.values[name]; !ok { + values.keys = append(values.keys, name) + } + values.values[name] = append(values.values[name], value) continue } if v.Field(i).Type().Kind() == reflect.Struct { - err := e.encode(v.Field(i), dst) + err := e.encode(v.Field(i), values) if err != nil { errors[v.Field(i).Type().String()] = err } @@ -132,13 +194,18 @@ func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error { } // Encode a slice. - if v.Field(i).Len() == 0 && opts.Contains("omitempty") { + sliceLen := v.Field(i).Len() + if sliceLen == 0 && opts.Contains("omitempty") { continue } - dst[name] = []string{} - for j := 0; j < v.Field(i).Len(); j++ { - dst[name] = append(dst[name], encFunc(v.Field(i).Index(j))) + if _, ok := values.values[name]; ok { + values.removeKey(name) + } + values.keys = append(values.keys, name) + values.values[name] = make([]string, 0, sliceLen) + for j := 0; j < sliceLen; j++ { + values.values[name] = append(values.values[name], encFunc(v.Field(i).Index(j))) } } diff --git a/encoder_test.go b/encoder_test.go index 092f0de..683c6b7 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -523,3 +523,59 @@ func TestRegisterEncoderWithPtrType(t *testing.T) { valExists(t, "DateStart", ss.DateStart.time.String(), vals) valExists(t, "DateEnd", "", vals) } + +func TestUrlValues(t *testing.T) { + v1 := UrlValues{ + keys: []string{"a"}, + values: map[string][]string{ + "a": {"some&value"}, + }, + } + v1Encoded, v1Expect := v1.Encode(), "a=some%26value" + if v1Encoded != v1Expect { + t.Fatalf("Expected: %v, got: %v", v1Expect, v1Encoded) + } + + v2 := UrlValues{ + keys: []string{"z", "a", "s", "x"}, + values: map[string][]string{ + "a": {"valueA", "value%b"}, + "s": {"valueS"}, + "x": {""}, + "z": {"value$Z"}, + }, + } + v2Encoded, v2Expect := v2.Encode(), "z=value%24Z&a=valueA&a=value%25b&s=valueS&x=" + if v2Encoded != v2Expect { + t.Fatalf("Expected: %v, got: %v", v2Expect, v2Encoded) + } +} + +func TestEncodeValues(t *testing.T) { + type S1 struct { + Order []string `schema:"order"` + Asc int `schema:"asc"` + PubKey string `schema:"pubkey"` + Method string `schema:"method"` + } + + s1 := S1{ + Order: []string{"name1", "name2"}, + Asc: 1, + PubKey: "example-pubkey-foobar", + Method: "HMAC-256", + } + + encoder := NewEncoder() + values, err := encoder.EncodeValues(s1) + noError(t, err) + expectOrder := []string{"order", "asc", "pubkey", "method"} + if len(values.keys) != len(expectOrder) { + t.Fatalf("Expected length of %v, but got %v", len(expectOrder), len(values.keys)) + } + for i, k := range values.keys { + if expectOrder[i] != k { + t.Fatalf("Expected: %v, got: %v", expectOrder[i], k) + } + } +}