diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..f24bcb4b8 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -32,6 +32,7 @@ import ( "context" "errors" "fmt" + "github.com/stretchr/testify/require" "io" "math" "math/big" @@ -943,6 +944,190 @@ func TestMapScan(t *testing.T) { assertEqual(t, "address", "10.0.0.1", row["address"]) } +type expec struct { + Id int + Col_ascii interface{} + Col_bigint interface{} + Col_blob interface{} + Col_boolean interface{} + Col_date interface{} + Col_decimal interface{} + Col_double interface{} + Col_duration interface{} + Col_float interface{} + Col_inet interface{} + Col_int interface{} + Col_smallint interface{} + Col_text interface{} + Col_time interface{} + Col_timestamp interface{} + Col_timeuuid interface{} + Col_tinyint interface{} + Col_uuid interface{} + Col_varchar interface{} + Col_varint interface{} +} + +func TestMapScanWithNullbleValue(t *testing.T) { + timeUUID := TimeUUID() + date := time.Date(2009, time.November, 10, 0, 0, 0, 0, time.UTC) + timestamp := time.Time{}.Add(time.Duration(200)) + + testCases := []struct { + name string + query string + keys []string + values []interface{} + expectations expec + id int64 + }{ + { + name: "with values", + query: `INSERT INTO gocql_test.scan_map_with_nullable_value_table + (Id, Col_ascii, Col_bigint, Col_blob, Col_boolean, Col_date, Col_decimal, Col_double, + Col_duration, Col_float, Col_inet, Col_int, Col_smallint, Col_text, Col_time, Col_timestamp, + Col_timeuuid, Col_tinyint, Col_uuid, Col_varchar, Col_varint) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + keys: []string{"Id", "Col_ascii", "Col_bigint", "Col_blob", "Col_boolean", "Col_date", "Col_decimal", "Col_double", "Col_duration", "Col_float", "Col_inet", "Col_int", "Col_smallint", "Col_text", "Col_time", "Col_timestamp", "Col_timeuuid", "Col_tinyint", "Col_uuid", "Col_varchar", "Col_varint"}, + values: []interface{}{1, "test_ascii", int64(123456789), []byte{0x01, 0x02, 0x03}, true, + date, *inf.NewDec(12345, 0), 123.45, Duration{ + Months: 250, + Days: 500, + Nanoseconds: 300010001, + }, float32(3.14), "127.0.0.1", + 123, int16(1000), "test_text", time.Duration(200), timestamp, timeUUID, + int8(5), timeUUID, "test_varchar", *big.NewInt(99999)}, + expectations: expec{ + Id: 1, + Col_ascii: "test_ascii", + Col_bigint: int64(123456789), + Col_blob: []byte{0x01, 0x02, 0x03}, + Col_boolean: true, + Col_date: date, + Col_decimal: *inf.NewDec(12345, 0), + Col_double: 123.45, + Col_duration: Duration{ + Months: 250, + Days: 500, + Nanoseconds: 300010001, + }, + Col_float: float32(3.14), + Col_inet: "127.0.0.1", + Col_int: 123, + Col_smallint: int16(1000), + Col_text: "test_text", + Col_time: time.Duration(200), + Col_timestamp: timestamp, + Col_timeuuid: timeUUID, + Col_tinyint: int8(5), + Col_uuid: timeUUID, + Col_varchar: "test_varchar", + Col_varint: *big.NewInt(99999), + }, + id: 1, + }, + + { + name: "without values", + query: `INSERT INTO gocql_test.scan_map_with_nullable_value_table (Id) VALUES (?)`, + keys: []string{"Id", "Col_ascii", "Col_bigint", "Col_blob", "Col_boolean", "Col_date", "Col_decimal", "Col_double", "Col_duration", "Col_float", "Col_inet", "Col_int", "Col_smallint", "Col_text", "Col_time", "Col_timestamp", "Col_timeuuid", "Col_tinyint", "Col_uuid", "Col_varchar", "Col_varint"}, + values: []interface{}{2}, + expectations: expec{ + Id: 2, + Col_ascii: nil, + Col_bigint: nil, + Col_blob: nil, + Col_boolean: nil, + Col_date: nil, + Col_decimal: nil, + Col_double: nil, + Col_duration: nil, + Col_float: nil, + Col_inet: nil, + Col_int: nil, + Col_smallint: nil, + Col_text: nil, + Col_time: nil, + Col_timestamp: nil, + Col_timeuuid: nil, + Col_tinyint: nil, + Col_uuid: nil, + Col_varchar: nil, + Col_varint: nil, + }, + id: 2, + }, + } + session := createSession(t) + defer session.Close() + + createTableQuery := ` + CREATE TABLE IF NOT EXISTS gocql_test.scan_map_with_nullable_value_table ( + Id INT PRIMARY KEY, + Col_ascii ASCII, + Col_bigint BIGINT, + Col_blob BLOB, + Col_boolean BOOLEAN, + Col_date DATE, + Col_decimal DECIMAL, + Col_double DOUBLE, + Col_duration DURATION, + Col_float FLOAT, + Col_inet INET, + Col_int INT, + Col_smallint SMALLINT, + Col_text TEXT, + Col_time TIME, + Col_timestamp TIMESTAMP, + Col_timeuuid TIMEUUID, + Col_tinyint TINYINT, + Col_uuid UUID, + Col_varchar VARCHAR, + Col_varint VARINT + ); + ` + + err := session.Query(createTableQuery).Exec() + if err != nil { + t.Fatal("Failed to create the table:", err) + } + + t.Log("Table created successfully!") + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + err = session.Query(testCase.query, testCase.values...).Exec() + if err != nil { + t.Fatal("Failed to execute query:", err) + } + + iter := session.Query(`SELECT * FROM gocql_test.scan_map_with_nullable_value_table WHERE Id = ? LIMIT 1`, testCase.id).Iter() + row := make(map[string]interface{}) + + if !iter.MapScanWithNullableValues(row) { + t.Fatal("select:", iter.Close()) + } + + v := reflect.ValueOf(testCase.expectations) + for _, key := range testCase.keys { + if testCase.id == 1 { + col := row[strings.ToLower(key)] + if !reflect.ValueOf(col).Elem().IsZero() { + got := reflect.ValueOf(col).Elem().Interface() + + require.Equal(t, v.FieldByName(key).Interface(), got, key) + } + } else { + if key != "Id" && !reflect.ValueOf(row[strings.ToLower(key)]).IsZero() { + t.Fatalf("Failed on:%v,\nExpected %v to be %v,\n Got: %v", key, key, v.FieldByName(key).Interface(), row[strings.ToLower(key)]) + } + } + } + }) + } +} + func TestSliceMap(t *testing.T) { session := createSession(t) defer session.Close() diff --git a/common_test.go b/common_test.go index a5edb03c6..7117fc662 100644 --- a/common_test.go +++ b/common_test.go @@ -286,6 +286,14 @@ func assertEqual(t *testing.T, description string, expected, actual interface{}) func assertDeepEqual(t *testing.T, description string, expected, actual interface{}) { t.Helper() + rv1 := reflect.ValueOf(expected) + rv2 := reflect.ValueOf(actual) + if rv1.Kind() == reflect.Ptr { + expected = rv1.Elem().Interface() + } + if rv2.Kind() == reflect.Ptr { + actual = rv2.Elem().Interface() + } if !reflect.DeepEqual(expected, actual) { t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) } diff --git a/helpers.go b/helpers.go index f2faee9e0..3e7e7a5dd 100644 --- a/helpers.go +++ b/helpers.go @@ -101,6 +101,66 @@ func goType(t TypeInfo) (reflect.Type, error) { } } +func nullableGoType(t TypeInfo) (reflect.Type, error) { + switch t.Type() { + case TypeVarchar, TypeAscii, TypeInet, TypeText: + return reflect.TypeOf(new(string)), nil + case TypeBigInt, TypeCounter: + return reflect.TypeOf(new(int64)), nil + case TypeTime: + return reflect.TypeOf(new(time.Duration)), nil + case TypeTimestamp: + return reflect.TypeOf(new(time.Time)), nil + case TypeBlob: + return reflect.TypeOf(new([]byte)), nil + case TypeBoolean: + return reflect.TypeOf(new(bool)), nil + case TypeFloat: + return reflect.TypeOf(new(float32)), nil + case TypeDouble: + return reflect.TypeOf(new(float64)), nil + case TypeInt: + return reflect.TypeOf(new(int)), nil + case TypeSmallInt: + return reflect.TypeOf(new(int16)), nil + case TypeTinyInt: + return reflect.TypeOf(new(int8)), nil + case TypeDecimal: + return reflect.TypeOf(*new(*inf.Dec)), nil + case TypeUUID, TypeTimeUUID: + return reflect.TypeOf(new(UUID)), nil + case TypeList, TypeSet: + elemType, err := nullableGoType(t.(CollectionType).Elem) + if err != nil { + return nil, err + } + return reflect.SliceOf(elemType), nil + case TypeMap: + keyType, err := nullableGoType(t.(CollectionType).Key) + if err != nil { + return nil, err + } + valueType, err := nullableGoType(t.(CollectionType).Elem) + if err != nil { + return nil, err + } + return reflect.MapOf(keyType, valueType), nil + case TypeVarint: + return reflect.TypeOf(*new(*big.Int)), nil + case TypeTuple: + tuple := t.(TupleTypeInfo) + return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil + case TypeUDT: + return reflect.TypeOf(make(map[string]interface{})), nil + case TypeDate: + return reflect.TypeOf(new(time.Time)), nil + case TypeDuration: + return reflect.TypeOf(new(Duration)), nil + default: + return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) + } +} + func dereference(i interface{}) interface{} { return reflect.Indirect(reflect.ValueOf(i)).Interface() } @@ -323,6 +383,8 @@ func TupleColumnName(c string, n int) string { } func (iter *Iter) RowData() (RowData, error) { + var err error + var val interface{} if iter.err != nil { return RowData{}, iter.err } @@ -332,7 +394,12 @@ func (iter *Iter) RowData() (RowData, error) { for _, column := range iter.Columns() { if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { - val, err := column.TypeInfo.NewWithError() + if !iter.isNullableScan { + val, err = column.TypeInfo.NewWithError() + } else { + val, err = column.TypeInfo.NewWithNullable() + } + if err != nil { return RowData{}, err } @@ -342,10 +409,11 @@ func (iter *Iter) RowData() (RowData, error) { for i, elem := range c.Elems { columns = append(columns, TupleColumnName(column.Name, i)) val, err := elem.NewWithError() + if err != nil { return RowData{}, err } - values = append(values, val) + values = append(values, &val) } } } @@ -451,6 +519,16 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool { return false } +// MapScanWithNullableValues takes a map[string]interface{} and populates it with a row +// that is returned from cassandra. +// +// Each call to MapScanWithNullableValues() must be called with a new map object. +func (iter *Iter) MapScanWithNullableValues(m map[string]interface{}) bool { + iter.setNullableScan(true) + scan := iter.MapScan(m) + return scan +} + func copyBytes(p []byte) []byte { b := make([]byte, len(p)) copy(b, p) diff --git a/marshal.go b/marshal.go index 4d0adb923..9f6b5e142 100644 --- a/marshal.go +++ b/marshal.go @@ -2468,6 +2468,12 @@ type TypeInfo interface { // // If there is no corresponding Go type for the CQL type, NewWithError returns an error. NewWithError() (interface{}, error) + + // NewWithNullable creates a pointer to an empty version of whatever type + // is referenced by the TypeInfo receiver. + // + // Works similarly to NewWithError, but returns nullable values instead of default go type values. + NewWithNullable() (interface{}, error) } type NativeType struct { @@ -2476,6 +2482,14 @@ type NativeType struct { custom string // only used for TypeCustom } +func (t NativeType) NewWithNullable() (interface{}, error) { + typ, err := nullableGoType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + func NewNativeType(proto byte, typ Type, custom string) NativeType { return NativeType{proto, typ, custom} } @@ -2523,6 +2537,14 @@ type CollectionType struct { Elem TypeInfo // only used for TypeMap, TypeList and TypeSet } +func (t CollectionType) NewWithNullable() (interface{}, error) { + typ, err := nullableGoType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + func (t CollectionType) NewWithError() (interface{}, error) { typ, err := goType(t) if err != nil { @@ -2557,6 +2579,14 @@ type TupleTypeInfo struct { Elems []TypeInfo } +func (t TupleTypeInfo) NewWithNullable() (interface{}, error) { + typ, err := nullableGoType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + func (t TupleTypeInfo) String() string { var buf bytes.Buffer buf.WriteString(fmt.Sprintf("%s(", t.typ)) @@ -2596,6 +2626,14 @@ type UDTTypeInfo struct { Elements []UDTField } +func (u UDTTypeInfo) NewWithNullable() (interface{}, error) { + typ, err := nullableGoType(u) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + func (u UDTTypeInfo) NewWithError() (interface{}, error) { typ, err := goType(u) if err != nil { diff --git a/session.go b/session.go index a600b95f3..85d786ccb 100644 --- a/session.go +++ b/session.go @@ -1434,8 +1434,17 @@ type Iter struct { next *nextIter host *HostInfo - framer *framer - closed int32 + framer *framer + closed int32 + isNullableScan bool +} + +func (iter *Iter) getNullableScan() bool { + return iter.isNullableScan +} + +func (iter *Iter) setNullableScan(v bool) { + iter.isNullableScan = v } // Host returns the host which the query was sent to.