Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ types:
* integers
* `bool`
* `string`
* `[]byte`
* slices
* `trino.Numeric` - a string representation of a number
* `time.Time` - passed to Trino as a timestamp with a time zone
Expand Down Expand Up @@ -324,7 +325,11 @@ following types:
* `json.Number` for any numeric Trino types
* `[]interface{}` for Trino arrays
* `map[string]interface{}` for Trino maps
* `string` for other Trino types, as character, date, time, or timestamp
* `string` for other Trino types, as character, date, time, or timestamp.

> [!NOTE]
> `VARBINARY` columns are returned as base64-encoded strings when used within
> `ROW`, `MAP`, or `ARRAY` values.

## License

Expand Down
190 changes: 188 additions & 2 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"crypto/x509/pkix"
"database/sql"
"database/sql/driver"
"encoding/json"
"encoding/pem"
"errors"
"flag"
Expand All @@ -34,6 +35,7 @@ import (
"math/big"
"net/http"
"os"
"reflect"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -110,6 +112,14 @@ func TestMain(m *testing.M) {
"8080/tcp",
"8443/tcp",
},
}, func(hc *docker.HostConfig) {
hc.Ulimits = []docker.ULimit{
{
Name: "nofile",
Hard: 4096,
Soft: 4096,
},
}
Comment on lines +115 to +122
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trino requires a minimum of 4096 descriptors in order to startup. On my linux machine, I found docker defaults to a limit of 1024.

I've updated the test configuration such that we can be sure the container is configured with reasonable ulimits.

})
if err != nil {
log.Fatalf("Could not start resource: %s", err)
Expand Down Expand Up @@ -545,6 +555,8 @@ func TestIntegrationTypeConversion(t *testing.T) {
var (
goTime time.Time
nullTime NullTime
goBytes []byte
nullBytes []byte
goString string
nullString sql.NullString
nullStringSlice NullSliceString
Expand All @@ -564,6 +576,8 @@ func TestIntegrationTypeConversion(t *testing.T) {
SELECT
TIMESTAMP '2017-07-10 01:02:03.004 UTC',
CAST(NULL AS TIMESTAMP),
CAST(X'FFFF0FFF3FFFFFFF' AS VARBINARY),
CAST(NULL AS VARBINARY),
CAST('string' AS VARCHAR),
CAST(NULL AS VARCHAR),
ARRAY['A', 'B', NULL],
Expand All @@ -581,6 +595,8 @@ func TestIntegrationTypeConversion(t *testing.T) {
`).Scan(
&goTime,
&nullTime,
&goBytes,
&nullBytes,
&goString,
&nullString,
&nullStringSlice,
Expand All @@ -599,6 +615,172 @@ func TestIntegrationTypeConversion(t *testing.T) {
if err != nil {
t.Fatal(err)
}

// Compare the actual and expected values.
expectedTime := time.Date(2017, 7, 10, 1, 2, 3, 4*1000000, time.UTC)
if !goTime.Equal(expectedTime) {
t.Errorf("expected GoTime to be %v, got %v", expectedTime, goTime)
}

expectedBytes := []byte{0xff, 0xff, 0x0f, 0xff, 0x3f, 0xff, 0xff, 0xff}
if !bytes.Equal(goBytes, expectedBytes) {
t.Errorf("expected GoBytes to be %v, got %v", expectedBytes, goBytes)
}

if nullBytes != nil {
t.Errorf("expected NullBytes to be nil, got %v", nullBytes)
}

if goString != "string" {
t.Errorf("expected GoString to be %q, got %q", "string", goString)
}

if nullString.Valid {
t.Errorf("expected NullString.Valid to be false, got true")
}

if !reflect.DeepEqual(nullStringSlice.SliceString, []sql.NullString{{String: "A", Valid: true}, {String: "B", Valid: true}, {Valid: false}}) {
t.Errorf("expected NullStringSlice.SliceString to be %v, got %v",
[]sql.NullString{{String: "A", Valid: true}, {String: "B", Valid: true}, {Valid: false}},
nullStringSlice.SliceString)
}
if !nullStringSlice.Valid {
t.Errorf("expected NullStringSlice.Valid to be true, got false")
}

expectedSlice2String := [][]sql.NullString{{{String: "A", Valid: true}}, {}}
if !reflect.DeepEqual(nullStringSlice2.Slice2String, expectedSlice2String) {
t.Errorf("expected NullStringSlice2.Slice2String to be %v, got %v", expectedSlice2String, nullStringSlice2.Slice2String)
}
if !nullStringSlice2.Valid {
t.Errorf("expected NullStringSlice2.Valid to be true, got false")
}

expectedSlice3String := [][][]sql.NullString{{{{String: "A", Valid: true}}, {}}, {}}
if !reflect.DeepEqual(nullStringSlice3.Slice3String, expectedSlice3String) {
t.Errorf("expected NullStringSlice3.Slice3String to be %v, got %v", expectedSlice3String, nullStringSlice3.Slice3String)
}
if !nullStringSlice3.Valid {
t.Errorf("expected NullStringSlice3.Valid to be true, got false")
}

expectedSliceInt64 := []sql.NullInt64{{Int64: 1, Valid: true}, {Int64: 2, Valid: true}, {Valid: false}}
if !reflect.DeepEqual(nullInt64Slice.SliceInt64, expectedSliceInt64) {
t.Errorf("expected NullInt64Slice.SliceInt64 to be %v, got %v", expectedSliceInt64, nullInt64Slice.SliceInt64)
}
if !nullInt64Slice.Valid {
t.Errorf("expected NullInt64Slice.Valid to be true, got false")
}

expectedSlice2Int64 := [][]sql.NullInt64{{{Int64: 1, Valid: true}, {Int64: 1, Valid: true}, {Int64: 1, Valid: true}}, {}}
if !reflect.DeepEqual(nullInt64Slice2.Slice2Int64, expectedSlice2Int64) {
t.Errorf("expected NullInt64Slice2.Slice2Int64 to be %v, got %v", expectedSlice2Int64, nullInt64Slice2.Slice2Int64)
}
if !nullInt64Slice2.Valid {
t.Errorf("expected NullInt64Slice2.Valid to be true, got false")
}

expectedSlice3Int64 := [][][]sql.NullInt64{{{{Int64: 1, Valid: true}, {Int64: 1, Valid: true}, {Int64: 1, Valid: true}}, {}}, {}}
if !reflect.DeepEqual(nullInt64Slice3.Slice3Int64, expectedSlice3Int64) {
t.Errorf("expected NullInt64Slice3.Slice3Int64 to be %v, got %v", expectedSlice3Int64, nullInt64Slice3.Slice3Int64)
}
if !nullInt64Slice3.Valid {
t.Errorf("expected NullInt64Slice3.Valid to be true, got false")
}

expectedSliceFloat64 := []sql.NullFloat64{{Float64: 1.0, Valid: true}, {Float64: 2.0, Valid: true}, {Valid: false}}
if !reflect.DeepEqual(nullFloat64Slice.SliceFloat64, expectedSliceFloat64) {
t.Errorf("expected NullFloat64Slice.SliceFloat64 to be %v, got %v", expectedSliceFloat64, nullFloat64Slice.SliceFloat64)
}
if !nullFloat64Slice.Valid {
t.Errorf("expected NullFloat64Slice.Valid to be true, got false")
}

expectedSlice2Float64 := [][]sql.NullFloat64{{{Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}}, {}}
if !reflect.DeepEqual(nullFloat64Slice2.Slice2Float64, expectedSlice2Float64) {
t.Errorf("expected NullFloat64Slice2.Slice2Float64 to be %v, got %v", expectedSlice2Float64, nullFloat64Slice2.Slice2Float64)
}
if !nullFloat64Slice2.Valid {
t.Errorf("expected NullFloat64Slice2.Valid to be true, got false")
}

expectedSlice3Float64 := [][][]sql.NullFloat64{{{{Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}}, {}}, {}}
if !reflect.DeepEqual(nullFloat64Slice3.Slice3Float64, expectedSlice3Float64) {
t.Errorf("expected NullFloat64Slice3.Slice3Float64 to be %v, got %v", expectedSlice3Float64, nullFloat64Slice3.Slice3Float64)
}
if !nullFloat64Slice3.Valid {
t.Errorf("expected NullFloat64Slice3.Valid to be true, got false")
}

expectedMap := map[string]interface{}{"a": "c", "b": "d"}
if !reflect.DeepEqual(goMap, expectedMap) {
t.Errorf("expected GoMap to be %v, got %v", expectedMap, goMap)
}

if nullMap.Valid {
t.Errorf("expected NullMap.Valid to be false, got true")
}

expectedRow := []interface{}{json.Number("1"), "a", "2017-07-10 01:02:03.004000 UTC", []interface{}{"c"}}
if !reflect.DeepEqual(goRow, expectedRow) {
t.Errorf("expected GoRow to be %v, got %v", expectedRow, goRow)
}
}

func TestComplexTypes(t *testing.T) {
// This test has been created to showcase some issues with parsing
// complex types. It is not intended to be a comprehensive test of
// the parsing logic, but rather to provide a reference for future
// changes to the parsing logic.
//
// The current implementation of the parsing logic reads the value
// in the same format as the JSON response from Trino. This means
// that we don't go further to parse values as their structured types.
// For example, a row like `ROW(1, X'0000')` is read as
// a list of a `json.Number(1)` and a base64-encoded string.
t.Skip("skipping failing test")

dsn := *integrationServerFlag
db := integrationOpen(t, dsn)

for _, tt := range []struct {
name string
query string
expected interface{}
}{
{
name: "row containing scalar values",
query: `SELECT ROW(1, 'a', X'0000')`,
expected: []interface{}{1, "a", []byte{0x00, 0x00}},
},
{
name: "nested row",
query: `SELECT ROW(ROW(1, 'a'), ROW(2, 'b'))`,
expected: []interface{}{[]interface{}{1, "a"}, []interface{}{2, "b"}},
},
{
name: "map with scalar values",
query: `SELECT MAP(ARRAY['a', 'b'], ARRAY[1, 2])`,
expected: map[string]interface{}{"a": 1, "b": 2},
},
{
name: "map with nested row",
query: `SELECT MAP(ARRAY['a', 'b'], ARRAY[ROW(1, 'a'), ROW(2, 'b')])`,
expected: map[string]interface{}{"a": []interface{}{1, "a"}, "b": []interface{}{2, "b"}},
},
} {
t.Run(tt.name, func(t *testing.T) {
var result interface{}
err := db.QueryRow(tt.query).Scan(&result)
if err != nil {
t.Fatal(err)
}

if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

func TestIntegrationArgsConversion(t *testing.T) {
Expand All @@ -615,8 +797,9 @@ func TestIntegrationArgsConversion(t *testing.T) {
CAST(1 AS DOUBLE),
TIMESTAMP '2017-07-10 01:02:03.004 UTC',
CAST('string' AS VARCHAR),
CAST(X'FFFF0FFF3FFFFFFF' AS VARBINARY),
ARRAY['A', 'B']
)) AS t(col_tiny, col_small, col_int, col_big, col_real, col_double, col_ts, col_varchar, col_array )
)) AS t(col_tiny, col_small, col_int, col_big, col_real, col_double, col_ts, col_varchar, col_varbinary, col_array )
WHERE 1=1
AND col_tiny = ?
AND col_small = ?
Expand All @@ -626,6 +809,7 @@ func TestIntegrationArgsConversion(t *testing.T) {
AND col_double = cast(? as double)
AND col_ts = ?
AND col_varchar = ?
AND col_varbinary = ?
AND col_array = ?`,
int16(1),
int16(1),
Expand All @@ -635,7 +819,9 @@ func TestIntegrationArgsConversion(t *testing.T) {
Numeric("1"),
time.Date(2017, 7, 10, 1, 2, 3, 4*1000000, time.UTC),
"string",
[]string{"A", "B"}).Scan(&value)
[]byte{0xff, 0xff, 0x0f, 0xff, 0x3f, 0xff, 0xff, 0xff},
[]string{"A", "B"},
).Scan(&value)
if err != nil {
t.Fatal(err)
}
Expand Down
7 changes: 5 additions & 2 deletions trino/serial.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package trino

import (
"encoding/hex"
"encoding/json"
"fmt"
"math"
Expand Down Expand Up @@ -148,9 +149,11 @@ func Serial(v interface{}) (string, error) {
case string:
return "'" + strings.Replace(x, "'", "''", -1) + "'", nil

// TODO - []byte should probably be matched to 'VARBINARY' in trino
case []byte:
return "", UnsupportedArgError{"[]byte"}
if x == nil {
return "NULL", nil
}
return "X'" + hex.EncodeToString(x) + "'", nil

case trinoDate:
return fmt.Sprintf("DATE '%04d-%02d-%02d'", x.year, x.month, x.day), nil
Expand Down
15 changes: 15 additions & 0 deletions trino/serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ func TestSerial(t *testing.T) {
value: `hello "world"`,
expectedSerial: `'hello "world"'`,
},
{
name: "basic binary",
value: []byte{0x01, 0x02, 0x03},
expectedSerial: `X'010203'`,
},
{
name: "empty binary",
value: []byte{},
expectedSerial: `X''`,
},
{
name: "nil binary",
value: []byte(nil),
expectedSerial: `NULL`,
},
{
name: "int8",
value: int8(100),
Expand Down
29 changes: 27 additions & 2 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"crypto/x509"
"database/sql"
"database/sql/driver"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -1511,8 +1512,10 @@ func getScanType(typeNames []string) (reflect.Type, error) {
switch typeNames[0] {
case "boolean":
v = sql.NullBool{}
case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "unknown":
case "json", "char", "varchar", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "unknown":
v = sql.NullString{}
case "varbinary":
v = []byte{}
case "tinyint", "smallint":
v = sql.NullInt32{}
case "integer":
Expand Down Expand Up @@ -1596,12 +1599,14 @@ func (c *typeConverter) ConvertValue(v interface{}) (driver.Value, error) {
return nil, err
}
return vv.Bool, err
case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "Geometry", "SphericalGeography", "unknown":
case "json", "char", "varchar", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "Geometry", "SphericalGeography", "unknown":
vv, err := scanNullString(v)
if !vv.Valid {
return nil, err
}
return vv.String, err
case "varbinary":
return scanNullBytes(v)
case "tinyint", "smallint", "integer", "bigint":
vv, err := scanNullInt64(v)
if !vv.Valid {
Expand Down Expand Up @@ -1771,6 +1776,26 @@ func scanNullString(v interface{}) (sql.NullString, error) {
return sql.NullString{Valid: true, String: vv}, nil
}

func scanNullBytes(v interface{}) ([]byte, error) {
if v == nil {
return nil, nil
}

// VARBINARY values come back as a base64 encoded string.
vv, ok := v.(string)
if !ok {
return nil, fmt.Errorf("cannot convert %v (%T) to []byte", v, v)
}

// Decode the base64 encoded string into a []byte.
decoded, err := base64.StdEncoding.DecodeString(vv)
if err != nil {
return nil, fmt.Errorf("cannot decode base64 string into []byte: %w", err)
}

return decoded, nil
}

// NullSliceString represents a slice of string that may be null.
type NullSliceString struct {
SliceString []sql.NullString
Expand Down
2 changes: 1 addition & 1 deletion trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ func TestQueryColumns(t *testing.T) {
0,
false,
0,
reflect.TypeOf(sql.NullString{}),
reflect.TypeOf([]byte{}),
},
{
"JSON",
Expand Down
Loading