diff --git a/cassandra_test.go b/cassandra_test.go index c092184ec..01b9cf23f 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -664,21 +664,21 @@ func TestCAS(t *testing.T) { } failBatch = session.Batch(LoggedBatch) - failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) + failBatch.Query("UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) if _, _, err := session.ExecuteBatchCAS(failBatch, new(bool)); err == nil { t.Fatal("update should have errored") } // make sure MapScanCAS does not panic when MapScan fails casMap = make(map[string]interface{}) casMap["last_modified"] = false - if _, err := session.Query(`UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?`, + if _, err := session.Query(`UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?`, modified).MapScanCAS(casMap); err == nil { t.Fatal("update should hvae errored", err) } // make sure MapExecuteBatchCAS does not panic when MapScan fails failBatch = session.Batch(LoggedBatch) - failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) + failBatch.Query("UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) casMap = make(map[string]interface{}) casMap["last_modified"] = false if _, _, err := session.MapExecuteBatchCAS(failBatch, casMap); err == nil { diff --git a/frame.go b/frame.go index 5b0c4c38d..2c860704f 100644 --- a/frame.go +++ b/frame.go @@ -32,6 +32,7 @@ import ( "io/ioutil" "net" "runtime" + "strconv" "strings" "time" ) @@ -913,6 +914,22 @@ func (f *framer) readTypeInfo() TypeInfo { collection.Elem = f.readTypeInfo() return collection + case TypeCustom: + if strings.HasPrefix(simple.custom, "org.apache.cassandra.db.marshal.VectorType") { + spec := strings.TrimPrefix(simple.custom, "org.apache.cassandra.db.marshal.VectorType") + spec = spec[1 : len(spec)-1] // remove parenthesis + idx := strings.LastIndex(spec, ",") + typeStr := spec[:idx] + dimStr := spec[idx+1:] + subType := getCassandraLongType(strings.TrimSpace(typeStr), f.proto, nopLogger{}) + dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) + vector := VectorType{ + NativeType: simple, + SubType: subType, + Dimensions: dim, + } + return vector + } } return simple diff --git a/helpers.go b/helpers.go index 30ae5f3e1..335fbd647 100644 --- a/helpers.go +++ b/helpers.go @@ -25,9 +25,11 @@ package gocql import ( + "encoding/hex" "fmt" "math/big" "reflect" + "strconv" "strings" "time" @@ -39,6 +41,46 @@ type RowData struct { Values []interface{} } +// asVectorType attempts to convert a NativeType(custom) which represents a VectorType +// into a concrete VectorType. It also works recursively (nested vectors). +func asVectorType(t TypeInfo) (VectorType, bool) { + if v, ok := t.(VectorType); ok { + return v, true + } + n, ok := t.(NativeType) + if !ok || n.Type() != TypeCustom { + return VectorType{}, false + } + const prefix = "org.apache.cassandra.db.marshal.VectorType" + if !strings.HasPrefix(n.Custom(), prefix+"(") { + return VectorType{}, false + } + + spec := strings.TrimPrefix(n.Custom(), prefix) + spec = strings.Trim(spec, "()") + // split last comma -> subtype spec , dimensions + idx := strings.LastIndex(spec, ",") + if idx <= 0 { + return VectorType{}, false + } + subStr := strings.TrimSpace(spec[:idx]) + dimStr := strings.TrimSpace(spec[idx+1:]) + dim, err := strconv.Atoi(dimStr) + if err != nil { + return VectorType{}, false + } + subType := getCassandraLongType(subStr, n.Version(), nopLogger{}) + // recurse if subtype itself is still a custom vector + if innerVec, ok := asVectorType(subType); ok { + subType = innerVec + } + return VectorType{ + NativeType: NewCustomType(n.Version(), TypeCustom, prefix), + SubType: subType, + Dimensions: dim, + }, true +} + func goType(t TypeInfo) (reflect.Type, error) { switch t.Type() { case TypeVarchar, TypeAscii, TypeInet, TypeText: @@ -95,6 +137,20 @@ func goType(t TypeInfo) (reflect.Type, error) { return reflect.TypeOf(*new(time.Time)), nil case TypeDuration: return reflect.TypeOf(*new(Duration)), nil + case TypeCustom: + // Handle VectorType encoded as custom + if vec, ok := asVectorType(t); ok { + innerPtr, err := vec.SubType.NewWithError() + if err != nil { + return nil, err + } + elemType := reflect.TypeOf(innerPtr) + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + return reflect.SliceOf(elemType), nil + } + return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) default: return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) } @@ -161,59 +217,168 @@ func getCassandraBaseType(name string) Type { } } -func getCassandraType(name string, logger StdLogger) TypeInfo { +// TODO: Cover with unit tests. +// Parses long Java-style type definition to internal data structures. +func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo { + if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.SetType") { + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeSet), + Elem: getCassandraLongType(unwrapCompositeTypeDefinition(name, "org.apache.cassandra.db.marshal.SetType", '('), protoVer, logger), + } + } else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.ListType") { + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeList), + Elem: getCassandraLongType(unwrapCompositeTypeDefinition(name, "org.apache.cassandra.db.marshal.ListType", '('), protoVer, logger), + } + } else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.MapType") { + names := splitJavaCompositeTypes(name, "org.apache.cassandra.db.marshal.MapType") + if len(names) != 2 { + logger.Printf("gocql: error parsing map type, it has %d subelements, expecting 2\n", len(names)) + return NewNativeType(protoVer, TypeCustom) + } + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeMap), + Key: getCassandraLongType(names[0], protoVer, logger), + Elem: getCassandraLongType(names[1], protoVer, logger), + } + } else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.TupleType") { + names := splitJavaCompositeTypes(name, "org.apache.cassandra.db.marshal.TupleType") + types := make([]TypeInfo, len(names)) + + for i, name := range names { + types[i] = getCassandraLongType(name, protoVer, logger) + } + + return TupleTypeInfo{ + NativeType: NewNativeType(protoVer, TypeTuple), + Elems: types, + } + } else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.UserType") { + names := splitJavaCompositeTypes(name, "org.apache.cassandra.db.marshal.UserType") + fields := make([]UDTField, len(names)-2) + + for i := 2; i < len(names); i++ { + spec := strings.Split(names[i], ":") + fieldName, _ := hex.DecodeString(spec[0]) + fields[i-2] = UDTField{ + Name: string(fieldName), + Type: getCassandraLongType(spec[1], protoVer, logger), + } + } + + udtName, _ := hex.DecodeString(names[1]) + return UDTTypeInfo{ + NativeType: NewNativeType(protoVer, TypeUDT), + KeySpace: names[0], + Name: string(udtName), + Elements: fields, + } + } else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.VectorType") { + names := splitJavaCompositeTypes(name, "org.apache.cassandra.db.marshal.VectorType") + subType := getCassandraLongType(strings.TrimSpace(names[0]), protoVer, logger) + dim, err := strconv.Atoi(strings.TrimSpace(names[1])) + if err != nil { + logger.Printf("gocql: error parsing vector dimensions: %v\n", err) + return NewNativeType(protoVer, TypeCustom) + } + + return VectorType{ + NativeType: NewCustomType(protoVer, TypeCustom, "org.apache.cassandra.db.marshal.VectorType"), + SubType: subType, + Dimensions: dim, + } + } else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.FrozenType") { + names := splitJavaCompositeTypes(name, "org.apache.cassandra.db.marshal.FrozenType") + return getCassandraLongType(strings.TrimSpace(names[0]), protoVer, logger) + } else { + // basic type + return NativeType{ + proto: protoVer, + typ: getApacheCassandraType(name), + } + } +} + +// Parses short CQL type representation (e.g. map) to internal data structures. +func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(name, "frozen<") { - return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger) + return getCassandraType(unwrapCompositeTypeDefinition(name, "frozen", '<'), protoVer, logger) } else if strings.HasPrefix(name, "set<") { return CollectionType{ - NativeType: NativeType{typ: TypeSet}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger), + NativeType: NewNativeType(protoVer, TypeSet), + Elem: getCassandraType(unwrapCompositeTypeDefinition(name, "set", '<'), protoVer, logger), } } else if strings.HasPrefix(name, "list<") { return CollectionType{ - NativeType: NativeType{typ: TypeList}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), + NativeType: NewNativeType(protoVer, TypeList), + Elem: getCassandraType(unwrapCompositeTypeDefinition(name, "list", '<'), protoVer, logger), } } else if strings.HasPrefix(name, "map<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) + names := splitCQLCompositeTypes(name, "map") if len(names) != 2 { logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) - return NativeType{ - typ: TypeCustom, - } + return NewNativeType(protoVer, TypeCustom) } return CollectionType{ - NativeType: NativeType{typ: TypeMap}, - Key: getCassandraType(names[0], logger), - Elem: getCassandraType(names[1], logger), + NativeType: NewNativeType(protoVer, TypeMap), + Key: getCassandraType(names[0], protoVer, logger), + Elem: getCassandraType(names[1], protoVer, logger), } } else if strings.HasPrefix(name, "tuple<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) + names := splitCQLCompositeTypes(name, "tuple") types := make([]TypeInfo, len(names)) for i, name := range names { - types[i] = getCassandraType(name, logger) + types[i] = getCassandraType(name, protoVer, logger) } return TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NewNativeType(protoVer, TypeTuple), Elems: types, } + } else if strings.HasPrefix(name, "vector<") { + names := splitCQLCompositeTypes(name, "vector") + subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger) + dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) + + return VectorType{ + NativeType: NewCustomType(protoVer, TypeCustom, "org.apache.cassandra.db.marshal.VectorType"), + SubType: subType, + Dimensions: dim, + } } else { return NativeType{ - typ: getCassandraBaseType(name), + proto: protoVer, + typ: getCassandraBaseType(name), } } } -func splitCompositeTypes(name string) []string { - if !strings.Contains(name, "<") { - return strings.Split(name, ", ") +func splitCQLCompositeTypes(name string, typeName string) []string { + return splitCompositeTypes(name, typeName, '<', '>') +} + +func splitJavaCompositeTypes(name string, typeName string) []string { + return splitCompositeTypes(name, typeName, '(', ')') +} + +func unwrapCompositeTypeDefinition(name string, typeName string, typeOpen int32) string { + return strings.TrimPrefix(name[:len(name)-1], typeName+string(typeOpen)) +} + +func splitCompositeTypes(name string, typeName string, typeOpen int32, typeClose int32) []string { + def := unwrapCompositeTypeDefinition(name, typeName, typeOpen) + if !strings.Contains(def, string(typeOpen)) { + parts := strings.Split(def, ",") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts } var parts []string lessCount := 0 segment := "" - for _, char := range name { + for _, char := range def { if char == ',' && lessCount == 0 { if segment != "" { parts = append(parts, strings.TrimSpace(segment)) @@ -222,9 +387,9 @@ func splitCompositeTypes(name string) []string { continue } segment += string(char) - if char == '<' { + if char == typeOpen { lessCount++ - } else if char == '>' { + } else if char == typeClose { lessCount-- } } @@ -282,6 +447,10 @@ func getApacheCassandraType(class string) Type { return TypeTuple case "DurationType": return TypeDuration + case "SimpleDateType": + return TypeDate + case "UserType": + return TypeUDT default: return TypeCustom } diff --git a/helpers_test.go b/helpers_test.go index b17164bb4..eaf8fe731 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -35,7 +35,7 @@ import ( func TestGetCassandraType_Set(t *testing.T) { t.Parallel() - typ := getCassandraType("set", &defaultLogger{}) + typ := getCassandraType("set", protoVersion4, &defaultLogger{}) set, ok := typ.(CollectionType) if !ok { t.Fatalf("expected CollectionType got %T", typ) @@ -230,11 +230,68 @@ func TestGetCassandraType(t *testing.T) { Elem: NativeType{typ: TypeDuration}, }, }, + { + "vector", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType", + }, + SubType: NativeType{typ: TypeFloat}, + Dimensions: 3, + }, + }, + { + "vector, 5>", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType", + }, + SubType: VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType", + }, + SubType: NativeType{typ: TypeFloat}, + Dimensions: 3, + }, + Dimensions: 5, + }, + }, + { + "vector, 5>", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType", + }, + SubType: CollectionType{ + NativeType: NativeType{typ: TypeMap}, + Key: NativeType{typ: TypeUUID}, + Elem: NativeType{typ: TypeTimestamp}, + }, + Dimensions: 5, + }, + }, + { + "vector>, 100>", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType", + }, + SubType: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeFloat}, + }, + }, + Dimensions: 100, + }, + }, } for _, test := range tests { t.Run(test.input, func(t *testing.T) { - got := getCassandraType(test.input, &defaultLogger{}) + got := getCassandraType(test.input, 0, &defaultLogger{}) // TODO(zariel): define an equal method on the types? if !reflect.DeepEqual(got, test.exp) { diff --git a/integration_serialization_scylla_test.go b/integration_serialization_scylla_test.go index dd863abc1..aea328db2 100644 --- a/integration_serialization_scylla_test.go +++ b/integration_serialization_scylla_test.go @@ -8,7 +8,9 @@ import ( "fmt" "math/big" "reflect" + "strings" "testing" + "time" "unsafe" "gopkg.in/inf.v0" @@ -67,7 +69,7 @@ func checkTypeMarshal(t *testing.T, tc valcases.SimpleTypeCases) { cqlName := tc.CQLName t.Run(cqlName, func(t *testing.T) { tp := Type(tc.CQLType) - cqlType := NewNativeType(4, tp, "") + cqlType := NewNativeType(4, tp) for _, valCase := range tc.Cases { for _, langCase := range valCase.LangCases { @@ -89,7 +91,7 @@ func checkTypeUnmarshal(t *testing.T, tc valcases.SimpleTypeCases) { cqlName := tc.CQLName t.Run(cqlName, func(t *testing.T) { tp := Type(tc.CQLType) - cqlType := NewNativeType(4, tp, "") + cqlType := NewNativeType(4, tp) for _, valCase := range tc.Cases { for _, langCase := range valCase.LangCases { @@ -115,7 +117,7 @@ func checkTypeInsertSelect(t *testing.T, session *Session, insertStmt, selectStm cqlName := tc.CQLName t.Run(cqlName, func(t *testing.T) { tp := Type(tc.CQLType) - cqlType := NewNativeType(4, tp, "") + cqlType := NewNativeType(4, tp) for _, valCase := range tc.Cases { valCaseName := valCase.Name @@ -217,3 +219,692 @@ func equalVals(in1, in2 interface{}) bool { return reflect.DeepEqual(in1, in2) } } + +// SliceMapTypesTestCase defines a test case for validating SliceMap/MapScan behavior +type SliceMapTypesTestCase struct { + CQLType string + CQLValue string // Non-NULL value to insert + ExpectedValue interface{} // Expected value for non-NULL case + ExpectedNullValue interface{} // Expected value for NULL +} + +// compareCollectionValues compares collection values (lists, sets, maps) with special handling +func compareCollectionValues(t *testing.T, cqlType string, expected, actual interface{}) bool { + switch { + case strings.HasPrefix(cqlType, "set<"): + // Sets are returned as slices, but order is not guaranteed + expectedSlice := reflect.ValueOf(expected) + actualSlice := reflect.ValueOf(actual) + if expectedSlice.Kind() != reflect.Slice || actualSlice.Kind() != reflect.Slice { + return false + } + if expectedSlice.Len() != actualSlice.Len() { + return false + } + + // Convert to maps for unordered comparison + expectedSet := make(map[interface{}]bool) + for i := 0; i < expectedSlice.Len(); i++ { + expectedSet[expectedSlice.Index(i).Interface()] = true + } + + actualSet := make(map[interface{}]bool) + for i := 0; i < actualSlice.Len(); i++ { + actualSet[actualSlice.Index(i).Interface()] = true + } + + return reflect.DeepEqual(expectedSet, actualSet) + + default: + // For lists, maps, and other collections, reflect.DeepEqual works fine + return reflect.DeepEqual(expected, actual) + } +} + +// compareValues compares expected and actual values with type-specific logic +func compareValues(t *testing.T, cqlType string, expected, actual interface{}) bool { + switch cqlType { + case "varint": + // big.Int needs Cmp() for proper comparison, but handle nil pointers safely + if expectedBig, ok := expected.(*big.Int); ok { + if actualBig, ok := actual.(*big.Int); ok { + // Handle nil cases + if expectedBig == nil && actualBig == nil { + return true + } + if expectedBig == nil || actualBig == nil { + return false + } + return expectedBig.Cmp(actualBig) == 0 + } + } + return reflect.DeepEqual(expected, actual) + + case "decimal": + // inf.Dec needs Cmp() for proper comparison, but handle nil pointers safely + if expectedDec, ok := expected.(*inf.Dec); ok { + if actualDec, ok := actual.(*inf.Dec); ok { + // Handle nil cases + if expectedDec == nil && actualDec == nil { + return true + } + if expectedDec == nil || actualDec == nil { + return false + } + return expectedDec.Cmp(actualDec) == 0 + } + } + return reflect.DeepEqual(expected, actual) + + default: + // reflect.DeepEqual handles nil vs empty slice/map distinction correctly for all types + // including inet (net.IP), blob ([]byte), collections ([]T, map[K]V), etc. + // This is critical for catching zero value behavior changes in the driver + return reflect.DeepEqual(expected, actual) + } +} + +// TestSliceMapMapScanTypes tests SliceMap and MapScan with various CQL types +func TestSliceMapMapScanTypes(t *testing.T) { + session := createSession(t) + defer session.Close() + + tableCQL := ` + CREATE TABLE IF NOT EXISTS gocql_test.slicemap_test ( + id int PRIMARY KEY, + tinyint_col tinyint, + smallint_col smallint, + int_col int, + bigint_col bigint, + float_col float, + double_col double, + boolean_col boolean, + text_col text, + ascii_col ascii, + varchar_col varchar, + timestamp_col timestamp, + uuid_col uuid, + timeuuid_col timeuuid, + inet_col inet, + blob_col blob, + varint_col varint, + decimal_col decimal, + date_col date, + time_col time, + duration_col duration + )` + + if err := createTable(session, tableCQL); err != nil { + t.Fatal("Failed to create test table:", err) + } + + if err := session.Query("TRUNCATE gocql_test.slicemap_test").Exec(); err != nil { + t.Fatal("Failed to truncate test table:", err) + } + + testCases := []SliceMapTypesTestCase{ + {"tinyint", "42", int8(42), int8(0)}, + {"smallint", "1234", int16(1234), int16(0)}, + {"int", "123456", int(123456), int(0)}, + {"bigint", "1234567890", int64(1234567890), int64(0)}, + {"float", "3.14", float32(3.14), float32(0)}, + {"double", "2.718281828", float64(2.718281828), float64(0)}, + {"boolean", "true", true, false}, + {"text", "'hello world'", "hello world", ""}, + {"ascii", "'hello ascii'", "hello ascii", ""}, + {"varchar", "'hello varchar'", "hello varchar", ""}, + {"timestamp", "1388534400000", time.Unix(1388534400, 0).UTC(), time.Time{}}, + {"uuid", "550e8400-e29b-41d4-a716-446655440000", mustParseUUID("550e8400-e29b-41d4-a716-446655440000"), UUID{}}, + {"timeuuid", "60d79c23-5793-11f0-8afe-bcfce78b517a", mustParseUUID("60d79c23-5793-11f0-8afe-bcfce78b517a"), UUID{}}, + {"inet", "'127.0.0.1'", "127.0.0.1", ""}, + {"blob", "0x48656c6c6f", []byte("Hello"), []byte(nil)}, + {"varint", "123456789012345678901234567890", mustParseBigInt("123456789012345678901234567890"), (*big.Int)(nil)}, + {"decimal", "123.45", mustParseDecimal("123.45"), (*inf.Dec)(nil)}, + {"date", "'2015-05-03'", time.Date(2015, 5, 3, 0, 0, 0, 0, time.UTC), time.Date(-5877641, 06, 23, 0, 0, 0, 0, time.UTC)}, + {"time", "'13:30:54.234'", 13*time.Hour + 30*time.Minute + 54*time.Second + 234*time.Millisecond, time.Duration(0)}, + {"duration", "1y2mo3d4h5m6s789ms", mustCreateDuration(14, 3, 4*time.Hour+5*time.Minute+6*time.Second+789*time.Millisecond), Duration{}}, + } + + for i, tc := range testCases { + t.Run(tc.CQLType, func(t *testing.T) { + testSliceMapMapScanSimple(t, session, tc, i) + }) + } +} + +// Simplified test function that tests both SliceMap and MapScan with both NULL and non-NULL values +func testSliceMapMapScanSimple(t *testing.T, session *Session, tc SliceMapTypesTestCase, id int) { + colName := tc.CQLType + "_col" + + t.Run("NonNull", func(t *testing.T) { + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_test (id, %s) VALUES (?, %s)", colName, tc.CQLValue) + if err := session.Query(insertQuery, id*2).Exec(); err != nil { + t.Fatalf("Failed to insert non-NULL value: %v", err) + } + + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + result := queryAndExtractValue(t, session, colName, id*2, method) + validateResult(t, tc.CQLType, tc.ExpectedValue, result, method, "non-NULL") + }) + } + }) + + t.Run("Null", func(t *testing.T) { + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_test (id, %s) VALUES (?, NULL)", colName) + if err := session.Query(insertQuery, id*2+1).Exec(); err != nil { + t.Fatalf("Failed to insert NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + result := queryAndExtractValue(t, session, colName, id*2+1, method) + validateResult(t, tc.CQLType, tc.ExpectedNullValue, result, method, "NULL") + }) + } + }) +} + +func queryAndExtractValue(t *testing.T, session *Session, colName string, id int, method string) interface{} { + fmt.Println("queryAndExtractValue") + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_test WHERE id = ?", colName) + + switch method { + case "SliceMap": + iter := session.Query(selectQuery, id).Iter() + sliceResults, err := iter.SliceMap() + fmt.Println("Slice results: ", sliceResults[0][colName]) + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + return sliceResults[0][colName] + + case "MapScan": + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, id).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + return mapResult[colName] + + default: + t.Fatalf("Unknown method: %s", method) + return nil + } +} + +func validateResult(t *testing.T, cqlType string, expected, actual interface{}, method, valueType string) { + if expected != nil && actual != nil { + expectedType := reflect.TypeOf(expected) + actualType := reflect.TypeOf(actual) + if expectedType != actualType { + t.Errorf("%s %s %s: expected type %v, got %v", method, valueType, cqlType, expectedType, actualType) + } + } + + if !compareValues(t, cqlType, expected, actual) { + t.Errorf("%s %s %s: expected value %v (type %T), got %v (type %T)", + method, valueType, cqlType, expected, expected, actual, actual) + } +} + +func mustParseUUID(s string) UUID { + u, err := ParseUUID(s) + if err != nil { + panic(err) + } + return u +} + +func mustParseBigInt(s string) *big.Int { + i := new(big.Int) + if _, ok := i.SetString(s, 10); !ok { + panic("failed to parse big.Int: " + s) + } + return i +} + +func mustParseDecimal(s string) *inf.Dec { + dec := new(inf.Dec) + if _, ok := dec.SetString(s); !ok { + panic("failed to parse inf.Dec: " + s) + } + return dec +} + +func mustCreateDuration(months int32, days int32, timeDuration time.Duration) Duration { + return Duration{ + Months: months, + Days: days, + Nanoseconds: timeDuration.Nanoseconds(), + } +} + +// TestSliceMapMapScanCounterTypes tests counter types separately since they have special restrictions +// (counter columns can't be mixed with other column types in the same table) +func TestSliceMapMapScanCounterTypes(t *testing.T) { + session := createSessionFromClusterTabletsDisabled(createCluster(), t) + defer session.Close() + + // Create separate table for counter types + if err := createTable(session, ` + CREATE TABLE IF NOT EXISTS gocql_test_tablets_disabled.slicemap_counter_test ( + id int PRIMARY KEY, + counter_col counter + ) + `); err != nil { + t.Fatal("Failed to create counter test table:", err) + } + + // Clear existing data + if err := session.Query("TRUNCATE gocql_test_tablets_disabled.slicemap_counter_test").Exec(); err != nil { + t.Fatal("Failed to truncate counter test table:", err) + } + + testID := 1 + expectedValue := int64(42) + + // Increment counter (can't INSERT into counter, must UPDATE) + err := session.Query("UPDATE gocql_test_tablets_disabled.slicemap_counter_test SET counter_col = counter_col + 42 WHERE id = ?", testID).Exec() + if err != nil { + t.Fatalf("Failed to increment counter: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := "SELECT counter_col FROM gocql_test_tablets_disabled.slicemap_counter_test WHERE id = ?" + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0]["counter_col"] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult["counter_col"] + } + + validateResult(t, "counter", expectedValue, result, method, "incremented") + }) + } +} + +// TestSliceMapMapScanTupleTypes tests tuple types separately since they have special handling +// (tuple elements get split into individual columns) +func TestSliceMapMapScanTupleTypes(t *testing.T) { + session := createSession(t) + defer session.Close() + + // Create test table with tuple column + if err := createTable(session, ` + CREATE TABLE IF NOT EXISTS gocql_test.slicemap_tuple_test ( + id int PRIMARY KEY, + tuple_col tuple + ) + `); err != nil { + t.Fatal("Failed to create tuple test table:", err) + } + + // Clear existing data + if err := session.Query("TRUNCATE gocql_test.slicemap_tuple_test").Exec(); err != nil { + t.Fatal("Failed to truncate tuple test table:", err) + } + + // Test non-NULL tuple + t.Run("NonNull", func(t *testing.T) { + testID := 1 + // Insert tuple value + err := session.Query("INSERT INTO gocql_test.slicemap_tuple_test (id, tuple_col) VALUES (?, (42, 'hello'))", testID).Exec() + if err != nil { + t.Fatalf("Failed to insert tuple value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result map[string]interface{} + + selectQuery := "SELECT tuple_col FROM gocql_test.slicemap_tuple_test WHERE id = ?" + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0] + } else { + result = make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(result); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + } + + // Check tuple elements (tuples get split into individual columns) + elem0Key := TupleColumnName("tuple_col", 0) + elem1Key := TupleColumnName("tuple_col", 1) + + if result[elem0Key] != 42 { + t.Errorf("%s tuple[0]: expected 42, got %v", method, result[elem0Key]) + } + if result[elem1Key] != "hello" { + t.Errorf("%s tuple[1]: expected 'hello', got %v", method, result[elem1Key]) + } + }) + } + }) + + // Test NULL tuple + t.Run("Null", func(t *testing.T) { + testID := 2 + // Insert NULL tuple + err := session.Query("INSERT INTO gocql_test.slicemap_tuple_test (id, tuple_col) VALUES (?, NULL)", testID).Exec() + if err != nil { + t.Fatalf("Failed to insert NULL tuple: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result map[string]interface{} + + selectQuery := "SELECT tuple_col FROM gocql_test.slicemap_tuple_test WHERE id = ?" + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0] + } else { + result = make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(result); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + } + + // Check tuple elements (NULL tuple gives zero values) + elem0Key := TupleColumnName("tuple_col", 0) + elem1Key := TupleColumnName("tuple_col", 1) + + if result[elem0Key] != 0 { + t.Errorf("%s NULL tuple[0]: expected 0, got %v", method, result[elem0Key]) + } + if result[elem1Key] != "" { + t.Errorf("%s NULL tuple[1]: expected '', got %v", method, result[elem1Key]) + } + }) + } + }) +} + +// TestSliceMapMapScanVectorTypes tests vector types separately since they need Cassandra 5.0+ and special table setup +// (vectors need separate tables and version checks) +func TestSliceMapMapScanVectorTypes(t *testing.T) { + session := createSession(t) + defer session.Close() + + if *flagDistribution == "cassandra" && flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + if *flagDistribution == "scylla" && flagCassVersion.Before(2025, 3, 0) { + t.Skip("Vector types have been introduced in ScyllaDB 2025.3") + } + + // Create test table with vector columns + if err := createTable(session, ` + CREATE TABLE IF NOT EXISTS gocql_test.slicemap_vector_test ( + id int PRIMARY KEY, + vector_float_col vector, + vector_text_col vector + ) + `); err != nil { + t.Fatal("Failed to create vector test table:", err) + } + + // Clear existing data + if err := session.Query("TRUNCATE gocql_test.slicemap_vector_test").Exec(); err != nil { + t.Fatal("Failed to truncate vector test table:", err) + } + + testCases := []struct { + colName string + cqlValue string + expectedValue interface{} + expectedNull interface{} + }{ + {"vector_float_col", "[1.0, 2.5, -3.0]", []float32{1.0, 2.5, -3.0}, []float32(nil)}, + {"vector_text_col", "['hello', 'world']", []string{"hello", "world"}, []string(nil)}, + } + + for _, tc := range testCases { + t.Run(tc.colName, func(t *testing.T) { + // Test non-NULL value + t.Run("NonNull", func(t *testing.T) { + testID := 1 + // Insert non-NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_vector_test (id, %s) VALUES (?, %s)", tc.colName, tc.cqlValue) + if err := session.Query(insertQuery, testID).Exec(); err != nil { + t.Fatalf("Failed to insert non-NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_vector_test WHERE id = ?", tc.colName) + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0][tc.colName] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult[tc.colName] + } + + validateResult(t, tc.colName, tc.expectedValue, result, method, "non-NULL") + }) + } + }) + + // Test NULL value + t.Run("Null", func(t *testing.T) { + testID := 2 + // Insert NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_vector_test (id, %s) VALUES (?, NULL)", tc.colName) + if err := session.Query(insertQuery, testID).Exec(); err != nil { + t.Fatalf("Failed to insert NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_vector_test WHERE id = ?", tc.colName) + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0][tc.colName] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult[tc.colName] + } + + // Vectors should return nil slices for NULL values for consistency + validateResult(t, tc.colName, tc.expectedNull, result, method, "NULL") + }) + } + }) + }) + } +} + +// TestSliceMapMapScanCollectionTypes tests collection types separately since they have special handling +// (collections should return nil slices/maps for NULL values for consistency with other slice-based types) +func TestSliceMapMapScanCollectionTypes(t *testing.T) { + session := createSession(t) + defer session.Close() + + // Create test table with collection columns + if err := createTable(session, ` + CREATE TABLE IF NOT EXISTS gocql_test.slicemap_collection_test ( + id int PRIMARY KEY, + list_col list, + set_col set, + map_col map + ) + `); err != nil { + t.Fatal("Failed to create collection test table:", err) + } + + // Clear existing data + if err := session.Query("TRUNCATE gocql_test.slicemap_collection_test").Exec(); err != nil { + t.Fatal("Failed to truncate collection test table:", err) + } + + testCases := []struct { + colName string + cqlValue string + expectedValue interface{} + expectedNull interface{} + }{ + {"list_col", "['a', 'b', 'c']", []string{"a", "b", "c"}, []string(nil)}, + {"set_col", "{1, 2, 3}", []int{1, 2, 3}, []int(nil)}, + {"map_col", "{'key1': 1, 'key2': 2}", map[string]int{"key1": 1, "key2": 2}, map[string]int(nil)}, + } + + for _, tc := range testCases { + t.Run(tc.colName, func(t *testing.T) { + // Test non-NULL value + t.Run("NonNull", func(t *testing.T) { + testID := 1 + // Insert non-NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_collection_test (id, %s) VALUES (?, %s)", tc.colName, tc.cqlValue) + if err := session.Query(insertQuery, testID).Exec(); err != nil { + t.Fatalf("Failed to insert non-NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_collection_test WHERE id = ?", tc.colName) + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0][tc.colName] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult[tc.colName] + } + + // For sets, we need special comparison since order is not guaranteed + if strings.HasPrefix(tc.colName, "set_") { + if !compareCollectionValues(t, tc.colName, tc.expectedValue, result) { + t.Errorf("%s non-NULL %s: expected %v, got %v", method, tc.colName, tc.expectedValue, result) + } + } else { + validateResult(t, tc.colName, tc.expectedValue, result, method, "non-NULL") + } + }) + } + }) + + // Test NULL value + t.Run("Null", func(t *testing.T) { + testID := 2 + // Insert NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_collection_test (id, %s) VALUES (?, NULL)", tc.colName) + if err := session.Query(insertQuery, testID).Exec(); err != nil { + t.Fatalf("Failed to insert NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_collection_test WHERE id = ?", tc.colName) + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0][tc.colName] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult[tc.colName] + } + + // Collections should return nil slices/maps for NULL values for consistency + validateResult(t, tc.colName, tc.expectedNull, result, method, "NULL") + }) + } + }) + }) + } +} diff --git a/internal/tests/rand.go b/internal/tests/rand.go index e5382ebc0..efb1760dc 100644 --- a/internal/tests/rand.go +++ b/internal/tests/rand.go @@ -3,6 +3,7 @@ package tests import ( "math/rand" "sync" + "time" ) // RandInterface defines the thread-safe random number generator interface. @@ -162,3 +163,15 @@ func (r *ThreadSafeRand) Read(p []byte) (n int, err error) { defer r.mux.Unlock() return r.r.Read(p) } + +var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + +const randCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +func RandomText(size int) string { + result := make([]byte, size) + for i := range result { + result[i] = randCharset[rand.Intn(len(randCharset))] + } + return string(result) +} diff --git a/marshal.go b/marshal.go index 860ddd363..dd8d135c7 100644 --- a/marshal.go +++ b/marshal.go @@ -29,6 +29,7 @@ import ( "errors" "fmt" "math" + "math/bits" "reflect" "unsafe" @@ -228,6 +229,10 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { return marshalDate(value) case TypeDuration: return marshalDuration(value) + case TypeCustom: + if vector, ok := info.(VectorType); ok { + return marshalVector(vector, value) + } } // TODO(tux21b): add the remaining types @@ -335,6 +340,10 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { return unmarshalDate(data, value) case TypeDuration: return unmarshalDuration(data, value) + case TypeCustom: + if vector, ok := info.(VectorType); ok { + return unmarshalVector(vector, data, value) + } } // TODO(tux21b): add the remaining types @@ -522,16 +531,6 @@ func marshalVarint(value interface{}) ([]byte, error) { return data, nil } -func decBigInt(data []byte) int64 { - if len(data) != 8 { - return 0 - } - return int64(data[0])<<56 | int64(data[1])<<48 | - int64(data[2])<<40 | int64(data[3])<<32 | - int64(data[4])<<24 | int64(data[5])<<16 | - int64(data[6])<<8 | int64(data[7]) -} - func marshalBool(value interface{}) ([]byte, error) { data, err := boolean.Marshal(value) if err != nil { @@ -815,6 +814,177 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } +func marshalVector(info VectorType, value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } else if _, ok := value.(unsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + if k == reflect.Slice && rv.IsNil() { + return nil, nil + } + + switch k { + case reflect.Slice, reflect.Array: + buf := &bytes.Buffer{} + n := rv.Len() + if n != info.Dimensions { + return nil, marshalErrorf("expected vector with %d dimensions, received %d", info.Dimensions, n) + } + + isLengthType := isVectorVariableLengthType(info.SubType) + for i := 0; i < n; i++ { + item, err := Marshal(info.SubType, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + if isLengthType { + writeUnsignedVInt(buf, uint64(len(item))) + } + buf.Write(item) + } + return buf.Bytes(), nil + } + return nil, marshalErrorf("can not marshal %T into %s. Accepted types: slice, array.", value, info) +} + +func unmarshalVector(info VectorType, data []byte, value interface{}) error { + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + if t.Kind() == reflect.Interface { + if t.NumMethod() != 0 { + return unmarshalErrorf("can not unmarshal into non-empty interface %T", value) + } + t = reflect.TypeOf(info.Zero()) + } + + k := t.Kind() + switch k { + case reflect.Slice, reflect.Array: + if data == nil { + if k == reflect.Array { + return unmarshalErrorf("unmarshal vector: can not store nil in array value") + } + if rv.IsNil() { + return nil + } + rv.Set(reflect.Zero(t)) + return nil + } + if k == reflect.Array { + if rv.Len() != info.Dimensions { + return unmarshalErrorf("unmarshal vector: array of size %d cannot store vector of %d dimensions", rv.Len(), info.Dimensions) + } + } else { + rv.Set(reflect.MakeSlice(t, info.Dimensions, info.Dimensions)) + if rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + } + elemSize := len(data) / info.Dimensions + isLengthType := isVectorVariableLengthType(info.SubType) + for i := 0; i < info.Dimensions; i++ { + offset := 0 + if isLengthType { + m, p, err := readUnsignedVInt(data) + if err != nil { + return err + } + elemSize = int(m) + offset = p + } + if offset > 0 { + data = data[offset:] + } + var unmarshalData []byte + if elemSize >= 0 { + if len(data) < elemSize { + return unmarshalErrorf("unmarshal vector: unexpected eof") + } + unmarshalData = data[:elemSize] + data = data[elemSize:] + } + err := Unmarshal(info.SubType, unmarshalData, rv.Index(i).Addr().Interface()) + if err != nil { + return unmarshalErrorf("failed to unmarshal %s into %T: %s", info.SubType, unmarshalData, err.Error()) + } + } + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: *slice, *array, *interface{}.", info, value) +} + +// isVectorVariableLengthType determines if a type requires explicit length serialization within a vector. +// Variable-length types need their length encoded before the actual data to allow proper deserialization. +// Fixed-length types, on the other hand, don't require this kind of length prefix. +func isVectorVariableLengthType(elemType TypeInfo) bool { + switch elemType.Type() { + case TypeVarchar, TypeAscii, TypeBlob, TypeText, + TypeCounter, + TypeDuration, TypeDate, TypeTime, + TypeDecimal, TypeSmallInt, TypeTinyInt, TypeVarint, + TypeInet, + TypeList, TypeSet, TypeMap, TypeUDT, TypeTuple: + return true + case TypeCustom: + if vecType, ok := elemType.(VectorType); ok { + return isVectorVariableLengthType(vecType.SubType) + } + return true + } + return false +} + +func writeUnsignedVInt(buf *bytes.Buffer, v uint64) { + numBytes := computeUnsignedVIntSize(v) + if numBytes <= 1 { + buf.WriteByte(byte(v)) + return + } + + extraBytes := numBytes - 1 + var tmp = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + tmp[i] = byte(v) + v >>= 8 + } + tmp[0] |= byte(^(0xff >> uint(extraBytes))) + buf.Write(tmp) +} + +func readUnsignedVInt(data []byte) (uint64, int, error) { + if len(data) <= 0 { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[0] + if firstByte&0x80 == 0 { + return uint64(firstByte), 1, nil + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + if len(data) < numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", numBytes+1, len(data)) + } + for i := 0; i < numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return ret, numBytes + 1, nil +} + +func computeUnsignedVIntSize(v uint64) int { + lead0 := bits.LeadingZeros64(v) + return (639 - lead0*9) >> 6 +} + func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { mapInfo, ok := info.(CollectionType) if !ok { @@ -1499,7 +1669,11 @@ type NativeType struct { proto byte } -func NewNativeType(proto byte, typ Type, custom string) NativeType { +func NewNativeType(proto byte, typ Type) NativeType { + return NativeType{proto: proto, typ: typ, custom: ""} +} + +func NewCustomType(proto byte, typ Type, custom string) NativeType { return NativeType{proto: proto, typ: typ, custom: custom} } @@ -1548,6 +1722,21 @@ type CollectionType struct { NativeType } +type VectorType struct { + SubType TypeInfo + NativeType + Dimensions int +} + +// Zero returns the zero value for the vector CQL type. +func (v VectorType) Zero() interface{} { + t, e := v.SubType.NewWithError() + if e != nil { + return nil + } + return reflect.Zero(reflect.SliceOf(reflect.TypeOf(t))).Interface() +} + func (t CollectionType) NewWithError() (interface{}, error) { typ, err := goType(t) if err != nil { diff --git a/marshal_test.go b/marshal_test.go index 036418074..f42792d8f 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -30,6 +30,7 @@ package gocql import ( "bytes" "encoding/binary" + "fmt" "math" "math/big" "net" @@ -426,23 +427,22 @@ func TestMarshalTuple(t *testing.T) { }, } - for _, tc := range testCases { + for i, tc := range testCases { t.Run(tc.name, func(t *testing.T) { data, err := Marshal(info, tc.value) if err != nil { - t.Errorf("marshalTest: %v", err) + t.Errorf("marshalTest[%d]: %v", i, err) return } - if !bytes.Equal(data, tc.expected) { - t.Errorf("marshalTest: expected %x (%v), got %x (%v)", - tc.expected, decBigInt(tc.expected), data, decBigInt(data)) + t.Errorf("marshalTest[%d]: expected %x, got %x", + i, tc.expected, data) return } err = Unmarshal(info, data, tc.checkValue) if err != nil { - t.Errorf("unmarshalTest: %v", err) + t.Errorf("unmarshalTest[%d]: %v", i, err) return } @@ -791,6 +791,37 @@ func TestReadCollectionSize(t *testing.T) { } } +func TestReadUnsignedVInt(t *testing.T) { + tests := []struct { + decodedInt uint64 + encodedVint []byte + }{ + { + decodedInt: 0, + encodedVint: []byte{0}, + }, + { + decodedInt: 100, + encodedVint: []byte{100}, + }, + { + decodedInt: 256000, + encodedVint: []byte{195, 232, 0}, + }, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%d", test.decodedInt), func(t *testing.T) { + actual, _, err := readUnsignedVInt(test.encodedVint) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if actual != test.decodedInt { + t.Fatalf("Expected %d, but got %d", test.decodedInt, actual) + } + }) + } +} + func BenchmarkUnmarshalUUID(b *testing.B) { b.ReportAllocs() src := make([]byte, 16) diff --git a/tests/serialization/marshal_0_unset_test.go b/tests/serialization/marshal_0_unset_test.go index 4cdb33472..4cc0e099d 100644 --- a/tests/serialization/marshal_0_unset_test.go +++ b/tests/serialization/marshal_0_unset_test.go @@ -18,37 +18,37 @@ func TestMarshalUnsetColumn(t *testing.T) { err bool } - elem := gocql.NewNativeType(3, gocql.TypeSmallInt, "") + elem := gocql.NewNativeType(3, gocql.TypeSmallInt) cases := []tCase{ - {gocql.NewNativeType(4, gocql.TypeBoolean, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeTinyInt, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeSmallInt, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeInt, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeBigInt, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeCounter, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeVarint, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeFloat, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeDouble, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeDecimal, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeVarchar, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeText, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeBlob, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeAscii, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeUUID, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeTimeUUID, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeInet, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeTime, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeTimestamp, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeDate, ""), true, false}, - {gocql.NewNativeType(4, gocql.TypeDuration, ""), true, false}, + {gocql.NewNativeType(4, gocql.TypeBoolean), true, false}, + {gocql.NewNativeType(4, gocql.TypeTinyInt), true, false}, + {gocql.NewNativeType(4, gocql.TypeSmallInt), true, false}, + {gocql.NewNativeType(4, gocql.TypeInt), true, false}, + {gocql.NewNativeType(4, gocql.TypeBigInt), true, false}, + {gocql.NewNativeType(4, gocql.TypeCounter), true, false}, + {gocql.NewNativeType(4, gocql.TypeVarint), true, false}, + {gocql.NewNativeType(4, gocql.TypeFloat), true, false}, + {gocql.NewNativeType(4, gocql.TypeDouble), true, false}, + {gocql.NewNativeType(4, gocql.TypeDecimal), true, false}, + {gocql.NewNativeType(4, gocql.TypeVarchar), true, false}, + {gocql.NewNativeType(4, gocql.TypeText), true, false}, + {gocql.NewNativeType(4, gocql.TypeBlob), true, false}, + {gocql.NewNativeType(4, gocql.TypeAscii), true, false}, + {gocql.NewNativeType(4, gocql.TypeUUID), true, false}, + {gocql.NewNativeType(4, gocql.TypeTimeUUID), true, false}, + {gocql.NewNativeType(4, gocql.TypeInet), true, false}, + {gocql.NewNativeType(4, gocql.TypeTime), true, false}, + {gocql.NewNativeType(4, gocql.TypeTimestamp), true, false}, + {gocql.NewNativeType(4, gocql.TypeDate), true, false}, + {gocql.NewNativeType(4, gocql.TypeDuration), true, false}, - {gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeList, ""), nil, elem), true, false}, - {gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeSet, ""), nil, elem), true, false}, + {gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeList), nil, elem), true, false}, + {gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeSet), nil, elem), true, false}, - {gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeMap, ""), elem, elem), true, false}, + {gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeMap), elem, elem), true, false}, {gocql.NewUDTType(3, "udt1", "", gocql.UDTField{Name: "1", Type: elem}), true, true}, - {gocql.NewTupleType(gocql.NewNativeType(3, gocql.TypeTuple, ""), elem), true, true}, + {gocql.NewTupleType(gocql.NewNativeType(3, gocql.TypeTuple), elem), true, true}, } for _, expected := range cases { diff --git a/tests/serialization/marshal_10_decimal_corrupt_test.go b/tests/serialization/marshal_10_decimal_corrupt_test.go index ac8d577cc..8ed1555fb 100644 --- a/tests/serialization/marshal_10_decimal_corrupt_test.go +++ b/tests/serialization/marshal_10_decimal_corrupt_test.go @@ -18,7 +18,7 @@ func TestMarshalDecimalCorrupt(t *testing.T) { name string } - tType := gocql.NewNativeType(4, gocql.TypeDecimal, "") + tType := gocql.NewNativeType(4, gocql.TypeDecimal) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_10_decimal_test.go b/tests/serialization/marshal_10_decimal_test.go index 3a4e3027c..b5877790b 100644 --- a/tests/serialization/marshal_10_decimal_test.go +++ b/tests/serialization/marshal_10_decimal_test.go @@ -20,7 +20,7 @@ import ( func TestMarshalDecimal(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeDecimal, "") + tType := gocql.NewNativeType(4, gocql.TypeDecimal) type testSuite struct { name string diff --git a/tests/serialization/marshal_11_texts_test.go b/tests/serialization/marshal_11_texts_test.go index 527ef2175..31273d46a 100644 --- a/tests/serialization/marshal_11_texts_test.go +++ b/tests/serialization/marshal_11_texts_test.go @@ -42,28 +42,28 @@ func TestMarshalTexts(t *testing.T) { { name: "glob.varchar", marshal: func(i interface{}) ([]byte, error) { - return gocql.Marshal(gocql.NewNativeType(4, gocql.TypeVarchar, ""), i) + return gocql.Marshal(gocql.NewNativeType(4, gocql.TypeVarchar), i) }, unmarshal: func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(gocql.NewNativeType(4, gocql.TypeVarchar, ""), bytes, i) + return gocql.Unmarshal(gocql.NewNativeType(4, gocql.TypeVarchar), bytes, i) }, }, { name: "glob.text", marshal: func(i interface{}) ([]byte, error) { - return gocql.Marshal(gocql.NewNativeType(4, gocql.TypeText, ""), i) + return gocql.Marshal(gocql.NewNativeType(4, gocql.TypeText), i) }, unmarshal: func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(gocql.NewNativeType(4, gocql.TypeText, ""), bytes, i) + return gocql.Unmarshal(gocql.NewNativeType(4, gocql.TypeText), bytes, i) }, }, { name: "glob.blob", marshal: func(i interface{}) ([]byte, error) { - return gocql.Marshal(gocql.NewNativeType(4, gocql.TypeBlob, ""), i) + return gocql.Marshal(gocql.NewNativeType(4, gocql.TypeBlob), i) }, unmarshal: func(bytes []byte, i interface{}) error { - return gocql.Unmarshal(gocql.NewNativeType(4, gocql.TypeBlob, ""), bytes, i) + return gocql.Unmarshal(gocql.NewNativeType(4, gocql.TypeBlob), bytes, i) }, }, } diff --git a/tests/serialization/marshal_12_ascii_corrupt_test.go b/tests/serialization/marshal_12_ascii_corrupt_test.go index dbda677ab..9fb4831f5 100644 --- a/tests/serialization/marshal_12_ascii_corrupt_test.go +++ b/tests/serialization/marshal_12_ascii_corrupt_test.go @@ -15,7 +15,7 @@ import ( func TestMarshalAsciiMustFail(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeAscii, "") + tType := gocql.NewNativeType(4, gocql.TypeAscii) type testSuite struct { name string diff --git a/tests/serialization/marshal_12_ascii_test.go b/tests/serialization/marshal_12_ascii_test.go index d4e593f7e..87175dfea 100644 --- a/tests/serialization/marshal_12_ascii_test.go +++ b/tests/serialization/marshal_12_ascii_test.go @@ -13,7 +13,7 @@ import ( ) func TestMarshalAscii(t *testing.T) { - tType := gocql.NewNativeType(4, gocql.TypeAscii, "") + tType := gocql.NewNativeType(4, gocql.TypeAscii) type testSuite struct { name string diff --git a/tests/serialization/marshal_13_uuids_corrupt_test.go b/tests/serialization/marshal_13_uuids_corrupt_test.go index d57d0f6d3..9b2c43424 100644 --- a/tests/serialization/marshal_13_uuids_corrupt_test.go +++ b/tests/serialization/marshal_13_uuids_corrupt_test.go @@ -17,8 +17,8 @@ func TestMarshalUUIDsMustFail(t *testing.T) { t.Parallel() tTypes := []gocql.NativeType{ - gocql.NewNativeType(4, gocql.TypeUUID, ""), - gocql.NewNativeType(4, gocql.TypeTimeUUID, ""), + gocql.NewNativeType(4, gocql.TypeUUID), + gocql.NewNativeType(4, gocql.TypeTimeUUID), } type testSuite struct { diff --git a/tests/serialization/marshal_13_uuids_test.go b/tests/serialization/marshal_13_uuids_test.go index 06cd9bac5..ebd169c08 100644 --- a/tests/serialization/marshal_13_uuids_test.go +++ b/tests/serialization/marshal_13_uuids_test.go @@ -17,8 +17,8 @@ func TestMarshalUUIDs(t *testing.T) { t.Parallel() tTypes := []gocql.NativeType{ - gocql.NewNativeType(4, gocql.TypeUUID, ""), - gocql.NewNativeType(4, gocql.TypeTimeUUID, ""), + gocql.NewNativeType(4, gocql.TypeUUID), + gocql.NewNativeType(4, gocql.TypeTimeUUID), } type testSuite struct { @@ -124,7 +124,7 @@ func TestMarshalUUIDs(t *testing.T) { func TestMarshalTimeUUID(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeTimeUUID, "") + tType := gocql.NewNativeType(4, gocql.TypeTimeUUID) type testSuite struct { name string diff --git a/tests/serialization/marshal_14_inet_corrupt_test.go b/tests/serialization/marshal_14_inet_corrupt_test.go index 2a418cfb5..776591ae3 100644 --- a/tests/serialization/marshal_14_inet_corrupt_test.go +++ b/tests/serialization/marshal_14_inet_corrupt_test.go @@ -16,7 +16,7 @@ import ( func TestMarshalsInetMustFail(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeInet, "") + tType := gocql.NewNativeType(4, gocql.TypeInet) type testSuite struct { name string diff --git a/tests/serialization/marshal_14_inet_test.go b/tests/serialization/marshal_14_inet_test.go index 5ad15b462..c0389d92e 100644 --- a/tests/serialization/marshal_14_inet_test.go +++ b/tests/serialization/marshal_14_inet_test.go @@ -16,7 +16,7 @@ import ( func TestMarshalsInet(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeInet, "") + tType := gocql.NewNativeType(4, gocql.TypeInet) type testSuite struct { name string diff --git a/tests/serialization/marshal_15_time_corrupt_test.go b/tests/serialization/marshal_15_time_corrupt_test.go index 57dba214f..edd21cbae 100644 --- a/tests/serialization/marshal_15_time_corrupt_test.go +++ b/tests/serialization/marshal_15_time_corrupt_test.go @@ -17,7 +17,7 @@ import ( func TestMarshalTimeCorrupt(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeTime, "") + tType := gocql.NewNativeType(4, gocql.TypeTime) type testSuite struct { name string diff --git a/tests/serialization/marshal_15_time_test.go b/tests/serialization/marshal_15_time_test.go index 79c233035..42892c4c3 100644 --- a/tests/serialization/marshal_15_time_test.go +++ b/tests/serialization/marshal_15_time_test.go @@ -16,7 +16,7 @@ import ( func TestMarshalsTime(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeTime, "") + tType := gocql.NewNativeType(4, gocql.TypeTime) type testSuite struct { name string diff --git a/tests/serialization/marshal_16_timestamp_corrupt_test.go b/tests/serialization/marshal_16_timestamp_corrupt_test.go index 0b5458bcc..2f6dda115 100644 --- a/tests/serialization/marshal_16_timestamp_corrupt_test.go +++ b/tests/serialization/marshal_16_timestamp_corrupt_test.go @@ -16,7 +16,7 @@ import ( func TestMarshalTimestampCorrupt(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeTimestamp, "") + tType := gocql.NewNativeType(4, gocql.TypeTimestamp) type testSuite struct { name string diff --git a/tests/serialization/marshal_16_timestamp_test.go b/tests/serialization/marshal_16_timestamp_test.go index 804862cce..4fdde0a13 100644 --- a/tests/serialization/marshal_16_timestamp_test.go +++ b/tests/serialization/marshal_16_timestamp_test.go @@ -17,7 +17,7 @@ import ( func TestMarshalsTimestamp(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeTimestamp, "") + tType := gocql.NewNativeType(4, gocql.TypeTimestamp) type testSuite struct { name string diff --git a/tests/serialization/marshal_17_date_corrupt_test.go b/tests/serialization/marshal_17_date_corrupt_test.go index 614d67814..248000b77 100644 --- a/tests/serialization/marshal_17_date_corrupt_test.go +++ b/tests/serialization/marshal_17_date_corrupt_test.go @@ -16,7 +16,7 @@ import ( func TestMarshalDateCorrupt(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeDate, "") + tType := gocql.NewNativeType(4, gocql.TypeDate) type testSuite struct { name string diff --git a/tests/serialization/marshal_17_date_test.go b/tests/serialization/marshal_17_date_test.go index cb5b800a8..ec6500822 100644 --- a/tests/serialization/marshal_17_date_test.go +++ b/tests/serialization/marshal_17_date_test.go @@ -17,7 +17,7 @@ import ( func TestMarshalsDate(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeDate, "") + tType := gocql.NewNativeType(4, gocql.TypeDate) type testSuite struct { name string diff --git a/tests/serialization/marshal_18_duration_corrupt_test.go b/tests/serialization/marshal_18_duration_corrupt_test.go index 5c0bd3b1f..8ddc28363 100644 --- a/tests/serialization/marshal_18_duration_corrupt_test.go +++ b/tests/serialization/marshal_18_duration_corrupt_test.go @@ -15,7 +15,7 @@ import ( func TestMarshalDurationCorrupt(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeDuration, "") + tType := gocql.NewNativeType(4, gocql.TypeDuration) marshal := func(i interface{}) ([]byte, error) { return gocql.Marshal(tType, i) } unmarshal := func(bytes []byte, i interface{}) error { diff --git a/tests/serialization/marshal_18_duration_test.go b/tests/serialization/marshal_18_duration_test.go index e3a29df6f..b7ccd9840 100644 --- a/tests/serialization/marshal_18_duration_test.go +++ b/tests/serialization/marshal_18_duration_test.go @@ -14,7 +14,7 @@ import ( ) func TestMarshalsDuration(t *testing.T) { - tType := gocql.NewNativeType(4, gocql.TypeDuration, "") + tType := gocql.NewNativeType(4, gocql.TypeDuration) const nanoDay = 24 * 60 * 60 * 1000 * 1000 * 1000 diff --git a/tests/serialization/marshal_19_list_set_v3_corrupt_test.go b/tests/serialization/marshal_19_list_set_v3_corrupt_test.go index 9d1c12bf8..db4cd1bbf 100644 --- a/tests/serialization/marshal_19_list_set_v3_corrupt_test.go +++ b/tests/serialization/marshal_19_list_set_v3_corrupt_test.go @@ -15,10 +15,10 @@ import ( func TestMarshalSetListV3Corrupt(t *testing.T) { t.Parallel() - elem := gocql.NewNativeType(3, gocql.TypeSmallInt, "") + elem := gocql.NewNativeType(3, gocql.TypeSmallInt) tTypes := []gocql.TypeInfo{ - gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeList, ""), nil, elem), - gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeSet, ""), nil, elem), + gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeList), nil, elem), + gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeSet), nil, elem), } // unmarshal data than bigger the normal data, does not return error. diff --git a/tests/serialization/marshal_19_list_set_v3_test.go b/tests/serialization/marshal_19_list_set_v3_test.go index 4ef7b4f60..d4d724d66 100644 --- a/tests/serialization/marshal_19_list_set_v3_test.go +++ b/tests/serialization/marshal_19_list_set_v3_test.go @@ -14,11 +14,11 @@ import ( func TestMarshalSetListV3(t *testing.T) { t.Parallel() - elem := gocql.NewNativeType(3, gocql.TypeSmallInt, "") + elem := gocql.NewNativeType(3, gocql.TypeSmallInt) tTypes := []gocql.TypeInfo{ - gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeList, ""), nil, elem), - gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeSet, ""), nil, elem), + gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeList), nil, elem), + gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeSet), nil, elem), } // unmarshal `zero` data return an error diff --git a/tests/serialization/marshal_1_boolean_corrupt_test.go b/tests/serialization/marshal_1_boolean_corrupt_test.go index 8d9472cfc..57650cdf3 100644 --- a/tests/serialization/marshal_1_boolean_corrupt_test.go +++ b/tests/serialization/marshal_1_boolean_corrupt_test.go @@ -16,7 +16,7 @@ import ( func TestMarshalBooleanCorrupt(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeBoolean, "") + tType := gocql.NewNativeType(4, gocql.TypeBoolean) type testSuite struct { name string diff --git a/tests/serialization/marshal_1_boolean_test.go b/tests/serialization/marshal_1_boolean_test.go index 3fa3c83cb..9a54453fc 100644 --- a/tests/serialization/marshal_1_boolean_test.go +++ b/tests/serialization/marshal_1_boolean_test.go @@ -15,7 +15,7 @@ import ( func TestMarshalBoolean(t *testing.T) { t.Parallel() - tType := gocql.NewNativeType(4, gocql.TypeBoolean, "") + tType := gocql.NewNativeType(4, gocql.TypeBoolean) type testSuite struct { name string diff --git a/tests/serialization/marshal_20_map_v3_corrupt_test.go b/tests/serialization/marshal_20_map_v3_corrupt_test.go index cefc76daf..66d224da4 100644 --- a/tests/serialization/marshal_20_map_v3_corrupt_test.go +++ b/tests/serialization/marshal_20_map_v3_corrupt_test.go @@ -15,8 +15,8 @@ import ( func TestMarshalMapV3Corrupt(t *testing.T) { t.Parallel() - elem := gocql.NewNativeType(3, gocql.TypeSmallInt, "") - tType := gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeMap, ""), elem, elem) + elem := gocql.NewNativeType(3, gocql.TypeSmallInt) + tType := gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeMap), elem, elem) //unmarshal data than bigger the normal data, does not return error. brokenBigData := serialization.GetTypes(mod.Values{ diff --git a/tests/serialization/marshal_20_map_v3_test.go b/tests/serialization/marshal_20_map_v3_test.go index b726a1305..e1b1f2efa 100644 --- a/tests/serialization/marshal_20_map_v3_test.go +++ b/tests/serialization/marshal_20_map_v3_test.go @@ -14,8 +14,8 @@ import ( func TestMarshalMapV3(t *testing.T) { t.Parallel() - elem := gocql.NewNativeType(3, gocql.TypeSmallInt, "") - tType := gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeMap, ""), elem, elem) + elem := gocql.NewNativeType(3, gocql.TypeSmallInt) + tType := gocql.NewCollectionType(gocql.NewNativeType(3, gocql.TypeMap), elem, elem) refInt16 := func(v int16) *int16 { return &v } refModInt16 := func(v mod.Int16) *mod.Int16 { return &v } diff --git a/tests/serialization/marshal_2_tinyint_corrupt_test.go b/tests/serialization/marshal_2_tinyint_corrupt_test.go index fb8271294..452fcb2f3 100644 --- a/tests/serialization/marshal_2_tinyint_corrupt_test.go +++ b/tests/serialization/marshal_2_tinyint_corrupt_test.go @@ -22,7 +22,7 @@ func TestMarshalTinyintCorrupt(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeTinyInt, "") + tType := gocql.NewNativeType(4, gocql.TypeTinyInt) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_2_tinyint_test.go b/tests/serialization/marshal_2_tinyint_test.go index 22e13e761..e5eee8ef2 100644 --- a/tests/serialization/marshal_2_tinyint_test.go +++ b/tests/serialization/marshal_2_tinyint_test.go @@ -22,7 +22,7 @@ func TestMarshalTinyint(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeTinyInt, "") + tType := gocql.NewNativeType(4, gocql.TypeTinyInt) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_3_smallint_corrupt_test.go b/tests/serialization/marshal_3_smallint_corrupt_test.go index 00fba89c3..b804f1935 100644 --- a/tests/serialization/marshal_3_smallint_corrupt_test.go +++ b/tests/serialization/marshal_3_smallint_corrupt_test.go @@ -22,7 +22,7 @@ func TestMarshalSmallintCorrupt(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeSmallInt, "") + tType := gocql.NewNativeType(4, gocql.TypeSmallInt) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_3_smallint_test.go b/tests/serialization/marshal_3_smallint_test.go index a24e2896c..208a03cc8 100644 --- a/tests/serialization/marshal_3_smallint_test.go +++ b/tests/serialization/marshal_3_smallint_test.go @@ -22,7 +22,7 @@ func TestMarshalSmallint(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeSmallInt, "") + tType := gocql.NewNativeType(4, gocql.TypeSmallInt) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_4_int_corrupt_test.go b/tests/serialization/marshal_4_int_corrupt_test.go index 3b177cca9..474e6a62c 100644 --- a/tests/serialization/marshal_4_int_corrupt_test.go +++ b/tests/serialization/marshal_4_int_corrupt_test.go @@ -22,7 +22,7 @@ func TestMarshalIntCorrupt(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeInt, "") + tType := gocql.NewNativeType(4, gocql.TypeInt) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_4_int_test.go b/tests/serialization/marshal_4_int_test.go index e860752ca..f0a8b8190 100644 --- a/tests/serialization/marshal_4_int_test.go +++ b/tests/serialization/marshal_4_int_test.go @@ -22,7 +22,7 @@ func TestMarshalInt(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeInt, "") + tType := gocql.NewNativeType(4, gocql.TypeInt) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_5_bigint_corrupt_test.go b/tests/serialization/marshal_5_bigint_corrupt_test.go index 3b0c77587..05ee2199a 100644 --- a/tests/serialization/marshal_5_bigint_corrupt_test.go +++ b/tests/serialization/marshal_5_bigint_corrupt_test.go @@ -22,7 +22,7 @@ func TestMarshalBigIntCorrupt(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeBigInt, "") + tType := gocql.NewNativeType(4, gocql.TypeBigInt) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_5_bigint_test.go b/tests/serialization/marshal_5_bigint_test.go index a864630a0..6e46f2728 100644 --- a/tests/serialization/marshal_5_bigint_test.go +++ b/tests/serialization/marshal_5_bigint_test.go @@ -21,7 +21,7 @@ func TestMarshalBigInt(t *testing.T) { marshal func(interface{}) ([]byte, error) unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeBigInt, "") + tType := gocql.NewNativeType(4, gocql.TypeBigInt) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_6_counter_corrupt_test.go b/tests/serialization/marshal_6_counter_corrupt_test.go index b82ca891b..4b014e60a 100644 --- a/tests/serialization/marshal_6_counter_corrupt_test.go +++ b/tests/serialization/marshal_6_counter_corrupt_test.go @@ -22,7 +22,7 @@ func TestMarshalCounterCorrupt(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeCounter, "") + tType := gocql.NewNativeType(4, gocql.TypeCounter) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_6_counter_test.go b/tests/serialization/marshal_6_counter_test.go index ee9c7d498..00ee84877 100644 --- a/tests/serialization/marshal_6_counter_test.go +++ b/tests/serialization/marshal_6_counter_test.go @@ -21,7 +21,7 @@ func TestMarshalCounter(t *testing.T) { marshal func(interface{}) ([]byte, error) unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeCounter, "") + tType := gocql.NewNativeType(4, gocql.TypeCounter) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_7_varint_corrupt_test.go b/tests/serialization/marshal_7_varint_corrupt_test.go index 6e86010af..3f709df47 100644 --- a/tests/serialization/marshal_7_varint_corrupt_test.go +++ b/tests/serialization/marshal_7_varint_corrupt_test.go @@ -19,7 +19,7 @@ func TestMarshalVarIntCorrupt(t *testing.T) { name string } - tType := gocql.NewNativeType(4, gocql.TypeVarint, "") + tType := gocql.NewNativeType(4, gocql.TypeVarint) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_7_varint_test.go b/tests/serialization/marshal_7_varint_test.go index 460a8e85e..0f823df5f 100644 --- a/tests/serialization/marshal_7_varint_test.go +++ b/tests/serialization/marshal_7_varint_test.go @@ -22,7 +22,7 @@ func TestMarshalVarIntNew(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeVarint, "") + tType := gocql.NewNativeType(4, gocql.TypeVarint) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_8_float_corrupt_test.go b/tests/serialization/marshal_8_float_corrupt_test.go index 5a19bfe43..94c19318e 100644 --- a/tests/serialization/marshal_8_float_corrupt_test.go +++ b/tests/serialization/marshal_8_float_corrupt_test.go @@ -21,7 +21,7 @@ func TestMarshalFloatCorrupt(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeFloat, "") + tType := gocql.NewNativeType(4, gocql.TypeFloat) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_8_float_test.go b/tests/serialization/marshal_8_float_test.go index 3560ac1e0..2deb4d59e 100644 --- a/tests/serialization/marshal_8_float_test.go +++ b/tests/serialization/marshal_8_float_test.go @@ -22,7 +22,7 @@ func TestMarshalFloat(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeFloat, "") + tType := gocql.NewNativeType(4, gocql.TypeFloat) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_9_double_corrupt_test.go b/tests/serialization/marshal_9_double_corrupt_test.go index 35599ebc9..ea40b15eb 100644 --- a/tests/serialization/marshal_9_double_corrupt_test.go +++ b/tests/serialization/marshal_9_double_corrupt_test.go @@ -21,7 +21,7 @@ func TestMarshalDoubleCorrupt(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeDouble, "") + tType := gocql.NewNativeType(4, gocql.TypeDouble) testSuites := [2]testSuite{ { diff --git a/tests/serialization/marshal_9_duble_test.go b/tests/serialization/marshal_9_duble_test.go index 94f4f6fde..336c9c429 100644 --- a/tests/serialization/marshal_9_duble_test.go +++ b/tests/serialization/marshal_9_duble_test.go @@ -22,7 +22,7 @@ func TestMarshalDouble(t *testing.T) { unmarshal func(bytes []byte, i interface{}) error } - tType := gocql.NewNativeType(4, gocql.TypeDouble, "") + tType := gocql.NewNativeType(4, gocql.TypeDouble) testSuites := [2]testSuite{ { diff --git a/vector_test.go b/vector_test.go new file mode 100644 index 000000000..b4214ce0e --- /dev/null +++ b/vector_test.go @@ -0,0 +1,426 @@ +//go:build integration +// +build integration + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 + * Copyright (c) 2016, The Gocql authors, + * provided under the BSD-3-Clause License. + * See the NOTICE file distributed with this work for additional information. + */ + +package gocql + +import ( + "fmt" + "net" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" + "gopkg.in/inf.v0" + + "github.com/gocql/gocql/internal/tests" +) + +type person struct { + FirstName string `cql:"first_name"` + LastName string `cql:"last_name"` + Age int `cql:"age"` +} + +func (p person) String() string { + return fmt.Sprintf("Person{firstName: %s, lastName: %s, Age: %d}", p.FirstName, p.LastName, p.Age) +} + +func TestVector_Marshaler(t *testing.T) { + session := createSession(t) + defer session.Close() + + if *flagDistribution == "cassandra" && flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + if *flagDistribution == "scylla" && flagCassVersion.Before(2025, 3, 0) { + t.Skip("Vector types have been introduced in ScyllaDB 2025.3") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_variable(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + insertFixVec := []float32{8, 2.5, -5.0} + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, insertFixVec).Exec() + if err != nil { + t.Fatal(err) + } + var selectFixVec []float32 + err = session.Query("SELECT vec FROM vector_fixed WHERE id = ?", 1).Scan(&selectFixVec) + if err != nil { + t.Fatal(err) + } + tests.AssertDeepEqual(t, "fixed size element vector", insertFixVec, selectFixVec) + + longText := tests.RandomText(500) + insertVarVec := []string{"apache", "cassandra", longText, "gocql"} + err = session.Query("INSERT INTO vector_variable(id, vec) VALUES(?, ?)", 1, insertVarVec).Exec() + if err != nil { + t.Fatal(err) + } + var selectVarVec []string + err = session.Query("SELECT vec FROM vector_variable WHERE id = ?", 1).Scan(&selectVarVec) + if err != nil { + t.Fatal(err) + } + tests.AssertDeepEqual(t, "variable size element vector", insertVarVec, selectVarVec) +} + +func TestVector_Types(t *testing.T) { + session := createSession(t) + defer session.Close() + + if *flagDistribution == "cassandra" && flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + if *flagDistribution == "scylla" && flagCassVersion.Before(2025, 3, 0) { + t.Skip("Vector types have been introduced in ScyllaDB 2025.3") + } + + timestamp1, _ := time.Parse("2006-01-02", "2000-01-01") + timestamp2, _ := time.Parse("2006-01-02 15:04:05", "2024-01-01 10:31:45") + timestamp3, _ := time.Parse("2006-01-02 15:04:05.000", "2024-05-01 10:31:45.987") + + date1, _ := time.Parse("2006-01-02", "2000-01-01") + date2, _ := time.Parse("2006-01-02", "2022-03-14") + date3, _ := time.Parse("2006-01-02", "2024-12-31") + + time1, _ := time.Parse("15:04:05", "01:00:00") + time2, _ := time.Parse("15:04:05", "15:23:59") + time3, _ := time.Parse("15:04:05.000", "10:31:45.987") + + duration1 := Duration{0, 1, 1920000000000} + duration2 := Duration{1, 1, 1920000000000} + duration3 := Duration{31, 0, 60000000000} + + // map1 := make(map[string]int) + // map1["a"] = 1 + // map1["b"] = 2 + // map1["c"] = 3 + // map2 := make(map[string]int) + // map2["abc"] = 123 + // map3 := make(map[string]int) + + testCases := []struct { + name string + cqlType string + value interface{} + comparator func(interface{}, interface{}) + }{ + {name: "ascii", cqlType: TypeAscii.String(), value: []string{"a", "1", "Z"}}, + {name: "bigint", cqlType: TypeBigInt.String(), value: []int64{1, 2, 3}}, + {name: "blob", cqlType: TypeBlob.String(), value: [][]byte{[]byte{1, 2, 3}, []byte{4, 5, 6, 7}, []byte{8, 9}}}, + {name: "boolean", cqlType: TypeBoolean.String(), value: []bool{true, false, true}}, + {name: "counter", cqlType: TypeCounter.String(), value: []int64{5, 6, 7}}, + {name: "decimal", cqlType: TypeDecimal.String(), value: []inf.Dec{*inf.NewDec(1, 0), *inf.NewDec(2, 1), *inf.NewDec(-3, 2)}}, + {name: "double", cqlType: TypeDouble.String(), value: []float64{0.1, -1.2, 3}}, + {name: "float", cqlType: TypeFloat.String(), value: []float32{0.1, -1.2, 3}}, + {name: "int", cqlType: TypeInt.String(), value: []int32{1, 2, 3}}, + {name: "text", cqlType: TypeText.String(), value: []string{"a", "b", "c"}}, + {name: "timestamp", cqlType: TypeTimestamp.String(), value: []time.Time{timestamp1, timestamp2, timestamp3}}, + {name: "uuid", cqlType: TypeUUID.String(), value: []UUID{MustRandomUUID(), MustRandomUUID(), MustRandomUUID()}}, + {name: "varchar", cqlType: TypeVarchar.String(), value: []string{"abc", "def", "ghi"}}, + {name: "varint", cqlType: TypeVarint.String(), value: []uint64{uint64(1234), uint64(123498765), uint64(18446744073709551615)}}, + {name: "timeuuid", cqlType: TypeTimeUUID.String(), value: []UUID{TimeUUID(), TimeUUID(), TimeUUID()}}, + { + name: "inet", + cqlType: TypeInet.String(), + value: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(192, 168, 1, 1), net.IPv4(8, 8, 8, 8)}, + comparator: func(e interface{}, a interface{}) { + expected := e.([]net.IP) + actual := a.([]net.IP) + tests.AssertEqual(t, "vector size", len(expected), len(actual)) + for i, _ := range expected { + tests.AssertTrue(t, "vector", expected[i].Equal(actual[i])) + } + }, + }, + {name: "date", cqlType: TypeDate.String(), value: []time.Time{date1, date2, date3}}, + {name: "time", cqlType: TypeTimestamp.String(), value: []time.Time{time1, time2, time3}}, + {name: "smallint", cqlType: TypeSmallInt.String(), value: []int16{127, 256, -1234}}, + {name: "tinyint", cqlType: TypeTinyInt.String(), value: []int8{127, 9, -123}}, + {name: "duration", cqlType: TypeDuration.String(), value: []Duration{duration1, duration2, duration3}}, + {name: "vector_vector_float", cqlType: "vector", value: [][]float32{{0.1, -1.2, 3, 5, 5}, {10.1, -122222.0002, 35.0, 1, 1}, {0, 0, 0, 0, 0}}}, + // {name: "vector_vector_set_float", cqlType: "vector, 5>", value: [][][]float32{ + // {{1, 2}, {2, -1}, {3}, {0}, {-1.3}}, + // {{2, 3}, {2, -1}, {3}, {0}, {-1.3}}, + // {{1, 1000.0}, {0}, {}, {12, 14, 15, 16}, {-1.3}}, + // }}, // disable until INSERTing Vector is fixed on scylladb side + {name: "vector_tuple_text_int_float", cqlType: "tuple", value: [][]interface{}{{"a", 1, float32(0.5)}, {"b", 2, float32(-1.2)}, {"c", 3, float32(0)}}}, + {name: "vector_tuple_text_list_text", cqlType: "tuple>", value: [][]interface{}{{"a", []string{"b", "c"}}, {"d", []string{"e", "f", "g"}}, {"h", []string{"i"}}}}, + // {name: "vector_set_text", cqlType: "set", value: [][]string{{"a", "b"}, {"c", "d"}, {"e", "f"}}}, // disable until INSERTing Vector is fixed on scylladb side + {name: "vector_list_int", cqlType: "list", value: [][]int32{{1, 2, 3}, {-1, -2, -3}, {0, 0, 0}}}, + // {name: "vector_map_text_int", cqlType: "map", value: []map[string]int{map1, map2, map3}}, // disable until INSERTing Vector is fixed on scylladb side + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + tableName := fmt.Sprintf("vector_%s", test.name) + err := createTable(session, fmt.Sprintf(`CREATE TABLE IF NOT EXISTS gocql_test.%s(id int primary key, vec vector<%s, 3>);`, tableName, test.cqlType)) + if err != nil { + t.Fatal(err) + } + + err = session.Query(fmt.Sprintf("INSERT INTO %s(id, vec) VALUES(?, ?)", tableName), 1, test.value).Exec() + if err != nil { + t.Fatal(err) + } + + v := reflect.New(reflect.TypeOf(test.value)) + err = session.Query(fmt.Sprintf("SELECT vec FROM %s WHERE id = ?", tableName), 1).Scan(v.Interface()) + if err != nil { + t.Fatal(err) + } + if test.comparator != nil { + test.comparator(test.value, v.Elem().Interface()) + } else { + tests.AssertDeepEqual(t, "vector", test.value, v.Elem().Interface()) + } + }) + } +} + +func TestVector_MarshalerUDT(t *testing.T) { + session := createSession(t) + defer session.Close() + + if *flagDistribution == "cassandra" && flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + if *flagDistribution == "scylla" && flagCassVersion.Before(2025, 3, 0) { + t.Skip("Vector types have been introduced in ScyllaDB 2025.3") + } + + err := createTable(session, `CREATE TYPE gocql_test.person( + first_name text, + last_name text, + age int);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE gocql_test.vector_relatives( + id int, + couple vector, + primary key(id) + );`) + if err != nil { + t.Fatal(err) + } + + p1 := person{"Johny", "Bravo", 25} + p2 := person{"Capitan", "Planet", 5} + insVec := []person{p1, p2} + + err = session.Query("INSERT INTO vector_relatives(id, couple) VALUES(?, ?)", 1, insVec).Exec() + if err != nil { + t.Fatal(err) + } + + var selVec []person + + err = session.Query("SELECT couple FROM vector_relatives WHERE id = ?", 1).Scan(&selVec) + if err != nil { + t.Fatal(err) + } + + tests.AssertDeepEqual(t, "udt", &insVec, &selVec) +} + +func TestVector_Empty(t *testing.T) { + session := createSession(t) + defer session.Close() + + if *flagDistribution == "cassandra" && flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + if *flagDistribution == "scylla" && flagCassVersion.Before(2025, 3, 0) { + t.Skip("Vector types have been introduced in ScyllaDB 2025.3") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed_null(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_variable_null(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = session.Query("INSERT INTO vector_fixed_null(id) VALUES(?)", 1).Exec() + if err != nil { + t.Fatal(err) + } + var selectFixVec []float32 + err = session.Query("SELECT vec FROM vector_fixed_null WHERE id = ?", 1).Scan(&selectFixVec) + if err != nil { + t.Fatal(err) + } + tests.AssertTrue(t, "fixed size element vector is empty", selectFixVec == nil) + + err = session.Query("INSERT INTO vector_variable_null(id) VALUES(?)", 1).Exec() + if err != nil { + t.Fatal(err) + } + var selectVarVec []string + err = session.Query("SELECT vec FROM vector_variable_null WHERE id = ?", 1).Scan(&selectVarVec) + if err != nil { + t.Fatal(err) + } + tests.AssertTrue(t, "variable size element vector is empty", selectVarVec == nil) +} + +func TestVector_MissingDimension(t *testing.T) { + session := createSession(t) + defer session.Close() + + if *flagDistribution == "cassandra" && flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + if *flagDistribution == "scylla" && flagCassVersion.Before(2025, 3, 0) { + t.Skip("Vector types have been introduced in ScyllaDB 2025.3") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0}).Exec() + require.Error(t, err, "expected vector with 3 dimensions, received 2") + + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0, 1, 3}).Exec() + require.Error(t, err, "expected vector with 3 dimensions, received 4") +} + +func TestVector_SubTypeParsing(t *testing.T) { + testCases := []struct { + name string + custom string + expected TypeInfo + }{ + {name: "text", custom: "org.apache.cassandra.db.marshal.UTF8Type", expected: NativeType{typ: TypeVarchar}}, + // disable until INSERTing Vector is fixed on scylladb side + // {name: "set_int", custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type)", expected: CollectionType{NativeType{typ: TypeSet}, nil, NativeType{typ: TypeInt}}}, + { + name: "udt", + custom: "org.apache.cassandra.db.marshal.UserType(gocql_test,706572736f6e,66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,6c6173745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,616765:org.apache.cassandra.db.marshal.Int32Type)", + expected: UDTTypeInfo{ + NativeType: NativeType{typ: TypeUDT}, + KeySpace: "gocql_test", + Name: "person", + Elements: []UDTField{ + UDTField{Name: "first_name", Type: NativeType{typ: TypeVarchar}}, + UDTField{Name: "last_name", Type: NativeType{typ: TypeVarchar}}, + UDTField{Name: "age", Type: NativeType{typ: TypeInt}}, + }, + }, + }, + { + name: "tuple", + custom: "org.apache.cassandra.db.marshal.TupleType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)", + expected: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: []TypeInfo{ + NativeType{typ: TypeVarchar}, + NativeType{typ: TypeInt}, + NativeType{typ: TypeVarchar}, + }, + }, + }, + { + name: "vector_vector_inet", + custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)", + expected: VectorType{ + NativeType: NativeType{typ: TypeCustom, custom: "org.apache.cassandra.db.marshal.VectorType"}, + SubType: VectorType{ + NativeType: NativeType{typ: TypeCustom, custom: "org.apache.cassandra.db.marshal.VectorType"}, + SubType: NativeType{typ: TypeInet}, + Dimensions: 2, + }, + Dimensions: 3, + }, + }, + { + name: "map_int_vector_text", + custom: "org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 10))", + expected: CollectionType{ + NativeType: NativeType{typ: TypeMap}, + Key: NativeType{typ: TypeInt}, + Elem: VectorType{ + NativeType: NativeType{typ: TypeCustom, custom: "org.apache.cassandra.db.marshal.VectorType"}, + SubType: NativeType{typ: TypeVarchar}, + Dimensions: 10, + }, + }, + }, + // disable until INSERTing Vector is fixed on scylladb side + // { + // name: "set_map_vector_text_text", + // custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 10),org.apache.cassandra.db.marshal.UTF8Type))", + // expected: CollectionType{ + // NativeType: NativeType{typ: TypeSet}, + // Key: nil, + // Elem: CollectionType{ + // NativeType{typ: TypeMap}, + // VectorType{ + // NativeType: NativeType{typ: TypeCustom, custom: "org.apache.cassandra.db.marshal.VectorType"}, + // SubType: NativeType{typ: TypeInt}, + // Dimensions: 10, + // }, + // NativeType{typ: TypeVarchar}, + // }, + // }, + // }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + f := newFramer(nil, 0) + f.writeShort(0) + f.writeString(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom)) + parsedType := f.readTypeInfo() + require.IsType(t, parsedType, VectorType{}) + vectorType := parsedType.(VectorType) + tests.AssertEqual(t, "dimensions", 2, vectorType.Dimensions) + tests.AssertDeepEqual(t, "vector", test.expected, vectorType.SubType) + }) + } +}