Skip to content

Commit 907e69e

Browse files
Alternative to parse custom string
1 parent 4d2f080 commit 907e69e

File tree

5 files changed

+169
-30
lines changed

5 files changed

+169
-30
lines changed

frame.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -931,16 +931,24 @@ func (f *framer) readTypeInfo() TypeInfo {
931931
return collection
932932
case TypeCustom:
933933
if strings.HasPrefix(simple.custom, VECTOR_TYPE) {
934-
spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE)
935-
spec = spec[1 : len(spec)-1] // remove parenthesis
936-
idx := strings.LastIndex(spec, ",")
937-
typeStr := spec[:idx]
938-
dimStr := spec[idx+1:]
939-
subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{})
940-
dim, _ := strconv.Atoi(strings.TrimSpace(dimStr))
934+
// TODO(lantoniak): There are currently two ways of parsing types in the driver.
935+
// a) using getTypeInfo()
936+
// b) using parseType()
937+
// I think we could agree to use getTypeInfo() when parsing binary type definition
938+
// and parseType() would be responsible for parsing "custom" string definition.
939+
//spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE)
940+
//spec = spec[1 : len(spec)-1] // remove parenthesis
941+
//idx := strings.LastIndex(spec, ",")
942+
//typeStr := spec[:idx]
943+
//dimStr := spec[idx+1:]
944+
//subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{})
945+
//dim, _ := strconv.Atoi(strings.TrimSpace(dimStr))
946+
result := parseType(simple.custom, simple.proto, nopLogger{})
947+
dim, _ := strconv.Atoi(result.types[1].Custom())
941948
vector := VectorType{
942949
NativeType: simple,
943-
SubType: subType,
950+
//SubType: subType,
951+
SubType: result.types[0],
944952
Dimensions: dim,
945953
}
946954
return vector

marshal.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,7 +1746,7 @@ func marshalVector(info VectorType, value interface{}) ([]byte, error) {
17461746
if err != nil {
17471747
return nil, err
17481748
}
1749-
if isVectorVariableLengthType(info.SubType.Type()) {
1749+
if isVectorVariableLengthType(info.SubType) {
17501750
writeUnsignedVInt(buf, uint64(len(item)))
17511751
}
17521752
buf.Write(item)
@@ -1786,7 +1786,7 @@ func unmarshalVector(info VectorType, data []byte, value interface{}) error {
17861786
elemSize := len(data) / info.Dimensions
17871787
for i := 0; i < info.Dimensions; i++ {
17881788
offset := 0
1789-
if isVectorVariableLengthType(info.SubType.Type()) {
1789+
if isVectorVariableLengthType(info.SubType) {
17901790
m, p, err := readUnsignedVint(data, 0)
17911791
if err != nil {
17921792
return err
@@ -1815,8 +1815,8 @@ func unmarshalVector(info VectorType, data []byte, value interface{}) error {
18151815
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
18161816
}
18171817

1818-
func isVectorVariableLengthType(elemType Type) bool {
1819-
switch elemType {
1818+
func isVectorVariableLengthType(elemType TypeInfo) bool {
1819+
switch elemType.Type() {
18201820
case TypeVarchar, TypeAscii, TypeBlob, TypeText:
18211821
return true
18221822
case TypeCounter:
@@ -1829,6 +1829,13 @@ func isVectorVariableLengthType(elemType Type) bool {
18291829
return true
18301830
case TypeList, TypeSet, TypeMap, TypeUDT:
18311831
return true
1832+
case TypeCustom:
1833+
switch elemType.(type) {
1834+
case VectorType:
1835+
vecType := elemType.(VectorType)
1836+
return isVectorVariableLengthType(vecType.SubType)
1837+
}
1838+
return true
18321839
}
18331840
return false
18341841
}

metadata.go

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ func compileMetadata(
389389
col.Order = DESC
390390
}
391391
} else {
392-
validatorParsed := parseType(col.Validator, logger)
392+
validatorParsed := parseType(col.Validator, byte(protoVersion), logger)
393393
col.Type = validatorParsed.types[0]
394394
col.Order = ASC
395395
if validatorParsed.reversed[0] {
@@ -411,9 +411,9 @@ func compileMetadata(
411411
}
412412

413413
if protoVersion == protoVersion1 {
414-
compileV1Metadata(tables, logger)
414+
compileV1Metadata(tables, protoVersion, logger)
415415
} else {
416-
compileV2Metadata(tables, logger)
416+
compileV2Metadata(tables, protoVersion, logger)
417417
}
418418
}
419419

@@ -422,14 +422,14 @@ func compileMetadata(
422422
// column metadata as V2+ (because V1 doesn't support the "type" column in the
423423
// system.schema_columns table) so determining PartitionKey and ClusterColumns
424424
// is more complex.
425-
func compileV1Metadata(tables []TableMetadata, logger StdLogger) {
425+
func compileV1Metadata(tables []TableMetadata, protoVer int, logger StdLogger) {
426426
for i := range tables {
427427
table := &tables[i]
428428

429429
// decode the key validator
430-
keyValidatorParsed := parseType(table.KeyValidator, logger)
430+
keyValidatorParsed := parseType(table.KeyValidator, byte(protoVer), logger)
431431
// decode the comparator
432-
comparatorParsed := parseType(table.Comparator, logger)
432+
comparatorParsed := parseType(table.Comparator, byte(protoVer), logger)
433433

434434
// the partition key length is the same as the number of types in the
435435
// key validator
@@ -515,7 +515,7 @@ func compileV1Metadata(tables []TableMetadata, logger StdLogger) {
515515
alias = table.ValueAlias
516516
}
517517
// decode the default validator
518-
defaultValidatorParsed := parseType(table.DefaultValidator, logger)
518+
defaultValidatorParsed := parseType(table.DefaultValidator, byte(protoVer), logger)
519519
column := &ColumnMetadata{
520520
Keyspace: table.Keyspace,
521521
Table: table.Name,
@@ -529,15 +529,15 @@ func compileV1Metadata(tables []TableMetadata, logger StdLogger) {
529529
}
530530

531531
// The simpler compile case for V2+ protocol
532-
func compileV2Metadata(tables []TableMetadata, logger StdLogger) {
532+
func compileV2Metadata(tables []TableMetadata, protoVer int, logger StdLogger) {
533533
for i := range tables {
534534
table := &tables[i]
535535

536536
clusteringColumnCount := componentColumnCountOfType(table.Columns, ColumnClusteringKey)
537537
table.ClusteringColumns = make([]*ColumnMetadata, clusteringColumnCount)
538538

539539
if table.KeyValidator != "" {
540-
keyValidatorParsed := parseType(table.KeyValidator, logger)
540+
keyValidatorParsed := parseType(table.KeyValidator, byte(protoVer), logger)
541541
table.PartitionKey = make([]*ColumnMetadata, len(keyValidatorParsed.types))
542542
} else { // Cassandra 3.x+
543543
partitionKeyCount := componentColumnCountOfType(table.Columns, ColumnPartitionKey)
@@ -1186,6 +1186,7 @@ type typeParser struct {
11861186
input string
11871187
index int
11881188
logger StdLogger
1189+
proto byte
11891190
}
11901191

11911192
// the type definition parser result
@@ -1197,8 +1198,8 @@ type typeParserResult struct {
11971198
}
11981199

11991200
// Parse the type definition used for validator and comparator schema data
1200-
func parseType(def string, logger StdLogger) typeParserResult {
1201-
parser := &typeParser{input: def, logger: logger}
1201+
func parseType(def string, protoVer byte, logger StdLogger) typeParserResult {
1202+
parser := &typeParser{input: def, proto: protoVer, logger: logger}
12021203
return parser.parse()
12031204
}
12041205

@@ -1209,6 +1210,7 @@ const (
12091210
LIST_TYPE = "org.apache.cassandra.db.marshal.ListType"
12101211
SET_TYPE = "org.apache.cassandra.db.marshal.SetType"
12111212
MAP_TYPE = "org.apache.cassandra.db.marshal.MapType"
1213+
UDT_TYPE = "org.apache.cassandra.db.marshal.UserType"
12121214
VECTOR_TYPE = "org.apache.cassandra.db.marshal.VectorType"
12131215
)
12141216

@@ -1218,6 +1220,7 @@ type typeParserClassNode struct {
12181220
params []typeParserParamNode
12191221
// this is the segment of the input string that defined this node
12201222
input string
1223+
proto byte
12211224
}
12221225

12231226
// represents a class parameter in the type def AST
@@ -1237,6 +1240,7 @@ func (t *typeParser) parse() typeParserResult {
12371240
NativeType{
12381241
typ: TypeCustom,
12391242
custom: t.input,
1243+
proto: t.proto,
12401244
},
12411245
},
12421246
reversed: []bool{false},
@@ -1292,6 +1296,26 @@ func (t *typeParser) parse() typeParserResult {
12921296
reversed: reversed,
12931297
collections: collections,
12941298
}
1299+
} else if strings.HasPrefix(ast.name, VECTOR_TYPE) {
1300+
count := len(ast.params)
1301+
1302+
types := make([]TypeInfo, count)
1303+
reversed := make([]bool, count)
1304+
1305+
for i, param := range ast.params[:count] {
1306+
class := param.class
1307+
reversed[i] = strings.HasPrefix(class.name, REVERSED_TYPE)
1308+
if reversed[i] {
1309+
class = class.params[0].class
1310+
}
1311+
types[i] = class.asTypeInfo()
1312+
}
1313+
1314+
return typeParserResult{
1315+
isComposite: true,
1316+
types: types,
1317+
reversed: reversed,
1318+
}
12951319
} else {
12961320
// not composite, so one type
12971321
class := *ast
@@ -1314,7 +1338,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo {
13141338
elem := class.params[0].class.asTypeInfo()
13151339
return CollectionType{
13161340
NativeType: NativeType{
1317-
typ: TypeList,
1341+
typ: TypeList,
1342+
proto: class.proto,
13181343
},
13191344
Elem: elem,
13201345
}
@@ -1323,7 +1348,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo {
13231348
elem := class.params[0].class.asTypeInfo()
13241349
return CollectionType{
13251350
NativeType: NativeType{
1326-
typ: TypeSet,
1351+
typ: TypeSet,
1352+
proto: class.proto,
13271353
},
13281354
Elem: elem,
13291355
}
@@ -1333,15 +1359,47 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo {
13331359
elem := class.params[1].class.asTypeInfo()
13341360
return CollectionType{
13351361
NativeType: NativeType{
1336-
typ: TypeMap,
1362+
typ: TypeMap,
1363+
proto: class.proto,
13371364
},
13381365
Key: key,
13391366
Elem: elem,
13401367
}
13411368
}
1369+
if strings.HasPrefix(class.name, UDT_TYPE) {
1370+
udtName, _ := hex.DecodeString(class.params[1].class.name)
1371+
fields := make([]UDTField, len(class.params)-2)
1372+
for i := 2; i < len(class.params); i++ {
1373+
fieldName, _ := hex.DecodeString(*class.params[i].name)
1374+
fields[i-2] = UDTField{
1375+
Name: string(fieldName),
1376+
Type: class.params[i].class.asTypeInfo(),
1377+
}
1378+
}
1379+
return UDTTypeInfo{
1380+
NativeType: NativeType{
1381+
typ: TypeUDT,
1382+
proto: class.proto,
1383+
},
1384+
KeySpace: class.params[0].class.name,
1385+
Name: string(udtName),
1386+
Elements: fields,
1387+
}
1388+
}
1389+
if strings.HasPrefix(class.name, VECTOR_TYPE) {
1390+
dim, _ := strconv.Atoi(class.params[1].class.name)
1391+
return VectorType{
1392+
NativeType: NativeType{
1393+
typ: TypeCustom,
1394+
proto: class.proto,
1395+
},
1396+
SubType: class.params[0].class.asTypeInfo(),
1397+
Dimensions: dim,
1398+
}
1399+
}
13421400

13431401
// must be a simple type or custom type
1344-
info := NativeType{typ: getApacheCassandraType(class.name)}
1402+
info := NativeType{typ: getApacheCassandraType(class.name), proto: class.proto}
13451403
if info.typ == TypeCustom {
13461404
// add the entire class definition
13471405
info.custom = class.input
@@ -1371,6 +1429,7 @@ func (t *typeParser) parseClassNode() (node *typeParserClassNode, ok bool) {
13711429
name: name,
13721430
params: params,
13731431
input: t.input[startIndex:endIndex],
1432+
proto: t.proto,
13741433
}
13751434
return node, true
13761435
}

metadata_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,12 +636,14 @@ func TestTypeParser(t *testing.T) {
636636
},
637637
)
638638

639-
// custom
639+
// udt
640640
assertParseNonCompositeType(
641641
t,
642642
"org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)",
643643
assertTypeInfo{Type: TypeUDT, Custom: ""},
644644
)
645+
646+
// custom
645647
assertParseNonCompositeType(
646648
t,
647649
"org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.marshal.LexicalUUIDType)",
@@ -700,7 +702,7 @@ func assertParseNonCompositeType(
700702
) {
701703

702704
log := &defaultLogger{}
703-
result := parseType(def, log)
705+
result := parseType(def, 4, log)
704706
if len(result.reversed) != 1 {
705707
t.Errorf("%s expected %d reversed values but there were %d", def, 1, len(result.reversed))
706708
}
@@ -731,7 +733,7 @@ func assertParseCompositeType(
731733
) {
732734

733735
log := &defaultLogger{}
734-
result := parseType(def, log)
736+
result := parseType(def, 4, log)
735737
if len(result.reversed) != len(typesExpected) {
736738
t.Errorf("%s expected %d reversed values but there were %d", def, len(typesExpected), len(result.reversed))
737739
}

vector_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ func TestVector_Types(t *testing.T) {
165165
{name: "duration", cqlType: TypeDuration.String(), value: []Duration{duration1, duration2, duration3}},
166166
// TODO(lantonia): Test vector of custom types
167167
{name: "vector_vector_float", cqlType: "vector<float, 5>", value: [][]float32{{0.1, -1.2, 3, 5, 5}, {10.1, -122222.0002, 35.0, 1, 1}, {0, 0, 0, 0, 0}}},
168+
{name: "vector_vector_set_float", cqlType: "vector<set<float>, 5>", value: [][][]float32{
169+
{{1, 2}, {2, -1}, {3}, {0}, {-1.3}},
170+
{{2, 3}, {2, -1}, {3}, {0}, {-1.3}},
171+
{{1, 1000.0}, {0}, {}, {12, 14, 15, 16}, {-1.3}},
172+
}},
168173
{name: "vector_set_text", cqlType: "set<text>", value: [][]string{{"a", "b"}, {"c", "d"}, {"e", "f"}}},
169174
{name: "vector_list_int", cqlType: "list<int>", value: [][]int32{{1, 2, 3}, {-1, -2, -3}, {0, 0, 0}}},
170175
{name: "vector_map_text_int", cqlType: "map<text, int>", value: []map[string]int{map1, map2, map3}},
@@ -301,3 +306,61 @@ func TestVector_MissingDimension(t *testing.T) {
301306
err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0, 1, 3}).Exec()
302307
require.Error(t, err, "expected vector with 3 dimensions, received 4")
303308
}
309+
310+
func TestVector_SubTypeParsing(t *testing.T) {
311+
tests := []struct {
312+
name string
313+
custom string
314+
expected TypeInfo
315+
}{
316+
{name: "text", custom: "org.apache.cassandra.db.marshal.UTF8Type", expected: NativeType{typ: TypeVarchar}},
317+
{name: "set_int", custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type)", expected: CollectionType{NativeType{typ: TypeSet}, nil, NativeType{typ: TypeInt}}},
318+
{
319+
name: "udt",
320+
custom: "org.apache.cassandra.db.marshal.UserType(gocql_test,706572736f6e,66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,6c6173745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,616765:org.apache.cassandra.db.marshal.Int32Type)",
321+
expected: UDTTypeInfo{
322+
NativeType{typ: TypeUDT},
323+
"gocql_test",
324+
"person",
325+
[]UDTField{
326+
UDTField{"first_name", NativeType{typ: TypeVarchar}},
327+
UDTField{"last_name", NativeType{typ: TypeVarchar}},
328+
UDTField{"age", NativeType{typ: TypeInt}},
329+
},
330+
},
331+
},
332+
{
333+
name: "vector_vector_inet",
334+
custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)",
335+
expected: VectorType{
336+
NativeType{typ: TypeCustom},
337+
VectorType{
338+
NativeType{typ: TypeCustom},
339+
NativeType{typ: TypeInet},
340+
2,
341+
},
342+
3,
343+
},
344+
},
345+
{
346+
name: "map_int_vector_text",
347+
custom: "org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 10))",
348+
expected: CollectionType{
349+
NativeType{typ: TypeMap},
350+
NativeType{typ: TypeInt},
351+
VectorType{
352+
NativeType{typ: TypeCustom},
353+
NativeType{typ: TypeVarchar},
354+
10,
355+
},
356+
},
357+
},
358+
}
359+
360+
for _, test := range tests {
361+
t.Run(test.name, func(t *testing.T) {
362+
subType := parseType(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom), 0, &defaultLogger{})
363+
assertDeepEqual(t, "vector", test.expected, subType.types[0])
364+
})
365+
}
366+
}

0 commit comments

Comments
 (0)