Skip to content

Commit 61387c3

Browse files
committed
CASSGO-43: externally-defined type registration
The new RegisterType function can be used to register externally-defined types. You'll need to define your own marshalling and unmarshalling code as well as a TypeInfo implementation. The name and id MUST not collide with existing and future native CQL types. Additionally, a lot of the type handling was refactored to use the new format for native types. Performance should be slightly improved thanks to some simplification. Benchmarks are coming soon. Patch by James Hartig for CASSGO-43
1 parent 91cbf12 commit 61387c3

File tree

10 files changed

+885
-540
lines changed

10 files changed

+885
-540
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010

11+
- Externally-defined type registration (CASSGO-43)
12+
1113
### Changed
1214

1315
- Don't restrict server authenticator unless PasswordAuthentictor.AllowedAuthenticators is provided (CASSGO-19)

cassandra_test.go

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2415,18 +2415,19 @@ func TestAggregateMetadata(t *testing.T) {
24152415
t.Fatal("expected two aggregates")
24162416
}
24172417

2418+
protoVer := byte(session.cfg.ProtoVersion)
24182419
expectedAggregrate := AggregateMetadata{
24192420
Keyspace: "gocql_test",
24202421
Name: "average",
2421-
ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt}},
2422+
ArgumentTypes: []TypeInfo{NativeType{proto: protoVer, typ: TypeInt}},
24222423
InitCond: "(0, 0)",
2423-
ReturnType: NativeType{typ: TypeDouble},
2424+
ReturnType: NativeType{proto: protoVer, typ: TypeDouble},
24242425
StateType: TupleTypeInfo{
2425-
NativeType: NativeType{typ: TypeTuple},
2426+
NativeType: NativeType{proto: protoVer, typ: TypeTuple},
24262427

24272428
Elems: []TypeInfo{
2428-
NativeType{typ: TypeInt},
2429-
NativeType{typ: TypeBigInt},
2429+
NativeType{proto: protoVer, typ: TypeInt},
2430+
NativeType{proto: protoVer, typ: TypeBigInt},
24302431
},
24312432
},
24322433
stateFunc: "avgstate",
@@ -2439,11 +2440,11 @@ func TestAggregateMetadata(t *testing.T) {
24392440
}
24402441

24412442
if !reflect.DeepEqual(aggregates[0], expectedAggregrate) {
2442-
t.Fatalf("aggregate 'average' is %+v, but expected %+v", aggregates[0], expectedAggregrate)
2443+
t.Fatalf("aggregate 'average' is %#v, but expected %#v", aggregates[0], expectedAggregrate)
24432444
}
24442445
expectedAggregrate.Name = "average2"
24452446
if !reflect.DeepEqual(aggregates[1], expectedAggregrate) {
2446-
t.Fatalf("aggregate 'average2' is %+v, but expected %+v", aggregates[1], expectedAggregrate)
2447+
t.Fatalf("aggregate 'average2' is %#v, but expected %#v", aggregates[1], expectedAggregrate)
24472448
}
24482449
}
24492450

@@ -2465,28 +2466,29 @@ func TestFunctionMetadata(t *testing.T) {
24652466
avgState := functions[1]
24662467
avgFinal := functions[0]
24672468

2469+
protoVer := byte(session.cfg.ProtoVersion)
24682470
avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;"
24692471
expectedAvgState := FunctionMetadata{
24702472
Keyspace: "gocql_test",
24712473
Name: "avgstate",
24722474
ArgumentTypes: []TypeInfo{
24732475
TupleTypeInfo{
2474-
NativeType: NativeType{typ: TypeTuple},
2476+
NativeType: NativeType{proto: protoVer, typ: TypeTuple},
24752477

24762478
Elems: []TypeInfo{
2477-
NativeType{typ: TypeInt},
2478-
NativeType{typ: TypeBigInt},
2479+
NativeType{proto: protoVer, typ: TypeInt},
2480+
NativeType{proto: protoVer, typ: TypeBigInt},
24792481
},
24802482
},
2481-
NativeType{typ: TypeInt},
2483+
NativeType{proto: protoVer, typ: TypeInt},
24822484
},
24832485
ArgumentNames: []string{"state", "val"},
24842486
ReturnType: TupleTypeInfo{
2485-
NativeType: NativeType{typ: TypeTuple},
2487+
NativeType: NativeType{proto: protoVer, typ: TypeTuple},
24862488

24872489
Elems: []TypeInfo{
2488-
NativeType{typ: TypeInt},
2489-
NativeType{typ: TypeBigInt},
2490+
NativeType{proto: protoVer, typ: TypeInt},
2491+
NativeType{proto: protoVer, typ: TypeBigInt},
24902492
},
24912493
},
24922494
CalledOnNullInput: true,
@@ -2503,22 +2505,22 @@ func TestFunctionMetadata(t *testing.T) {
25032505
Name: "avgfinal",
25042506
ArgumentTypes: []TypeInfo{
25052507
TupleTypeInfo{
2506-
NativeType: NativeType{typ: TypeTuple},
2508+
NativeType: NativeType{proto: protoVer, typ: TypeTuple},
25072509

25082510
Elems: []TypeInfo{
2509-
NativeType{typ: TypeInt},
2510-
NativeType{typ: TypeBigInt},
2511+
NativeType{proto: protoVer, typ: TypeInt},
2512+
NativeType{proto: protoVer, typ: TypeBigInt},
25112513
},
25122514
},
25132515
},
25142516
ArgumentNames: []string{"state"},
2515-
ReturnType: NativeType{typ: TypeDouble},
2517+
ReturnType: NativeType{proto: protoVer, typ: TypeDouble},
25162518
CalledOnNullInput: true,
25172519
Language: "java",
25182520
Body: finalStateBody,
25192521
}
25202522
if !reflect.DeepEqual(avgFinal, expectedAvgFinal) {
2521-
t.Fatalf("function is %+v, but expected %+v", avgFinal, expectedAvgFinal)
2523+
t.Fatalf("function is %#v, but expected %#v", avgFinal, expectedAvgFinal)
25222524
}
25232525
}
25242526

@@ -2616,19 +2618,20 @@ func TestKeyspaceMetadata(t *testing.T) {
26162618
if flagCassVersion.Before(3, 0, 0) {
26172619
textType = TypeVarchar
26182620
}
2621+
protoVer := byte(session.cfg.ProtoVersion)
26192622
expectedType := UserTypeMetadata{
26202623
Keyspace: "gocql_test",
26212624
Name: "basicview",
26222625
FieldNames: []string{"birthday", "nationality", "weight", "height"},
26232626
FieldTypes: []TypeInfo{
2624-
NativeType{typ: TypeTimestamp},
2625-
NativeType{typ: textType},
2626-
NativeType{typ: textType},
2627-
NativeType{typ: textType},
2627+
NativeType{proto: protoVer, typ: TypeTimestamp},
2628+
NativeType{proto: protoVer, typ: textType},
2629+
NativeType{proto: protoVer, typ: textType},
2630+
NativeType{proto: protoVer, typ: textType},
26282631
},
26292632
}
26302633
if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], expectedType) {
2631-
t.Fatalf("type is %+v, but expected %+v", keyspaceMetadata.UserTypes["basicview"], expectedType)
2634+
t.Fatalf("type is %#v, but expected %#v", keyspaceMetadata.UserTypes["basicview"], expectedType)
26322635
}
26332636
if flagCassVersion.Major >= 3 {
26342637
materializedView, found := keyspaceMetadata.MaterializedViews["view_view"]

frame.go

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"io"
3232
"io/ioutil"
3333
"net"
34+
"reflect"
3435
"runtime"
3536
"strings"
3637
"time"
@@ -858,68 +859,98 @@ func (w *writePrepareFrame) buildFrame(f *framer, streamID int) error {
858859
return f.finish()
859860
}
860861

861-
func (f *framer) readTypeInfo() TypeInfo {
862-
// TODO: factor this out so the same code paths can be used to parse custom
863-
// types and other types, as much of the logic will be duplicated.
864-
id := f.readShort()
865-
866-
simple := NativeType{
867-
proto: f.proto,
868-
typ: Type(id),
869-
}
870-
871-
if simple.typ == TypeCustom {
872-
simple.custom = f.readString()
873-
if cassType := getApacheCassandraType(simple.custom); cassType != TypeCustom {
874-
simple.typ = cassType
875-
}
876-
}
862+
var (
863+
typeInfoType = reflect.TypeOf((*TypeInfo)(nil)).Elem()
864+
typeType = reflect.TypeOf(Type(0))
865+
typeInfoListType = reflect.TypeOf([]TypeInfo(nil))
866+
stringListType = reflect.TypeOf([]string(nil))
867+
udtFieldListType = reflect.TypeOf([]UDTField(nil))
868+
stringType = reflect.TypeOf("")
869+
shortType = reflect.TypeOf(uint16(0))
870+
byteType = reflect.TypeOf(byte(0))
871+
intType = reflect.TypeOf(int(0))
872+
)
877873

878-
switch simple.typ {
879-
case TypeTuple:
874+
func (f *framer) readForType(typ reflect.Type) interface{} {
875+
// check simple equality first
876+
switch typ {
877+
case stringType:
878+
return f.readString()
879+
case shortType:
880+
return f.readShort()
881+
case byteType:
882+
return f.readByte()
883+
case intType:
884+
return f.readInt()
885+
case stringListType:
886+
return f.readStringList()
887+
case udtFieldListType:
880888
n := f.readShort()
881-
tuple := TupleTypeInfo{
882-
NativeType: simple,
883-
Elems: make([]TypeInfo, n),
884-
}
885-
889+
fields := make([]UDTField, n)
886890
for i := 0; i < int(n); i++ {
887-
tuple.Elems[i] = f.readTypeInfo()
888-
}
889-
890-
return tuple
891-
892-
case TypeUDT:
893-
udt := UDTTypeInfo{
894-
NativeType: simple,
891+
fields[i] = UDTField{
892+
Name: f.readString(),
893+
Type: f.readTypeInfo(),
894+
}
895895
}
896-
udt.KeySpace = f.readString()
897-
udt.Name = f.readString()
898-
896+
return fields
897+
case typeInfoType:
898+
return f.readTypeInfo()
899+
case typeType:
900+
return Type(f.readShort())
901+
case typeInfoListType:
899902
n := f.readShort()
900-
udt.Elements = make([]UDTField, n)
903+
types := make([]TypeInfo, n)
901904
for i := 0; i < int(n); i++ {
902-
field := &udt.Elements[i]
903-
field.Name = f.readString()
904-
field.Type = f.readTypeInfo()
905+
types[i] = f.readTypeInfo()
905906
}
907+
return types
908+
}
906909

907-
return udt
908-
case TypeMap, TypeList, TypeSet:
909-
collection := CollectionType{
910-
NativeType: simple,
910+
// then check the kind and try to convert
911+
switch typ.Kind() {
912+
case reflect.String:
913+
return reflect.ValueOf(f.readString()).Convert(typ).Interface()
914+
case reflect.Int:
915+
return reflect.ValueOf(f.readInt()).Convert(typ).Interface()
916+
case reflect.Slice:
917+
n := f.readShort()
918+
slice := reflect.MakeSlice(typ, int(n), int(n))
919+
for i := 0; i < int(n); i++ {
920+
slice.Index(i).Set(reflect.ValueOf(f.readForType(typ.Elem())))
911921
}
922+
return slice.Interface()
923+
}
924+
panic(fmt.Errorf("unsupported type for reading from frame: %s", typ.String()))
925+
}
912926

913-
if simple.typ == TypeMap {
914-
collection.Key = f.readTypeInfo()
927+
func (f *framer) readTypeInfo() TypeInfo {
928+
typ := Type(f.readShort())
929+
930+
cqlct, ok := registeredCompositeTypes[typ]
931+
if ok {
932+
paramsTypes := cqlct.Params(int(f.proto))
933+
var params []interface{}
934+
if len(paramsTypes) > 0 {
935+
params = make([]interface{}, len(paramsTypes))
936+
for i, paramType := range paramsTypes {
937+
params[i] = f.readForType(paramType)
938+
}
915939
}
940+
return cqlct.TypeInfoParams(int(f.proto), params)
941+
}
916942

917-
collection.Elem = f.readTypeInfo()
918-
919-
return collection
943+
// custom is a special case, we need to read the name then get the type info
944+
if typ == TypeCustom {
945+
name := f.readString()
946+
return (customCQLType{}).TypeInfoParams(int(f.proto), []interface{}{name})
920947
}
921948

922-
return simple
949+
cqlt, ok := registeredTypes[typ]
950+
if !ok {
951+
panic(fmt.Errorf("unknown type id: %d", typ))
952+
}
953+
return cqlt.TypeInfo(int(f.proto))
923954
}
924955

925956
type preparedMetadata struct {

0 commit comments

Comments
 (0)