From ef3b04adac2806421c5975211b280dea971eab02 Mon Sep 17 00:00:00 2001 From: Elliot Courant Date: Fri, 4 Apr 2025 20:07:58 -0500 Subject: [PATCH] feat(map): Handle json.Number when merging maps If you unmarshall a json object using the json.Decoder and the UseNumber function, it will create numbers on the map as a `json.Number` object. This object has helper functions for handling ints and floats. So if a json.Number is encountered, try to handle the type of the number based on the destination. --- json_test.go | 36 ++++++++++++++++++++++++++++++++++++ map.go | 24 ++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 json_test.go diff --git a/json_test.go b/json_test.go new file mode 100644 index 0000000..610ab84 --- /dev/null +++ b/json_test.go @@ -0,0 +1,36 @@ +package mergo_test + +import ( + "bytes" + "encoding/json" + "testing" + + "dario.cat/mergo" +) + +func TestJsonNumber(t *testing.T) { + jsonSampleData := ` +{ + "amount": 1234 +} +` + + type Data struct { + Amount int64 `json:"amount"` + } + + foo := make(map[string]interface{}) + + decoder := json.NewDecoder(bytes.NewReader([]byte(jsonSampleData))) + decoder.UseNumber() + decoder.Decode(&foo) + + data := Data{} + err := mergo.Map(&data, foo) + if err != nil { + t.Errorf("failed to merge with json.Number: %+v", err) + } + if data.Amount != 1234 { + t.Errorf("merged amount does not match the json value! expected: 1234 got: %v", data.Amount) + } +} diff --git a/map.go b/map.go index 759b4f7..29fc962 100644 --- a/map.go +++ b/map.go @@ -9,6 +9,7 @@ package mergo import ( + "encoding/json" "fmt" "reflect" "unicode" @@ -111,6 +112,29 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, conf return } } else { + // If the map side of the merge is a json number then we can use the + // functions on the json number to merge the data depending on the + // destination type. + if number, ok := srcElement.Interface().(json.Number); ok { + switch dstKind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + value, err := number.Int64() + if err != nil { + return fmt.Errorf("...: %+v", err) + } + dstElement.SetInt(value) + continue + case reflect.Float32, reflect.Float64: + value, err := number.Float64() + if err != nil { + return fmt.Errorf("...: %+v", err) + } + dstElement.SetFloat(value) + continue + } + } + // But if we can't do that then fallback to the normal type mismatch + // failure. return fmt.Errorf("type mismatch on %s field: found %v, expected %v", fieldName, srcKind, dstKind) } }