Skip to content

Commit 57c99b9

Browse files
Apply review comments
1 parent 9205eff commit 57c99b9

File tree

5 files changed

+95
-70
lines changed

5 files changed

+95
-70
lines changed

frame.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -936,19 +936,19 @@ func (f *framer) readTypeInfo() TypeInfo {
936936
// b) using parseType()
937937
// I think we could agree to use getTypeInfo() when parsing binary type definition
938938
// and parseType() would be responsible for parsing "custom" string definition.
939-
//spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE)
940-
//spec = spec[1 : len(spec)-1] // remove parenthesis
941-
//idx := strings.LastIndex(spec, ",")
942-
//typeStr := spec[:idx]
943-
//dimStr := spec[idx+1:]
944-
//subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{})
945-
//dim, _ := strconv.Atoi(strings.TrimSpace(dimStr))
946-
result := parseType(simple.custom, simple.proto, nopLogger{})
947-
dim, _ := strconv.Atoi(result.types[1].Custom())
939+
spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE)
940+
spec = spec[1 : len(spec)-1] // remove parenthesis
941+
idx := strings.LastIndex(spec, ",")
942+
typeStr := spec[:idx]
943+
dimStr := spec[idx+1:]
944+
subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{})
945+
dim, _ := strconv.Atoi(strings.TrimSpace(dimStr))
946+
//result := parseType(simple.custom, simple.proto, nopLogger{})
947+
//dim, _ := strconv.Atoi(result.types[1].Custom())
948948
vector := VectorType{
949949
NativeType: simple,
950-
//SubType: subType,
951-
SubType: result.types[0],
950+
SubType: subType,
951+
//SubType: result.types[0],
952952
Dimensions: dim,
953953
}
954954
return vector

helpers.go

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -164,30 +164,30 @@ func getCassandraBaseType(name string) Type {
164164
}
165165
}
166166

167+
// Parses short CQL type representation to internal data structures.
168+
// Mapping of long Java-style type definition into short format is performed in
169+
// apacheToCassandraType function.
167170
func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo {
168171
if strings.HasPrefix(name, "frozen<") {
169172
return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), protoVer, logger)
170173
} else if strings.HasPrefix(name, "set<") {
171174
return CollectionType{
172-
NativeType: NativeType{typ: TypeSet, proto: protoVer},
175+
NativeType: NewNativeType(protoVer, TypeSet),
173176
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), protoVer, logger),
174177
}
175178
} else if strings.HasPrefix(name, "list<") {
176179
return CollectionType{
177-
NativeType: NativeType{typ: TypeList, proto: protoVer},
180+
NativeType: NewNativeType(protoVer, TypeList),
178181
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), protoVer, logger),
179182
}
180183
} else if strings.HasPrefix(name, "map<") {
181184
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<"))
182185
if len(names) != 2 {
183186
logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names))
184-
return NativeType{
185-
proto: protoVer,
186-
typ: TypeCustom,
187-
}
187+
return NewNativeType(protoVer, TypeCustom)
188188
}
189189
return CollectionType{
190-
NativeType: NativeType{typ: TypeMap, proto: protoVer},
190+
NativeType: NewNativeType(protoVer, TypeMap),
191191
Key: getCassandraType(names[0], protoVer, logger),
192192
Elem: getCassandraType(names[1], protoVer, logger),
193193
}
@@ -200,11 +200,29 @@ func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo {
200200
}
201201

202202
return TupleTypeInfo{
203-
NativeType: NativeType{typ: TypeTuple, proto: protoVer},
203+
NativeType: NewNativeType(protoVer, TypeTuple),
204204
Elems: types,
205205
}
206-
} else if strings.HasPrefix(name, "udt<") {
207-
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "udt<"))
206+
} else if strings.HasPrefix(name, "vector<") {
207+
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<"))
208+
subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger)
209+
dim, _ := strconv.Atoi(strings.TrimSpace(names[1]))
210+
211+
return VectorType{
212+
NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE),
213+
SubType: subType,
214+
Dimensions: dim,
215+
}
216+
} else if strings.Index(name, "<") == -1 {
217+
// basic type
218+
return NativeType{
219+
proto: protoVer,
220+
typ: getCassandraBaseType(name),
221+
}
222+
} else {
223+
// udt
224+
idx := strings.Index(name, "<")
225+
names := splitCompositeTypes(name[idx+1 : len(name)-1])
208226
fields := make([]UDTField, len(names)-2)
209227

210228
for i := 2; i < len(names); i++ {
@@ -218,30 +236,11 @@ func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo {
218236

219237
udtName, _ := hex.DecodeString(names[1])
220238
return UDTTypeInfo{
221-
NativeType: NativeType{typ: TypeUDT, proto: protoVer},
239+
NativeType: NewNativeType(protoVer, TypeUDT),
222240
KeySpace: names[0],
223241
Name: string(udtName),
224242
Elements: fields,
225243
}
226-
} else if strings.HasPrefix(name, "vector<") {
227-
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<"))
228-
subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger)
229-
dim, _ := strconv.Atoi(strings.TrimSpace(names[1]))
230-
231-
return VectorType{
232-
NativeType: NativeType{
233-
proto: protoVer,
234-
typ: TypeCustom,
235-
custom: VECTOR_TYPE,
236-
},
237-
SubType: subType,
238-
Dimensions: dim,
239-
}
240-
} else {
241-
return NativeType{
242-
proto: protoVer,
243-
typ: getCassandraBaseType(name),
244-
}
245244
}
246245
}
247246

@@ -273,35 +272,34 @@ func splitCompositeTypes(name string) []string {
273272
return parts
274273
}
275274

275+
// Convert long Java style type definition into the short CQL type names.
276276
func apacheToCassandraType(t string) string {
277277
t = strings.Replace(t, "(", "<", -1)
278278
t = strings.Replace(t, ")", ">", -1)
279279
types := strings.FieldsFunc(t, func(r rune) bool {
280280
return r == '<' || r == '>' || r == ','
281281
})
282-
skip := 0
283-
for _, class := range types {
284-
class = strings.TrimSpace(class)
285-
if !isDigitsOnly(class) {
286-
// vector types include dimension (digits) as second type parameter
287-
// UDT fields are represented in format {field id}:{class}, example 66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type
288-
if skip > 0 {
289-
skip -= 1
290-
continue
291-
}
292-
idx := strings.Index(class, ":")
293-
class = class[idx+1:]
294-
act := getApacheCassandraType(class)
295-
val := act.String()
296-
switch act {
297-
case TypeUDT:
298-
val = "udt"
299-
skip = 2 // skip next two parameters (keyspace and type ID), do not attempt to resolve their type
300-
case TypeCustom:
282+
for i := 0; i < len(types); i++ {
283+
class := strings.TrimSpace(types[i])
284+
// UDT fields are represented in format {field id}:{class}, example 66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type
285+
// Do not override hex encoded field names
286+
idx := strings.Index(class, ":")
287+
class = class[idx+1:]
288+
act := getApacheCassandraType(class)
289+
val := act.String()
290+
switch act {
291+
case TypeUDT:
292+
i += 2 // skip next two parameters (keyspace and type ID), do not attempt to resolve their type
293+
case TypeCustom:
294+
if isDigitsOnly(class) {
295+
// vector types include dimension (digits) as second type parameter
296+
// getApacheCassandraType() returns "custom" by default, but we need to leave digits intact
297+
val = class
298+
} else {
301299
val = getApacheCassandraCustomSubType(class)
302300
}
303-
t = strings.Replace(t, class, val, -1)
304301
}
302+
t = strings.Replace(t, class, val, -1)
305303
}
306304
// This is done so it exactly matches what Cassandra returns
307305
return strings.Replace(t, ",", ", ", -1)
@@ -373,6 +371,8 @@ func getApacheCassandraType(class string) Type {
373371
}
374372
}
375373

374+
// Dedicated function parsing known special subtypes of CQL custom type.
375+
// Currently, only vectors are implemented as special custom subtype.
376376
func getApacheCassandraCustomSubType(class string) string {
377377
switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
378378
case "VectorType":

marshal.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2649,7 +2649,11 @@ type NativeType struct {
26492649
custom string // only used for TypeCustom
26502650
}
26512651

2652-
func NewNativeType(proto byte, typ Type, custom string) NativeType {
2652+
func NewNativeType(proto byte, typ Type) NativeType {
2653+
return NativeType{proto, typ, ""}
2654+
}
2655+
2656+
func NewCustomType(proto byte, typ Type, custom string) NativeType {
26532657
return NativeType{proto, typ, custom}
26542658
}
26552659

metadata_test.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,17 @@ func TestTypeParser(t *testing.T) {
643643
assertTypeInfo{Type: TypeUDT, Custom: ""},
644644
)
645645

646+
// vector
647+
assertParseCompositeType(
648+
t,
649+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)",
650+
[]assertTypeInfo{
651+
{Type: TypeFloat},
652+
{Type: TypeCustom, Custom: "3"},
653+
},
654+
nil,
655+
)
656+
646657
// custom
647658
assertParseNonCompositeType(
648659
t,
@@ -702,7 +713,7 @@ func assertParseNonCompositeType(
702713
) {
703714

704715
log := &defaultLogger{}
705-
result := parseType(def, 4, log)
716+
result := parseType(def, protoVersion4, log)
706717
if len(result.reversed) != 1 {
707718
t.Errorf("%s expected %d reversed values but there were %d", def, 1, len(result.reversed))
708719
}
@@ -733,7 +744,7 @@ func assertParseCompositeType(
733744
) {
734745

735746
log := &defaultLogger{}
736-
result := parseType(def, 4, log)
747+
result := parseType(def, protoVersion4, log)
737748
if len(result.reversed) != len(typesExpected) {
738749
t.Errorf("%s expected %d reversed values but there were %d", def, len(typesExpected), len(result.reversed))
739750
}
@@ -749,7 +760,7 @@ func assertParseCompositeType(
749760
if !result.isComposite {
750761
t.Errorf("%s: Expected composite", def)
751762
}
752-
if result.collections == nil {
763+
if result.collections == nil && collectionsExpected != nil {
753764
t.Errorf("%s: Expected non-nil collections: %v", def, result.collections)
754765
}
755766

vector_test.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ func TestVector_SubTypeParsing(t *testing.T) {
347347
name: "vector_vector_inet",
348348
custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)",
349349
expected: VectorType{
350-
NativeType{typ: TypeCustom},
350+
NativeType{typ: TypeCustom, custom: VECTOR_TYPE},
351351
VectorType{
352-
NativeType{typ: TypeCustom},
352+
NativeType{typ: TypeCustom, custom: VECTOR_TYPE},
353353
NativeType{typ: TypeInet},
354354
2,
355355
},
@@ -363,7 +363,7 @@ func TestVector_SubTypeParsing(t *testing.T) {
363363
NativeType{typ: TypeMap},
364364
NativeType{typ: TypeInt},
365365
VectorType{
366-
NativeType{typ: TypeCustom},
366+
NativeType{typ: TypeCustom, custom: VECTOR_TYPE},
367367
NativeType{typ: TypeVarchar},
368368
10,
369369
},
@@ -373,8 +373,18 @@ func TestVector_SubTypeParsing(t *testing.T) {
373373

374374
for _, test := range tests {
375375
t.Run(test.name, func(t *testing.T) {
376-
subType := parseType(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom), 0, &defaultLogger{})
377-
assertDeepEqual(t, "vector", test.expected, subType.types[0])
376+
f := newFramer(nil, 0)
377+
f.writeShort(0)
378+
f.writeString(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom))
379+
parsedType := f.readTypeInfo()
380+
require.IsType(t, parsedType, VectorType{})
381+
382+
// test first parsing method
383+
vectorType := parsedType.(VectorType)
384+
assertEqual(t, "dimensions", 2, vectorType.Dimensions)
385+
assertDeepEqual(t, "vector", test.expected, vectorType.SubType)
386+
//subType := parseType(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom), 0, &defaultLogger{})
387+
//assertDeepEqual(t, "vector", test.expected, subType.types[0])
378388
})
379389
}
380390
}

0 commit comments

Comments
 (0)