diff --git a/README.md b/README.md index 9c75dd9a..7e64ed53 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,7 @@ types: * integers * `bool` * `string` +* `[]byte` * slices * `trino.Numeric` - a string representation of a number * `time.Time` - passed to Trino as a timestamp with a time zone @@ -324,7 +325,11 @@ following types: * `json.Number` for any numeric Trino types * `[]interface{}` for Trino arrays * `map[string]interface{}` for Trino maps -* `string` for other Trino types, as character, date, time, or timestamp +* `string` for other Trino types, as character, date, time, or timestamp. + +> [!NOTE] +> `VARBINARY` columns are returned as base64-encoded strings when used within +> `ROW`, `MAP`, or `ARRAY` values. ## License diff --git a/trino/integration_test.go b/trino/integration_test.go index d97720c5..6426e74c 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -24,6 +24,7 @@ import ( "crypto/x509/pkix" "database/sql" "database/sql/driver" + "encoding/json" "encoding/pem" "errors" "flag" @@ -34,6 +35,7 @@ import ( "math/big" "net/http" "os" + "reflect" "strconv" "strings" "testing" @@ -110,6 +112,14 @@ func TestMain(m *testing.M) { "8080/tcp", "8443/tcp", }, + }, func(hc *docker.HostConfig) { + hc.Ulimits = []docker.ULimit{ + { + Name: "nofile", + Hard: 4096, + Soft: 4096, + }, + } }) if err != nil { log.Fatalf("Could not start resource: %s", err) @@ -545,6 +555,8 @@ func TestIntegrationTypeConversion(t *testing.T) { var ( goTime time.Time nullTime NullTime + goBytes []byte + nullBytes []byte goString string nullString sql.NullString nullStringSlice NullSliceString @@ -564,6 +576,8 @@ func TestIntegrationTypeConversion(t *testing.T) { SELECT TIMESTAMP '2017-07-10 01:02:03.004 UTC', CAST(NULL AS TIMESTAMP), + CAST(X'FFFF0FFF3FFFFFFF' AS VARBINARY), + CAST(NULL AS VARBINARY), CAST('string' AS VARCHAR), CAST(NULL AS VARCHAR), ARRAY['A', 'B', NULL], @@ -581,6 +595,8 @@ func TestIntegrationTypeConversion(t *testing.T) { `).Scan( &goTime, &nullTime, + &goBytes, + &nullBytes, &goString, &nullString, &nullStringSlice, @@ -599,6 +615,172 @@ func TestIntegrationTypeConversion(t *testing.T) { if err != nil { t.Fatal(err) } + + // Compare the actual and expected values. + expectedTime := time.Date(2017, 7, 10, 1, 2, 3, 4*1000000, time.UTC) + if !goTime.Equal(expectedTime) { + t.Errorf("expected GoTime to be %v, got %v", expectedTime, goTime) + } + + expectedBytes := []byte{0xff, 0xff, 0x0f, 0xff, 0x3f, 0xff, 0xff, 0xff} + if !bytes.Equal(goBytes, expectedBytes) { + t.Errorf("expected GoBytes to be %v, got %v", expectedBytes, goBytes) + } + + if nullBytes != nil { + t.Errorf("expected NullBytes to be nil, got %v", nullBytes) + } + + if goString != "string" { + t.Errorf("expected GoString to be %q, got %q", "string", goString) + } + + if nullString.Valid { + t.Errorf("expected NullString.Valid to be false, got true") + } + + if !reflect.DeepEqual(nullStringSlice.SliceString, []sql.NullString{{String: "A", Valid: true}, {String: "B", Valid: true}, {Valid: false}}) { + t.Errorf("expected NullStringSlice.SliceString to be %v, got %v", + []sql.NullString{{String: "A", Valid: true}, {String: "B", Valid: true}, {Valid: false}}, + nullStringSlice.SliceString) + } + if !nullStringSlice.Valid { + t.Errorf("expected NullStringSlice.Valid to be true, got false") + } + + expectedSlice2String := [][]sql.NullString{{{String: "A", Valid: true}}, {}} + if !reflect.DeepEqual(nullStringSlice2.Slice2String, expectedSlice2String) { + t.Errorf("expected NullStringSlice2.Slice2String to be %v, got %v", expectedSlice2String, nullStringSlice2.Slice2String) + } + if !nullStringSlice2.Valid { + t.Errorf("expected NullStringSlice2.Valid to be true, got false") + } + + expectedSlice3String := [][][]sql.NullString{{{{String: "A", Valid: true}}, {}}, {}} + if !reflect.DeepEqual(nullStringSlice3.Slice3String, expectedSlice3String) { + t.Errorf("expected NullStringSlice3.Slice3String to be %v, got %v", expectedSlice3String, nullStringSlice3.Slice3String) + } + if !nullStringSlice3.Valid { + t.Errorf("expected NullStringSlice3.Valid to be true, got false") + } + + expectedSliceInt64 := []sql.NullInt64{{Int64: 1, Valid: true}, {Int64: 2, Valid: true}, {Valid: false}} + if !reflect.DeepEqual(nullInt64Slice.SliceInt64, expectedSliceInt64) { + t.Errorf("expected NullInt64Slice.SliceInt64 to be %v, got %v", expectedSliceInt64, nullInt64Slice.SliceInt64) + } + if !nullInt64Slice.Valid { + t.Errorf("expected NullInt64Slice.Valid to be true, got false") + } + + expectedSlice2Int64 := [][]sql.NullInt64{{{Int64: 1, Valid: true}, {Int64: 1, Valid: true}, {Int64: 1, Valid: true}}, {}} + if !reflect.DeepEqual(nullInt64Slice2.Slice2Int64, expectedSlice2Int64) { + t.Errorf("expected NullInt64Slice2.Slice2Int64 to be %v, got %v", expectedSlice2Int64, nullInt64Slice2.Slice2Int64) + } + if !nullInt64Slice2.Valid { + t.Errorf("expected NullInt64Slice2.Valid to be true, got false") + } + + expectedSlice3Int64 := [][][]sql.NullInt64{{{{Int64: 1, Valid: true}, {Int64: 1, Valid: true}, {Int64: 1, Valid: true}}, {}}, {}} + if !reflect.DeepEqual(nullInt64Slice3.Slice3Int64, expectedSlice3Int64) { + t.Errorf("expected NullInt64Slice3.Slice3Int64 to be %v, got %v", expectedSlice3Int64, nullInt64Slice3.Slice3Int64) + } + if !nullInt64Slice3.Valid { + t.Errorf("expected NullInt64Slice3.Valid to be true, got false") + } + + expectedSliceFloat64 := []sql.NullFloat64{{Float64: 1.0, Valid: true}, {Float64: 2.0, Valid: true}, {Valid: false}} + if !reflect.DeepEqual(nullFloat64Slice.SliceFloat64, expectedSliceFloat64) { + t.Errorf("expected NullFloat64Slice.SliceFloat64 to be %v, got %v", expectedSliceFloat64, nullFloat64Slice.SliceFloat64) + } + if !nullFloat64Slice.Valid { + t.Errorf("expected NullFloat64Slice.Valid to be true, got false") + } + + expectedSlice2Float64 := [][]sql.NullFloat64{{{Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}}, {}} + if !reflect.DeepEqual(nullFloat64Slice2.Slice2Float64, expectedSlice2Float64) { + t.Errorf("expected NullFloat64Slice2.Slice2Float64 to be %v, got %v", expectedSlice2Float64, nullFloat64Slice2.Slice2Float64) + } + if !nullFloat64Slice2.Valid { + t.Errorf("expected NullFloat64Slice2.Valid to be true, got false") + } + + expectedSlice3Float64 := [][][]sql.NullFloat64{{{{Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}}, {}}, {}} + if !reflect.DeepEqual(nullFloat64Slice3.Slice3Float64, expectedSlice3Float64) { + t.Errorf("expected NullFloat64Slice3.Slice3Float64 to be %v, got %v", expectedSlice3Float64, nullFloat64Slice3.Slice3Float64) + } + if !nullFloat64Slice3.Valid { + t.Errorf("expected NullFloat64Slice3.Valid to be true, got false") + } + + expectedMap := map[string]interface{}{"a": "c", "b": "d"} + if !reflect.DeepEqual(goMap, expectedMap) { + t.Errorf("expected GoMap to be %v, got %v", expectedMap, goMap) + } + + if nullMap.Valid { + t.Errorf("expected NullMap.Valid to be false, got true") + } + + expectedRow := []interface{}{json.Number("1"), "a", "2017-07-10 01:02:03.004000 UTC", []interface{}{"c"}} + if !reflect.DeepEqual(goRow, expectedRow) { + t.Errorf("expected GoRow to be %v, got %v", expectedRow, goRow) + } +} + +func TestComplexTypes(t *testing.T) { + // This test has been created to showcase some issues with parsing + // complex types. It is not intended to be a comprehensive test of + // the parsing logic, but rather to provide a reference for future + // changes to the parsing logic. + // + // The current implementation of the parsing logic reads the value + // in the same format as the JSON response from Trino. This means + // that we don't go further to parse values as their structured types. + // For example, a row like `ROW(1, X'0000')` is read as + // a list of a `json.Number(1)` and a base64-encoded string. + t.Skip("skipping failing test") + + dsn := *integrationServerFlag + db := integrationOpen(t, dsn) + + for _, tt := range []struct { + name string + query string + expected interface{} + }{ + { + name: "row containing scalar values", + query: `SELECT ROW(1, 'a', X'0000')`, + expected: []interface{}{1, "a", []byte{0x00, 0x00}}, + }, + { + name: "nested row", + query: `SELECT ROW(ROW(1, 'a'), ROW(2, 'b'))`, + expected: []interface{}{[]interface{}{1, "a"}, []interface{}{2, "b"}}, + }, + { + name: "map with scalar values", + query: `SELECT MAP(ARRAY['a', 'b'], ARRAY[1, 2])`, + expected: map[string]interface{}{"a": 1, "b": 2}, + }, + { + name: "map with nested row", + query: `SELECT MAP(ARRAY['a', 'b'], ARRAY[ROW(1, 'a'), ROW(2, 'b')])`, + expected: map[string]interface{}{"a": []interface{}{1, "a"}, "b": []interface{}{2, "b"}}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + var result interface{} + err := db.QueryRow(tt.query).Scan(&result) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } } func TestIntegrationArgsConversion(t *testing.T) { @@ -615,8 +797,9 @@ func TestIntegrationArgsConversion(t *testing.T) { CAST(1 AS DOUBLE), TIMESTAMP '2017-07-10 01:02:03.004 UTC', CAST('string' AS VARCHAR), + CAST(X'FFFF0FFF3FFFFFFF' AS VARBINARY), ARRAY['A', 'B'] - )) AS t(col_tiny, col_small, col_int, col_big, col_real, col_double, col_ts, col_varchar, col_array ) + )) AS t(col_tiny, col_small, col_int, col_big, col_real, col_double, col_ts, col_varchar, col_varbinary, col_array ) WHERE 1=1 AND col_tiny = ? AND col_small = ? @@ -626,6 +809,7 @@ func TestIntegrationArgsConversion(t *testing.T) { AND col_double = cast(? as double) AND col_ts = ? AND col_varchar = ? + AND col_varbinary = ? AND col_array = ?`, int16(1), int16(1), @@ -635,7 +819,9 @@ func TestIntegrationArgsConversion(t *testing.T) { Numeric("1"), time.Date(2017, 7, 10, 1, 2, 3, 4*1000000, time.UTC), "string", - []string{"A", "B"}).Scan(&value) + []byte{0xff, 0xff, 0x0f, 0xff, 0x3f, 0xff, 0xff, 0xff}, + []string{"A", "B"}, + ).Scan(&value) if err != nil { t.Fatal(err) } diff --git a/trino/serial.go b/trino/serial.go index 5a778a55..579f587b 100644 --- a/trino/serial.go +++ b/trino/serial.go @@ -15,6 +15,7 @@ package trino import ( + "encoding/hex" "encoding/json" "fmt" "math" @@ -148,9 +149,11 @@ func Serial(v interface{}) (string, error) { case string: return "'" + strings.Replace(x, "'", "''", -1) + "'", nil - // TODO - []byte should probably be matched to 'VARBINARY' in trino case []byte: - return "", UnsupportedArgError{"[]byte"} + if x == nil { + return "NULL", nil + } + return "X'" + hex.EncodeToString(x) + "'", nil case trinoDate: return fmt.Sprintf("DATE '%04d-%02d-%02d'", x.year, x.month, x.day), nil diff --git a/trino/serial_test.go b/trino/serial_test.go index aa91145f..45a0fd97 100644 --- a/trino/serial_test.go +++ b/trino/serial_test.go @@ -46,6 +46,21 @@ func TestSerial(t *testing.T) { value: `hello "world"`, expectedSerial: `'hello "world"'`, }, + { + name: "basic binary", + value: []byte{0x01, 0x02, 0x03}, + expectedSerial: `X'010203'`, + }, + { + name: "empty binary", + value: []byte{}, + expectedSerial: `X''`, + }, + { + name: "nil binary", + value: []byte(nil), + expectedSerial: `NULL`, + }, { name: "int8", value: int8(100), diff --git a/trino/trino.go b/trino/trino.go index 6f323162..e12257c9 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -56,6 +56,7 @@ import ( "crypto/x509" "database/sql" "database/sql/driver" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -1511,8 +1512,10 @@ func getScanType(typeNames []string) (reflect.Type, error) { switch typeNames[0] { case "boolean": v = sql.NullBool{} - case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "unknown": + case "json", "char", "varchar", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "unknown": v = sql.NullString{} + case "varbinary": + v = []byte{} case "tinyint", "smallint": v = sql.NullInt32{} case "integer": @@ -1596,12 +1599,14 @@ func (c *typeConverter) ConvertValue(v interface{}) (driver.Value, error) { return nil, err } return vv.Bool, err - case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "Geometry", "SphericalGeography", "unknown": + case "json", "char", "varchar", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "Geometry", "SphericalGeography", "unknown": vv, err := scanNullString(v) if !vv.Valid { return nil, err } return vv.String, err + case "varbinary": + return scanNullBytes(v) case "tinyint", "smallint", "integer", "bigint": vv, err := scanNullInt64(v) if !vv.Valid { @@ -1771,6 +1776,26 @@ func scanNullString(v interface{}) (sql.NullString, error) { return sql.NullString{Valid: true, String: vv}, nil } +func scanNullBytes(v interface{}) ([]byte, error) { + if v == nil { + return nil, nil + } + + // VARBINARY values come back as a base64 encoded string. + vv, ok := v.(string) + if !ok { + return nil, fmt.Errorf("cannot convert %v (%T) to []byte", v, v) + } + + // Decode the base64 encoded string into a []byte. + decoded, err := base64.StdEncoding.DecodeString(vv) + if err != nil { + return nil, fmt.Errorf("cannot decode base64 string into []byte: %w", err) + } + + return decoded, nil +} + // NullSliceString represents a slice of string that may be null. type NullSliceString struct { SliceString []sql.NullString diff --git a/trino/trino_test.go b/trino/trino_test.go index f1623a37..e95d0637 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -764,7 +764,7 @@ func TestQueryColumns(t *testing.T) { 0, false, 0, - reflect.TypeOf(sql.NullString{}), + reflect.TypeOf([]byte{}), }, { "JSON",