diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c6c7cc29..b17c5f008 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support vector type (CASSGO-11) - Allow SERIAL and LOCAL_SERIAL on SELECT statements (CASSGO-26) - Support of sending queries to the specific node with Query.SetHostID() (CASSGO-4) -- Support for Native Protocol 5 (CASSGO-1) +- Support for Native Protocol 5. Following protocol changes exposed new API + Query.SetKeyspace(), Query.WithNowInSeconds(), Batch.SetKeyspace(), Batch.WithNowInSeconds() (CASSGO-1) +- Externally-defined type registration (CASSGO-43) ### Changed @@ -36,12 +38,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Refactor HostInfo creation and ConnectAddress() method (CASSGO-45) - gocql.Compressor interface changes to follow append-like design. Bumped Go version to 1.19 (CASSGO-1) - Refactoring hostpool package test and Expose HostInfo creation (CASSGO-59) - - Move "execute batch" methods to Batch type (CASSGO-57) - - Make `Session` immutable by removing setters and associated mutex (CASSGO-23) +- inet columns default to net.IP when using MapScan or SliceMap (CASSGO-43) +- NativeType removed (CASSGO-43) +- `New` and `NewWithError` removed and replaced with `Zero` (CASSGO-43) ### Fixed + - Cassandra version unmarshal fix (CASSGO-49) - Retry policy now takes into account query idempotency (CASSGO-27) - Don't return error to caller with RetryType Ignore (CASSGO-28) @@ -50,7 +54,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Skip metadata only if the prepared result includes metadata (CASSGO-40) - Don't panic in MapExecuteBatchCAS if no `[applied]` column is returned (CASSGO-42) - Fix deadlock in refresh debouncer stop (CASSGO-41) - - Endless query execution fix (CASSGO-50) ## [1.7.0] - 2024-09-23 diff --git a/cassandra_test.go b/cassandra_test.go index 1a9b69fef..9fa2a0d1a 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -645,7 +645,7 @@ func TestDurationType(t *testing.T) { defer session.Close() if session.cfg.ProtoVersion < 5 { - t.Skip("Duration type is not supported. Please use protocol version >= 4 and cassandra version >= 3.11") + t.Skip("Duration type is not supported. Please use protocol version > 4") } if err := createTable(session, `CREATE TABLE gocql_test.duration_table ( @@ -1068,7 +1068,7 @@ func TestMapScan(t *testing.T) { } assertEqual(t, "fullname", "Ada Lovelace", row["fullname"]) assertEqual(t, "age", 30, row["age"]) - assertEqual(t, "address", "10.0.0.2", row["address"]) + assertDeepEqual(t, "address", net.ParseIP("10.0.0.2").To4(), row["address"]) assertDeepEqual(t, "data", []byte(`{"foo": "bar"}`), row["data"]) // Second iteration using a new map @@ -1078,7 +1078,7 @@ func TestMapScan(t *testing.T) { } assertEqual(t, "fullname", "Grace Hopper", row["fullname"]) assertEqual(t, "age", 31, row["age"]) - assertEqual(t, "address", "10.0.0.1", row["address"]) + assertDeepEqual(t, "address", net.ParseIP("10.0.0.1").To4(), row["address"]) assertDeepEqual(t, "data", []byte(nil), row["data"]) } @@ -1125,7 +1125,7 @@ func TestSliceMap(t *testing.T) { m["testset"] = []int{1, 2, 3, 4, 5, 6, 7, 8, 9} m["testmap"] = map[string]string{"field1": "val1", "field2": "val2", "field3": "val3"} m["testvarint"] = bigInt - m["testinet"] = "213.212.2.19" + m["testinet"] = net.ParseIP("213.212.2.19").To4() sliceMap := []map[string]interface{}{m} if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testlist"], m["testset"], m["testmap"], m["testvarint"], m["testinet"]).Exec(); err != nil { @@ -1157,51 +1157,105 @@ func TestSliceMap(t *testing.T) { } func matchSliceMap(t *testing.T, sliceMap []map[string]interface{}, testMap map[string]interface{}) { if sliceMap[0]["testuuid"] != testMap["testuuid"] { - t.Fatal("returned testuuid did not match") + t.Fatalf("returned testuuid %#v did not match %#v", sliceMap[0]["testuuid"], testMap["testuuid"]) } if sliceMap[0]["testtimestamp"] != testMap["testtimestamp"] { - t.Fatal("returned testtimestamp did not match") + t.Fatalf("returned testtimestamp %#v did not match %#v", sliceMap[0]["testtimestamp"], testMap["testtimestamp"]) } if sliceMap[0]["testvarchar"] != testMap["testvarchar"] { - t.Fatal("returned testvarchar did not match") + t.Fatalf("returned testvarchar %#v did not match %#v", sliceMap[0]["testvarchar"], testMap["testvarchar"]) } if sliceMap[0]["testbigint"] != testMap["testbigint"] { - t.Fatal("returned testbigint did not match") + t.Fatalf("returned testbigint %#v did not match %#v", sliceMap[0]["testbigint"], testMap["testbigint"]) } if !reflect.DeepEqual(sliceMap[0]["testblob"], testMap["testblob"]) { - t.Fatal("returned testblob did not match") + t.Fatalf("returned testblob %#v did not match %#v", sliceMap[0]["testblob"], testMap["testblob"]) } if sliceMap[0]["testbool"] != testMap["testbool"] { - t.Fatal("returned testbool did not match") + t.Fatalf("returned testbool %#v did not match %#v", sliceMap[0]["testbool"], testMap["testbool"]) } if sliceMap[0]["testfloat"] != testMap["testfloat"] { - t.Fatal("returned testfloat did not match") + t.Fatalf("returned testfloat %#v did not match %#v", sliceMap[0]["testfloat"], testMap["testfloat"]) } if sliceMap[0]["testdouble"] != testMap["testdouble"] { - t.Fatal("returned testdouble did not match") + t.Fatalf("returned testdouble %#v did not match %#v", sliceMap[0]["testdouble"], testMap["testdouble"]) } - if sliceMap[0]["testinet"] != testMap["testinet"] { - t.Fatal("returned testinet did not match") + if !reflect.DeepEqual(sliceMap[0]["testinet"], testMap["testinet"]) { + t.Fatalf("returned testinet %#v did not match %#v", sliceMap[0]["testinet"], testMap["testinet"]) } expectedDecimal := sliceMap[0]["testdecimal"].(*inf.Dec) returnedDecimal := testMap["testdecimal"].(*inf.Dec) if expectedDecimal.Cmp(returnedDecimal) != 0 { - t.Fatal("returned testdecimal did not match") + t.Fatalf("returned testdecimal %#v did not match %#v", sliceMap[0]["testdecimal"], testMap["testdecimal"]) } if !reflect.DeepEqual(sliceMap[0]["testlist"], testMap["testlist"]) { - t.Fatal("returned testlist did not match") + t.Fatalf("returned testlist %#v did not match %#v", sliceMap[0]["testlist"], testMap["testlist"]) } if !reflect.DeepEqual(sliceMap[0]["testset"], testMap["testset"]) { - t.Fatal("returned testset did not match") + t.Fatalf("returned testset %#v did not match %#v", sliceMap[0]["testset"], testMap["testset"]) } if !reflect.DeepEqual(sliceMap[0]["testmap"], testMap["testmap"]) { - t.Fatal("returned testmap did not match") + t.Fatalf("returned testmap %#v did not match %#v", sliceMap[0]["testmap"], testMap["testmap"]) } if sliceMap[0]["testint"] != testMap["testint"] { - t.Fatal("returned testint did not match") + t.Fatalf("returned testint %#v did not match %#v", sliceMap[0]["testint"], testMap["testint"]) + } +} + +func TestSliceMap_CopySlices(t *testing.T) { + session := createSession(t) + defer session.Close() + if err := createTable(session, `CREATE TABLE gocql_test.slice_map_copy_table ( + t text, + u timeuuid, + l list, + PRIMARY KEY (t, u) + )`); err != nil { + t.Fatal("create table:", err) + } + + err := session.Query( + `INSERT INTO slice_map_copy_table (t, u, l) VALUES ('test', ?, ?)`, + TimeUUID(), []string{"1", "2"}, + ).Exec() + if err != nil { + t.Fatal("insert:", err) + } + + err = session.Query( + `INSERT INTO slice_map_copy_table (t, u, l) VALUES ('test', ?, ?)`, + TimeUUID(), []string{"3", "4"}, + ).Exec() + if err != nil { + t.Fatal("insert:", err) + } + + err = session.Query( + `INSERT INTO slice_map_copy_table (t, u, l) VALUES ('test', ?, ?)`, + TimeUUID(), []string{"5", "6"}, + ).Exec() + if err != nil { + t.Fatal("insert:", err) + } + + if returned, retErr := session.Query(`SELECT * FROM slice_map_copy_table WHERE t = 'test'`).Iter().SliceMap(); retErr != nil { + t.Fatal("select:", retErr) + } else { + if len(returned) != 3 { + t.Fatal("expected 3 rows, got", len(returned)) + } + if !reflect.DeepEqual(returned[0]["l"], []string{"1", "2"}) { + t.Fatal("expected [1, 2], got", returned[0]["l"]) + } + if !reflect.DeepEqual(returned[1]["l"], []string{"3", "4"}) { + t.Fatal("expected [3, 4], got", returned[1]["l"]) + } + if !reflect.DeepEqual(returned[2]["l"], []string{"5", "6"}) { + t.Fatal("expected [5, 6], got", returned[2]["l"]) + } } } @@ -1278,7 +1332,7 @@ func TestSmallInt(t *testing.T) { t.Fatal("select:", retErr) } else { if sliceMap[0]["testsmallint"] != returned[0]["testsmallint"] { - t.Fatal("returned testsmallint did not match") + t.Fatalf("returned testsmallint %#v did not match %#v", returned[0]["testsmallint"], sliceMap[0]["testsmallint"]) } } } @@ -1598,7 +1652,7 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string Keyspace: "gocql_test", Table: table, Name: "foo", - TypeInfo: NativeType{ + TypeInfo: varcharLikeTypeInfo{ typ: TypeVarchar, }, }, @@ -2497,19 +2551,18 @@ func TestAggregateMetadata(t *testing.T) { t.Fatal("expected two aggregates") } - protoVer := byte(session.cfg.ProtoVersion) expectedAggregrate := AggregateMetadata{ Keyspace: "gocql_test", Name: "average", - ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt, proto: protoVer}}, + ArgumentTypes: []TypeInfo{intTypeInfo{}}, InitCond: "(0, 0)", - ReturnType: NativeType{typ: TypeDouble, proto: protoVer}, + ReturnType: doubleTypeInfo{}, StateType: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple, proto: protoVer}, - Elems: []TypeInfo{ - NativeType{typ: TypeInt, proto: protoVer}, - NativeType{typ: TypeBigInt, proto: protoVer}, + intTypeInfo{}, + bigIntLikeTypeInfo{ + typ: TypeBigInt, + }, }, }, stateFunc: "avgstate", @@ -2522,11 +2575,11 @@ func TestAggregateMetadata(t *testing.T) { } if !reflect.DeepEqual(aggregates[0], expectedAggregrate) { - t.Fatalf("aggregate 'average' is %+v, but expected %+v", aggregates[0], expectedAggregrate) + t.Fatalf("aggregate 'average' is %#v, but expected %#v", aggregates[0], expectedAggregrate) } expectedAggregrate.Name = "average2" if !reflect.DeepEqual(aggregates[1], expectedAggregrate) { - t.Fatalf("aggregate 'average2' is %+v, but expected %+v", aggregates[1], expectedAggregrate) + t.Fatalf("aggregate 'average2' is %#v, but expected %#v", aggregates[1], expectedAggregrate) } } @@ -2548,29 +2601,28 @@ func TestFunctionMetadata(t *testing.T) { avgState := functions[1] avgFinal := functions[0] - protoVer := byte(session.cfg.ProtoVersion) avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;" expectedAvgState := FunctionMetadata{ Keyspace: "gocql_test", Name: "avgstate", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple, proto: protoVer}, - Elems: []TypeInfo{ - NativeType{typ: TypeInt, proto: protoVer}, - NativeType{typ: TypeBigInt, proto: protoVer}, + intTypeInfo{}, + bigIntLikeTypeInfo{ + typ: TypeBigInt, + }, }, }, - NativeType{typ: TypeInt, proto: protoVer}, + intTypeInfo{}, }, ArgumentNames: []string{"state", "val"}, ReturnType: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple, proto: protoVer}, - Elems: []TypeInfo{ - NativeType{typ: TypeInt, proto: protoVer}, - NativeType{typ: TypeBigInt, proto: protoVer}, + intTypeInfo{}, + bigIntLikeTypeInfo{ + typ: TypeBigInt, + }, }, }, CalledOnNullInput: true, @@ -2587,22 +2639,22 @@ func TestFunctionMetadata(t *testing.T) { Name: "avgfinal", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple, proto: protoVer}, - Elems: []TypeInfo{ - NativeType{typ: TypeInt, proto: protoVer}, - NativeType{typ: TypeBigInt, proto: protoVer}, + intTypeInfo{}, + bigIntLikeTypeInfo{ + typ: TypeBigInt, + }, }, }, }, ArgumentNames: []string{"state"}, - ReturnType: NativeType{typ: TypeDouble, proto: protoVer}, + ReturnType: doubleTypeInfo{}, CalledOnNullInput: true, Language: "java", Body: finalStateBody, } if !reflect.DeepEqual(avgFinal, expectedAvgFinal) { - t.Fatalf("function is %+v, but expected %+v", avgFinal, expectedAvgFinal) + t.Fatalf("function is %#v, but expected %#v", avgFinal, expectedAvgFinal) } } @@ -2700,20 +2752,25 @@ func TestKeyspaceMetadata(t *testing.T) { if flagCassVersion.Before(3, 0, 0) { textType = TypeVarchar } - protoVer := byte(session.cfg.ProtoVersion) expectedType := UserTypeMetadata{ Keyspace: "gocql_test", Name: "basicview", FieldNames: []string{"birthday", "nationality", "weight", "height"}, FieldTypes: []TypeInfo{ - NativeType{typ: TypeTimestamp, proto: protoVer}, - NativeType{typ: textType, proto: protoVer}, - NativeType{typ: textType, proto: protoVer}, - NativeType{typ: textType, proto: protoVer}, + timestampTypeInfo{}, + varcharLikeTypeInfo{ + typ: textType, + }, + varcharLikeTypeInfo{ + typ: textType, + }, + varcharLikeTypeInfo{ + typ: textType, + }, }, } if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], expectedType) { - t.Fatalf("type is %+v, but expected %+v", keyspaceMetadata.UserTypes["basicview"], expectedType) + t.Fatalf("type is %#v, but expected %#v", keyspaceMetadata.UserTypes["basicview"], expectedType) } if flagCassVersion.Major >= 3 { materializedView, found := keyspaceMetadata.MaterializedViews["view_view"] @@ -3520,9 +3577,9 @@ func TestQuery_SetKeyspace(t *testing.T) { } const keyspaceStmt = ` - CREATE KEYSPACE IF NOT EXISTS gocql_query_keyspace_override_test + CREATE KEYSPACE IF NOT EXISTS gocql_query_keyspace_override_test WITH replication = { - 'class': 'SimpleStrategy', + 'class': 'SimpleStrategy', 'replication_factor': '1' }; ` diff --git a/cluster.go b/cluster.go index 06a271e95..f18f72978 100644 --- a/cluster.go +++ b/cluster.go @@ -274,6 +274,10 @@ type ClusterConfig struct { // default: 0.25. NextPagePrefetch float64 + // RegisteredTypes will be copied for all sessions created from this Cluster. + // If not provided, a copy of GlobalTypes will be used. + RegisteredTypes *RegisteredTypes + // internal config for testing disableControlConn bool } diff --git a/common_test.go b/common_test.go index 5e4f0cc0b..0ef0acad9 100644 --- a/common_test.go +++ b/common_test.go @@ -245,7 +245,7 @@ func createFunctions(t *testing.T, session *Session) { CALLED ON NULL INPUT RETURNS double LANGUAGE java AS - $$double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);$$ + $$double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);$$ `).Exec(); err != nil { t.Fatalf("failed to create function with err: %v", err) } @@ -297,20 +297,20 @@ func randomText(size int) string { func assertEqual(t *testing.T, description string, expected, actual interface{}) { t.Helper() if expected != actual { - t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) + t.Fatalf("expected %s to be (%#v) but was (%#v) instead", description, expected, actual) } } func assertDeepEqual(t *testing.T, description string, expected, actual interface{}) { t.Helper() if !reflect.DeepEqual(expected, actual) { - t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) + t.Fatalf("expected %s to be (%#v) but was (%#v) instead", description, expected, actual) } } func assertNil(t *testing.T, description string, actual interface{}) { t.Helper() if actual != nil { - t.Fatalf("expected %s to be (nil) but was (%+v) instead", description, actual) + t.Fatalf("expected %s to be (nil) but was (%#v) instead", description, actual) } } diff --git a/conn.go b/conn.go index b7a3ed27d..c9efe723b 100644 --- a/conn.go +++ b/conn.go @@ -249,6 +249,14 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * writeTimeout = cfg.WriteTimeout } + logger := cfg.Logger + if logger == nil { + logger = s.logger + if logger == nil { + logger = &defaultLogger{} + } + } + ctx, cancel := context.WithCancel(ctx) c := &Conn{ r: &connReader{ @@ -274,7 +282,7 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg * }, ctx: ctx, cancel: cancel, - logger: cfg.logger(), + logger: logger, streamObserver: s.streamObserver, writeTimeout: writeTimeout, } @@ -674,7 +682,7 @@ func (c *Conn) processFrame(ctx context.Context, r io.Reader) error { return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream) } else if head.stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently - framer := newFramer(c.compressor, c.version) + framer := newFramer(c.compressor, c.version, c.session.types) if err := framer.readFrame(r, &head); err != nil { return err } @@ -683,7 +691,7 @@ func (c *Conn) processFrame(ctx context.Context, r io.Reader) error { } else if head.stream <= 0 { // reserved stream that we dont use, probably due to a protocol error // or a bug in Cassandra, this should be an error, parse it and return. - framer := newFramer(c.compressor, c.version) + framer := newFramer(c.compressor, c.version, c.session.types) if err := framer.readFrame(r, &head); err != nil { return err } @@ -713,7 +721,7 @@ func (c *Conn) processFrame(ctx context.Context, r io.Reader) error { panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) } - framer := newFramer(c.compressor, c.version) + framer := newFramer(c.compressor, c.version, c.session.types) err = framer.readFrame(r, &head) if err != nil { @@ -1188,7 +1196,7 @@ func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer } // resp is basically a waiting semaphore protecting the framer - framer := newFramer(c.compressor, c.version) + framer := newFramer(c.compressor, c.version, c.session.types) call := &callReq{ timeout: make(chan struct{}), diff --git a/conn_test.go b/conn_test.go index cb76d90c1..4b7ed732d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -701,7 +701,7 @@ func TestStream0(t *testing.T) { const expErr = "gocql: received unexpected frame on stream 0" var buf bytes.Buffer - f := newFramer(nil, protoVersion4) + f := newFramer(nil, protoVersion4, GlobalTypes) f.writeHeader(0, opResult, 0) f.writeInt(resultKindVoid) f.buf[0] |= 0x80 @@ -717,7 +717,10 @@ func TestStream0(t *testing.T) { r: bufio.NewReader(&buf), }, streams: streams.New(protoVersion4), - logger: &defaultLogger{}, + session: &Session{ + types: GlobalTypes, + }, + logger: &defaultLogger{}, } err := conn.recv(context.Background(), false) @@ -1210,7 +1213,7 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { srv.errorLocked("process frame with a nil header") return } - respFrame := newFramer(nil, reqFrame.proto) + respFrame := newFramer(nil, reqFrame.proto, GlobalTypes) switch head.op { case opStartup: @@ -1227,7 +1230,11 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { respFrame.writeHeader(0, opSupported, head.stream) respFrame.writeShort(0) case opQuery: - query := reqFrame.readLongString() + query, err := reqFrame.readLongString() + if err != nil { + srv.errorLocked(err) + return + } first := query if n := strings.Index(query, " "); n > 0 { first = first[:n] @@ -1282,7 +1289,11 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { respFrame.writeHeader(0, opError, head.stream) respFrame.buf = append(respFrame.buf, reqFrame.buf...) case opPrepare: - query := reqFrame.readLongString() + query, err := reqFrame.readLongString() + if err != nil { + srv.errorLocked(err) + return + } name := strings.TrimPrefix(query, "select ") if n := strings.Index(name, " "); n > 0 { name = name[:n] @@ -1328,16 +1339,29 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { respFrame.writeString("unsupported query: " + name) } case opExecute: - b := reqFrame.readShortBytes() + b, err := reqFrame.readShortBytes() + if err != nil { + srv.errorLocked(err) + return + } id := binary.BigEndian.Uint64(b) // reqFrame.readConsistency() // var flags uint32 if srv.protocol > protoVersion4 { - ui := reqFrame.readInt() + ui, err := reqFrame.readInt() + if err != nil { + srv.errorLocked(err) + return + } flags = uint32(ui) } else { - flags = uint32(reqFrame.readByte()) + b, err := reqFrame.readByte() + if err != nil { + srv.errorLocked(err) + return + } + flags = uint32(b) } switch id { case 1: @@ -1397,7 +1421,7 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { if err != nil { return nil, err } - framer := newFramer(nil, srv.protocol) + framer := newFramer(nil, srv.protocol, GlobalTypes) err = framer.readFrame(conn, &head) if err != nil { @@ -1435,6 +1459,8 @@ func TestConnProcessAllFramesInSingleSegment(t *testing.T) { quit: make(chan struct{}), }, writeTimeout: time.Second * 10, + session: &Session{types: GlobalTypes}, + logger: &defaultLogger{}, } call1 := &callReq{ @@ -1460,11 +1486,11 @@ func TestConnProcessAllFramesInSingleSegment(t *testing.T) { }, } - framer1 := newFramer(nil, protoVersion5) + framer1 := newFramer(nil, protoVersion5, GlobalTypes) err = req.buildFrame(framer1, 1) require.NoError(t, err) - framer2 := newFramer(nil, protoVersion5) + framer2 := newFramer(nil, protoVersion5, GlobalTypes) err = req.buildFrame(framer2, 2) require.NoError(t, err) diff --git a/frame.go b/frame.go index f2f6c7cd8..a48279d3e 100644 --- a/frame.go +++ b/frame.go @@ -33,8 +33,7 @@ import ( "io" "io/ioutil" "net" - "runtime" - "strconv" + "reflect" "strings" "time" ) @@ -379,13 +378,16 @@ type framer struct { buf []byte customPayload map[string][]byte + + types *RegisteredTypes } -func newFramer(compressor Compressor, version byte) *framer { +func newFramer(compressor Compressor, version byte, r *RegisteredTypes) *framer { buf := make([]byte, defaultBufSize) f := &framer{ buf: buf[:0], readBuffer: buf, + types: r, } var flags byte if compressor != nil { @@ -494,16 +496,7 @@ func (f *framer) readFrame(r io.Reader, head *frameHeader) error { return nil } -func (f *framer) parseFrame() (frame frame, err error) { - defer func() { - if r := recover(); r != nil { - if _, ok := r.(runtime.Error); ok { - panic(r) - } - err = r.(error) - } - }() - +func (f *framer) parseFrame() (frame, error) { if f.header.version.request() { return nil, NewErrProtocol("got a request frame from server: %v", f.header.version) } @@ -512,42 +505,53 @@ func (f *framer) parseFrame() (frame frame, err error) { f.readTrace() } + var err error if f.header.flags&flagWarning == flagWarning { - f.header.warnings = f.readStringList() + f.header.warnings, err = f.readStringList() + if err != nil { + return nil, err + } } if f.header.flags&flagCustomPayload == flagCustomPayload { - f.customPayload = f.readBytesMap() + f.customPayload, err = f.readBytesMap() + if err != nil { + return nil, err + } } // assumes that the frame body has been read into rbuf switch f.header.op { case opError: - frame = f.parseErrorFrame() + return f.parseErrorFrame() case opReady: - frame = f.parseReadyFrame() + return f.parseReadyFrame() case opResult: - frame, err = f.parseResultFrame() + return f.parseResultFrame() case opSupported: - frame = f.parseSupportedFrame() + return f.parseSupportedFrame() case opAuthenticate: - frame = f.parseAuthenticateFrame() + return f.parseAuthenticateFrame() case opAuthChallenge: - frame = f.parseAuthChallengeFrame() + return f.parseAuthChallengeFrame() case opAuthSuccess: - frame = f.parseAuthSuccessFrame() + return f.parseAuthSuccessFrame() case opEvent: - frame = f.parseEventFrame() + return f.parseEventFrame() default: return nil, NewErrProtocol("unknown op in frame header: %s", f.header.op) } - - return } -func (f *framer) parseErrorFrame() frame { - code := f.readInt() - msg := f.readString() +func (f *framer) parseErrorFrame() (frame, error) { + code, err := f.readInt() + if err != nil { + return nil, err + } + msg, err := f.readString() + if err != nil { + return nil, err + } errD := errorFrame{ frameHeader: *f.header, @@ -557,123 +561,228 @@ func (f *framer) parseErrorFrame() frame { switch code { case ErrCodeUnavailable: - cl := f.readConsistency() - required := f.readInt() - alive := f.readInt() + cl, err := f.readConsistency() + if err != nil { + return nil, err + } + required, err := f.readInt() + if err != nil { + return nil, err + } + alive, err := f.readInt() + if err != nil { + return nil, err + } return &RequestErrUnavailable{ errorFrame: errD, Consistency: cl, Required: required, Alive: alive, - } + }, nil case ErrCodeWriteTimeout: - cl := f.readConsistency() - received := f.readInt() - blockfor := f.readInt() - writeType := f.readString() + cl, err := f.readConsistency() + if err != nil { + return nil, err + } + received, err := f.readInt() + if err != nil { + return nil, err + } + blockfor, err := f.readInt() + if err != nil { + return nil, err + } + writeType, err := f.readString() + if err != nil { + return nil, err + } return &RequestErrWriteTimeout{ errorFrame: errD, Consistency: cl, Received: received, BlockFor: blockfor, WriteType: writeType, - } + }, nil case ErrCodeReadTimeout: - cl := f.readConsistency() - received := f.readInt() - blockfor := f.readInt() - dataPresent := f.readByte() + cl, err := f.readConsistency() + if err != nil { + return nil, err + } + received, err := f.readInt() + if err != nil { + return nil, err + } + blockfor, err := f.readInt() + if err != nil { + return nil, err + } + dataPresent, err := f.readByte() + if err != nil { + return nil, err + } return &RequestErrReadTimeout{ errorFrame: errD, Consistency: cl, Received: received, BlockFor: blockfor, DataPresent: dataPresent, - } + }, nil case ErrCodeAlreadyExists: - ks := f.readString() - table := f.readString() + ks, err := f.readString() + if err != nil { + return nil, err + } + table, err := f.readString() + if err != nil { + return nil, err + } return &RequestErrAlreadyExists{ errorFrame: errD, Keyspace: ks, Table: table, - } + }, nil case ErrCodeUnprepared: - stmtId := f.readShortBytes() + stmtId, err := f.readShortBytes() + if err != nil { + return nil, err + } return &RequestErrUnprepared{ errorFrame: errD, StatementId: copyBytes(stmtId), // defensively copy - } + }, nil case ErrCodeReadFailure: res := &RequestErrReadFailure{ errorFrame: errD, } - res.Consistency = f.readConsistency() - res.Received = f.readInt() - res.BlockFor = f.readInt() + res.Consistency, err = f.readConsistency() + if err != nil { + return nil, err + } + res.Received, err = f.readInt() + if err != nil { + return nil, err + } + res.BlockFor, err = f.readInt() + if err != nil { + return nil, err + } if f.proto > protoVersion4 { - res.ErrorMap = f.readErrorMap() + res.ErrorMap, err = f.readErrorMap() + if err != nil { + return nil, err + } res.NumFailures = len(res.ErrorMap) } else { - res.NumFailures = f.readInt() + res.NumFailures, err = f.readInt() + if err != nil { + return nil, err + } } - res.DataPresent = f.readByte() != 0 + b, err := f.readByte() + if err != nil { + return nil, err + } + res.DataPresent = b != 0 - return res + return res, nil case ErrCodeWriteFailure: res := &RequestErrWriteFailure{ errorFrame: errD, } - res.Consistency = f.readConsistency() - res.Received = f.readInt() - res.BlockFor = f.readInt() + res.Consistency, err = f.readConsistency() + if err != nil { + return nil, err + } + res.Received, err = f.readInt() + if err != nil { + return nil, err + } + res.BlockFor, err = f.readInt() + if err != nil { + return nil, err + } if f.proto > protoVersion4 { - res.ErrorMap = f.readErrorMap() + res.ErrorMap, err = f.readErrorMap() + if err != nil { + return nil, err + } res.NumFailures = len(res.ErrorMap) } else { - res.NumFailures = f.readInt() + res.NumFailures, err = f.readInt() + if err != nil { + return nil, err + } + } + res.WriteType, err = f.readString() + if err != nil { + return nil, err } - res.WriteType = f.readString() - return res + return res, nil case ErrCodeFunctionFailure: res := &RequestErrFunctionFailure{ errorFrame: errD, } - res.Keyspace = f.readString() - res.Function = f.readString() - res.ArgTypes = f.readStringList() - return res + res.Keyspace, err = f.readString() + if err != nil { + return nil, err + } + res.Function, err = f.readString() + if err != nil { + return nil, err + } + res.ArgTypes, err = f.readStringList() + if err != nil { + return nil, err + } + return res, nil case ErrCodeCDCWriteFailure: - res := &RequestErrCDCWriteFailure{ + return &RequestErrCDCWriteFailure{ errorFrame: errD, - } - return res + }, nil case ErrCodeCASWriteUnknown: res := &RequestErrCASWriteUnknown{ errorFrame: errD, } - res.Consistency = f.readConsistency() - res.Received = f.readInt() - res.BlockFor = f.readInt() - return res + res.Consistency, err = f.readConsistency() + if err != nil { + return nil, err + } + res.Received, err = f.readInt() + if err != nil { + return nil, err + } + res.BlockFor, err = f.readInt() + if err != nil { + return nil, err + } + return res, nil case ErrCodeInvalid, ErrCodeBootstrapping, ErrCodeConfig, ErrCodeCredentials, ErrCodeOverloaded, ErrCodeProtocol, ErrCodeServer, ErrCodeSyntax, ErrCodeTruncate, ErrCodeUnauthorized: // TODO(zariel): we should have some distinct types for these errors - return errD + return errD, nil default: - panic(fmt.Errorf("unknown error code: 0x%x", errD.code)) + return nil, fmt.Errorf("unknown error code: 0x%x", errD.code) } } -func (f *framer) readErrorMap() (errMap ErrorMap) { - errMap = make(ErrorMap) - numErrs := f.readInt() +func (f *framer) readErrorMap() (ErrorMap, error) { + numErrs, err := f.readInt() + if err != nil { + return nil, err + } + errMap := make(ErrorMap, numErrs) for i := 0; i < numErrs; i++ { - ip := f.readInetAdressOnly().String() - errMap[ip] = f.readShort() + ip, err := f.readInetAdressOnly() + if err != nil { + return nil, err + } + errMap[ip.String()], err = f.readShort() + if err != nil { + return nil, err + } } - return + return errMap, nil } func (f *framer) writeHeader(flags byte, op frameOp, stream int) { @@ -737,10 +846,10 @@ type readyFrame struct { frameHeader } -func (f *framer) parseReadyFrame() frame { +func (f *framer) parseReadyFrame() (frame, error) { return &readyFrame{ frameHeader: *f.header, - } + }, nil } type supportedFrame struct { @@ -751,12 +860,15 @@ type supportedFrame struct { // TODO: if we move the body buffer onto the frameHeader then we only need a single // framer, and can move the methods onto the header. -func (f *framer) parseSupportedFrame() frame { +func (f *framer) parseSupportedFrame() (frame, error) { + s, err := f.readStringMultiMap() + if err != nil { + return nil, err + } return &supportedFrame{ frameHeader: *f.header, - - supported: f.readStringMultiMap(), - } + supported: s, + }, nil } type writeStartupFrame struct { @@ -806,84 +918,128 @@ func (w *writePrepareFrame) buildFrame(f *framer, streamID int) error { return f.finish() } -func (f *framer) readTypeInfo() TypeInfo { - // TODO: factor this out so the same code paths can be used to parse custom - // types and other types, as much of the logic will be duplicated. - id := f.readShort() - - simple := NativeType{ - proto: f.proto, - typ: Type(id), - } - - if simple.typ == TypeCustom { - simple.custom = f.readString() - if cassType := getApacheCassandraType(simple.custom); cassType != TypeCustom { - simple.typ = cassType +func (f *framer) readParam(param interface{}) (interface{}, error) { + switch p := param.(type) { + case string: + return f.readString() + case uint16: + return f.readShort() + case byte: + return f.readByte() + case []byte: + return f.readShortBytes() + case int: + return f.readInt() + case []string: + return f.readStringList() + case []UDTField: + n, err := f.readShort() + if err != nil { + return nil, err } - } - - switch simple.typ { - case TypeTuple: - n := f.readShort() - tuple := TupleTypeInfo{ - NativeType: simple, - Elems: make([]TypeInfo, n), + if len(p) < int(n) { + p = make([]UDTField, n) + } else { + p = p[:n] } - for i := 0; i < int(n); i++ { - tuple.Elems[i] = f.readTypeInfo() + p[i].Name, err = f.readString() + if err != nil { + return nil, err + } + p[i].Type, err = f.readTypeInfo() + if err != nil { + return nil, err + } } - - return tuple - - case TypeUDT: - udt := UDTTypeInfo{ - NativeType: simple, + return p, nil + case TypeInfo: + return f.readTypeInfo() + case *TypeInfo: + return f.readTypeInfo() + case []TypeInfo: + n, err := f.readShort() + if err != nil { + return nil, err + } + if len(p) < int(n) { + p = make([]TypeInfo, n) + } else { + p = p[:n] } - udt.KeySpace = f.readString() - udt.Name = f.readString() - - n := f.readShort() - udt.Elements = make([]UDTField, n) for i := 0; i < int(n); i++ { - field := &udt.Elements[i] - field.Name = f.readString() - field.Type = f.readTypeInfo() + p[i], err = f.readTypeInfo() + if err != nil { + return nil, err + } } - - return udt - case TypeMap, TypeList, TypeSet: - collection := CollectionType{ - NativeType: simple, + return p, nil + case Type: + // Type is actually an int but it's encoded as short + s, err := f.readShort() + if err != nil { + return nil, err } - - if simple.typ == TypeMap { - collection.Key = f.readTypeInfo() + return Type(s), nil + case []interface{}: + n, err := f.readShort() + if err != nil { + return nil, err + } + if len(p) != int(n) { + return nil, fmt.Errorf("wrong length for reading []interface{} from frame %d vs %d", len(p), n) } + for i := 0; i < int(n); i++ { + p[i], err = f.readParam(p[i]) + if err != nil { + return nil, err + } + } + return p, nil + } - collection.Elem = f.readTypeInfo() + // check if its a pointer + // we used to do some conversions in here for some types but that was risky + // since Type is an int but it is read with readShort so we stopped doing that + // out of caution and instead are just going to error + valueRef := reflect.ValueOf(param) + if valueRef.Kind() == reflect.Ptr && !valueRef.IsNil() { + return f.readParam(valueRef.Elem().Interface()) + } + return nil, fmt.Errorf("unsupported type for reading from frame: %T", param) +} - return collection - case TypeCustom: - if strings.HasPrefix(simple.custom, VECTOR_TYPE) { - spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) - spec = spec[1 : len(spec)-1] // remove parenthesis - idx := strings.LastIndex(spec, ",") - typeStr := spec[:idx] - dimStr := spec[idx+1:] - subType := getCassandraLongType(strings.TrimSpace(typeStr), f.proto, nopLogger{}) - dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) - vector := VectorType{ - NativeType: simple, - SubType: subType, - Dimensions: dim, - } - return vector +func (f *framer) readTypeInfo() (TypeInfo, error) { + i, err := f.readShort() + if err != nil { + return nil, err + } + typ := Type(i) + if typ == TypeCustom { + name, err := f.readString() + if err != nil { + return nil, err } + return f.types.typeInfoFromString(int(f.proto), name) + } + + if ti := f.types.fastTypeInfoLookup(typ); ti != nil { + return ti, nil } - return simple + cqlt := f.types.fastRegisteredTypeLookup(typ) + if cqlt == nil { + return nil, unknownTypeError(fmt.Sprintf("%d", typ)) + } + + params := cqlt.Params(int(f.proto)) + for i := range params { + params[i], err = f.readParam(params[i]) + if err != nil { + return nil, err + } + } + return cqlt.TypeInfoFromParams(int(f.proto), params) } type preparedMetadata struct { @@ -901,38 +1057,62 @@ func (r preparedMetadata) String() string { return fmt.Sprintf("[prepared flags=0x%x pkey=%v paging_state=% X columns=%v col_count=%d actual_col_count=%d]", r.flags, r.pkeyColumns, r.pagingState, r.columns, r.colCount, r.actualColCount) } -func (f *framer) parsePreparedMetadata() preparedMetadata { +func (f *framer) parsePreparedMetadata() (preparedMetadata, error) { // TODO: deduplicate this from parseMetadata meta := preparedMetadata{} - meta.flags = f.readInt() - meta.colCount = f.readInt() + var err error + meta.flags, err = f.readInt() + if err != nil { + return preparedMetadata{}, err + } + meta.colCount, err = f.readInt() + if err != nil { + return preparedMetadata{}, err + } if meta.colCount < 0 { - panic(fmt.Errorf("received negative column count: %d", meta.colCount)) + return preparedMetadata{}, fmt.Errorf("received negative column count: %d", meta.colCount) } meta.actualColCount = meta.colCount if f.proto >= protoVersion4 { - pkeyCount := f.readInt() + pkeyCount, err := f.readInt() + if err != nil { + return preparedMetadata{}, err + } pkeys := make([]int, pkeyCount) for i := 0; i < pkeyCount; i++ { - pkeys[i] = int(f.readShort()) + c, err := f.readShort() + if err != nil { + return preparedMetadata{}, err + } + pkeys[i] = int(c) } meta.pkeyColumns = pkeys } if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) + b, err := f.readBytes() + if err != nil { + return preparedMetadata{}, err + } + meta.pagingState = copyBytes(b) } if meta.flags&flagNoMetaData == flagNoMetaData { - return meta + return meta, nil } globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec if globalSpec { - meta.keyspace = f.readString() - meta.table = f.readString() + meta.keyspace, err = f.readString() + if err != nil { + return preparedMetadata{}, err + } + meta.table, err = f.readString() + if err != nil { + return preparedMetadata{}, err + } } var cols []ColumnInfo @@ -940,21 +1120,27 @@ func (f *framer) parsePreparedMetadata() preparedMetadata { // preallocate columninfo to avoid excess copying cols = make([]ColumnInfo, meta.colCount) for i := 0; i < meta.colCount; i++ { - f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) + err = f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) + if err != nil { + return preparedMetadata{}, err + } } } else { // use append, huge number of columns usually indicates a corrupt frame or // just a huge row. for i := 0; i < meta.colCount; i++ { var col ColumnInfo - f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) + err = f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) + if err != nil { + return preparedMetadata{}, err + } cols = append(cols, col) } } meta.columns = cols - return meta + return meta, nil } type resultMetadata struct { @@ -986,52 +1172,86 @@ func (r resultMetadata) String() string { return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v new_metadata_id=% X]", r.flags, r.pagingState, r.columns, r.newMetadataID) } -func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string) { +func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string) error { + var err error if !globalSpec { - col.Keyspace = f.readString() - col.Table = f.readString() + col.Keyspace, err = f.readString() + if err != nil { + return err + } + col.Table, err = f.readString() + if err != nil { + return err + } } else { col.Keyspace = keyspace col.Table = table } - col.Name = f.readString() - col.TypeInfo = f.readTypeInfo() - switch v := col.TypeInfo.(type) { + col.Name, err = f.readString() + if err != nil { + return err + } + col.TypeInfo, err = f.readTypeInfo() + if err != nil { + return err + } // maybe also UDT - case TupleTypeInfo: + if t, ok := col.TypeInfo.(TupleTypeInfo); ok { // -1 because we already included the tuple column - meta.actualColCount += len(v.Elems) - 1 + meta.actualColCount += len(t.Elems) - 1 } + return nil } -func (f *framer) parseResultMetadata() resultMetadata { +func (f *framer) parseResultMetadata() (resultMetadata, error) { var meta resultMetadata - meta.flags = f.readInt() - meta.colCount = f.readInt() + var err error + meta.flags, err = f.readInt() + if err != nil { + return resultMetadata{}, err + } + meta.colCount, err = f.readInt() + if err != nil { + return resultMetadata{}, err + } if meta.colCount < 0 { - panic(fmt.Errorf("received negative column count: %d", meta.colCount)) + return resultMetadata{}, fmt.Errorf("received negative column count: %d", meta.colCount) } meta.actualColCount = meta.colCount if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) + b, err := f.readBytes() + if err != nil { + return resultMetadata{}, err + } + meta.pagingState = copyBytes(b) } if f.proto > protoVersion4 && meta.flags&flagMetaDataChanged == flagMetaDataChanged { - meta.newMetadataID = copyBytes(f.readShortBytes()) + b, err := f.readShortBytes() + if err != nil { + return resultMetadata{}, err + } + meta.newMetadataID = copyBytes(b) } if meta.noMetaData() { - return meta + return meta, nil } var keyspace, table string globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec if globalSpec { - keyspace = f.readString() - table = f.readString() + keyspace, err = f.readString() + if err != nil { + return resultMetadata{}, err + } + table, err = f.readString() + if err != nil { + return resultMetadata{}, err + } } var cols []ColumnInfo @@ -1039,7 +1259,10 @@ func (f *framer) parseResultMetadata() resultMetadata { // preallocate columninfo to avoid excess copying cols = make([]ColumnInfo, meta.colCount) for i := 0; i < meta.colCount; i++ { - f.readCol(&cols[i], &meta, globalSpec, keyspace, table) + err = f.readCol(&cols[i], &meta, globalSpec, keyspace, table) + if err != nil { + return resultMetadata{}, err + } } } else { @@ -1047,14 +1270,17 @@ func (f *framer) parseResultMetadata() resultMetadata { // just a huge row. for i := 0; i < meta.colCount; i++ { var col ColumnInfo - f.readCol(&col, &meta, globalSpec, keyspace, table) + err = f.readCol(&col, &meta, globalSpec, keyspace, table) + if err != nil { + return resultMetadata{}, err + } cols = append(cols, col) } } meta.columns = cols - return meta + return meta, nil } type resultVoidFrame struct { @@ -1066,19 +1292,22 @@ func (f *resultVoidFrame) String() string { } func (f *framer) parseResultFrame() (frame, error) { - kind := f.readInt() + kind, err := f.readInt() + if err != nil { + return nil, err + } switch kind { case resultKindVoid: return &resultVoidFrame{frameHeader: *f.header}, nil case resultKindRows: - return f.parseResultRows(), nil + return f.parseResultRows() case resultKindKeyspace: - return f.parseResultSetKeyspace(), nil + return f.parseResultSetKeyspace() case resultKindPrepared: - return f.parseResultPrepared(), nil + return f.parseResultPrepared() case resultKindSchemaChanged: - return f.parseResultSchemaChange(), nil + return f.parseResultSchemaChange() } return nil, NewErrProtocol("unknown result kind: %x", kind) @@ -1096,16 +1325,23 @@ func (f *resultRowsFrame) String() string { return fmt.Sprintf("[result_rows meta=%v]", f.meta) } -func (f *framer) parseResultRows() frame { +func (f *framer) parseResultRows() (frame, error) { result := &resultRowsFrame{} - result.meta = f.parseResultMetadata() + var err error + result.meta, err = f.parseResultMetadata() + if err != nil { + return nil, err + } - result.numRows = f.readInt() + result.numRows, err = f.readInt() + if err != nil { + return nil, err + } if result.numRows < 0 { - panic(fmt.Errorf("invalid row_count in result frame: %d", result.numRows)) + return nil, fmt.Errorf("invalid row_count in result frame: %d", result.numRows) } - return result + return result, nil } type resultKeyspaceFrame struct { @@ -1117,11 +1353,15 @@ func (r *resultKeyspaceFrame) String() string { return fmt.Sprintf("[result_keyspace keyspace=%s]", r.keyspace) } -func (f *framer) parseResultSetKeyspace() frame { +func (f *framer) parseResultSetKeyspace() (frame, error) { + k, err := f.readString() + if err != nil { + return nil, err + } return &resultKeyspaceFrame{ frameHeader: *f.header, - keyspace: f.readString(), - } + keyspace: k, + }, nil } type resultPreparedFrame struct { @@ -1133,20 +1373,34 @@ type resultPreparedFrame struct { respMeta resultMetadata } -func (f *framer) parseResultPrepared() frame { +func (f *framer) parseResultPrepared() (frame, error) { + b, err := f.readShortBytes() + if err != nil { + return nil, err + } frame := &resultPreparedFrame{ frameHeader: *f.header, - preparedID: f.readShortBytes(), + preparedID: b, } if f.proto > protoVersion4 { - frame.resultMetadataID = copyBytes(f.readShortBytes()) + b, err = f.readShortBytes() + if err != nil { + return nil, err + } + frame.resultMetadataID = copyBytes(b) } - frame.reqMeta = f.parsePreparedMetadata() - frame.respMeta = f.parseResultMetadata() + frame.reqMeta, err = f.parsePreparedMetadata() + if err != nil { + return nil, err + } + frame.respMeta, err = f.parseResultMetadata() + if err != nil { + return nil, err + } - return frame + return frame, nil } type schemaChangeKeyspace struct { @@ -1198,9 +1452,15 @@ type schemaChangeAggregate struct { args []string } -func (f *framer) parseResultSchemaChange() frame { - change := f.readString() - target := f.readString() +func (f *framer) parseResultSchemaChange() (frame, error) { + change, err := f.readString() + if err != nil { + return nil, err + } + target, err := f.readString() + if err != nil { + return nil, err + } // TODO: could just use a separate type for each target switch target { @@ -1210,53 +1470,86 @@ func (f *framer) parseResultSchemaChange() frame { change: change, } - frame.keyspace = f.readString() + frame.keyspace, err = f.readString() + if err != nil { + return nil, err + } - return frame + return frame, err case "TABLE": frame := &schemaChangeTable{ frameHeader: *f.header, change: change, } - frame.keyspace = f.readString() - frame.object = f.readString() + frame.keyspace, err = f.readString() + if err != nil { + return nil, err + } + frame.object, err = f.readString() + if err != nil { + return nil, err + } - return frame + return frame, err case "TYPE": frame := &schemaChangeType{ frameHeader: *f.header, change: change, } - frame.keyspace = f.readString() - frame.object = f.readString() + frame.keyspace, err = f.readString() + if err != nil { + return nil, err + } + frame.object, err = f.readString() + if err != nil { + return nil, err + } - return frame + return frame, nil case "FUNCTION": frame := &schemaChangeFunction{ frameHeader: *f.header, change: change, } - frame.keyspace = f.readString() - frame.name = f.readString() - frame.args = f.readStringList() + frame.keyspace, err = f.readString() + if err != nil { + return nil, err + } + frame.name, err = f.readString() + if err != nil { + return nil, err + } + frame.args, err = f.readStringList() + if err != nil { + return nil, err + } - return frame + return frame, nil case "AGGREGATE": frame := &schemaChangeAggregate{ frameHeader: *f.header, change: change, } - frame.keyspace = f.readString() - frame.name = f.readString() - frame.args = f.readStringList() + frame.keyspace, err = f.readString() + if err != nil { + return nil, err + } + frame.name, err = f.readString() + if err != nil { + return nil, err + } + frame.args, err = f.readStringList() + if err != nil { + return nil, err + } - return frame + return frame, nil default: - panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change)) + return nil, fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change) } } @@ -1270,11 +1563,15 @@ func (a *authenticateFrame) String() string { return fmt.Sprintf("[authenticate class=%q]", a.class) } -func (f *framer) parseAuthenticateFrame() frame { +func (f *framer) parseAuthenticateFrame() (frame, error) { + cls, err := f.readString() + if err != nil { + return nil, err + } return &authenticateFrame{ frameHeader: *f.header, - class: f.readString(), - } + class: cls, + }, nil } type authSuccessFrame struct { @@ -1287,11 +1584,15 @@ func (a *authSuccessFrame) String() string { return fmt.Sprintf("[auth_success data=%q]", a.data) } -func (f *framer) parseAuthSuccessFrame() frame { +func (f *framer) parseAuthSuccessFrame() (frame, error) { + b, err := f.readBytes() + if err != nil { + return nil, err + } return &authSuccessFrame{ frameHeader: *f.header, - data: f.readBytes(), - } + data: b, + }, nil } type authChallengeFrame struct { @@ -1304,11 +1605,15 @@ func (a *authChallengeFrame) String() string { return fmt.Sprintf("[auth_challenge data=%q]", a.data) } -func (f *framer) parseAuthChallengeFrame() frame { +func (f *framer) parseAuthChallengeFrame() (frame, error) { + b, err := f.readBytes() + if err != nil { + return nil, err + } return &authChallengeFrame{ frameHeader: *f.header, - data: f.readBytes(), - } + data: b, + }, nil } type statusChangeEventFrame struct { @@ -1336,22 +1641,37 @@ func (t topologyChangeEventFrame) String() string { return fmt.Sprintf("[topology_change change=%s host=%v port=%v]", t.change, t.host, t.port) } -func (f *framer) parseEventFrame() frame { - eventType := f.readString() +func (f *framer) parseEventFrame() (frame, error) { + eventType, err := f.readString() + if err != nil { + return nil, err + } switch eventType { case "TOPOLOGY_CHANGE": frame := &topologyChangeEventFrame{frameHeader: *f.header} - frame.change = f.readString() - frame.host, frame.port = f.readInet() + frame.change, err = f.readString() + if err != nil { + return nil, err + } + frame.host, frame.port, err = f.readInet() + if err != nil { + return nil, err + } - return frame + return frame, nil case "STATUS_CHANGE": frame := &statusChangeEventFrame{frameHeader: *f.header} - frame.change = f.readString() - frame.host, frame.port = f.readInet() + frame.change, err = f.readString() + if err != nil { + return nil, err + } + frame.host, frame.port, err = f.readInet() + if err != nil { + return nil, err + } - return frame + return frame, nil case "SCHEMA_CHANGE": // this should work for all versions return f.parseResultSchemaChange() @@ -1734,72 +2054,87 @@ func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) error { return f.finish() } -func (f *framer) readByte() byte { +func (f *framer) readByte() (byte, error) { if len(f.buf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.buf))) + return 0, fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.buf)) } b := f.buf[0] f.buf = f.buf[1:] - return b + return b, nil } -func (f *framer) readInt() (n int) { +func (f *framer) readInt() (int, error) { if len(f.buf) < 4 { - panic(fmt.Errorf("not enough bytes in buffer to read int require 4 got: %d", len(f.buf))) + return 0, fmt.Errorf("not enough bytes in buffer to read int require 4 got: %d", len(f.buf)) } - n = int(int32(f.buf[0])<<24 | int32(f.buf[1])<<16 | int32(f.buf[2])<<8 | int32(f.buf[3])) + n := int(int32(f.buf[0])<<24 | int32(f.buf[1])<<16 | int32(f.buf[2])<<8 | int32(f.buf[3])) f.buf = f.buf[4:] - return + return n, nil } -func (f *framer) readShort() (n uint16) { +func (f *framer) readShort() (uint16, error) { if len(f.buf) < 2 { - panic(fmt.Errorf("not enough bytes in buffer to read short require 2 got: %d", len(f.buf))) + return 0, fmt.Errorf("not enough bytes in buffer to read short require 2 got: %d", len(f.buf)) } - n = uint16(f.buf[0])<<8 | uint16(f.buf[1]) + n := uint16(f.buf[0])<<8 | uint16(f.buf[1]) f.buf = f.buf[2:] - return + return n, nil } -func (f *framer) readString() (s string) { - size := f.readShort() +func (f *framer) readString() (string, error) { + size, err := f.readShort() + if err != nil { + return "", err + } if len(f.buf) < int(size) { - panic(fmt.Errorf("not enough bytes in buffer to read string require %d got: %d", size, len(f.buf))) + return "", fmt.Errorf("not enough bytes in buffer to read string require %d got: %d", size, len(f.buf)) } - s = string(f.buf[:size]) + s := string(f.buf[:size]) f.buf = f.buf[size:] - return + return s, nil } -func (f *framer) readLongString() (s string) { - size := f.readInt() +func (f *framer) readLongString() (string, error) { + size, err := f.readInt() + if err != nil { + return "", err + } if len(f.buf) < size { - panic(fmt.Errorf("not enough bytes in buffer to read long string require %d got: %d", size, len(f.buf))) + return "", fmt.Errorf("not enough bytes in buffer to read long string require %d got: %d", size, len(f.buf)) } - s = string(f.buf[:size]) + s := string(f.buf[:size]) f.buf = f.buf[size:] - return + return s, err } -func (f *framer) readStringList() []string { - size := f.readShort() +func (f *framer) readStringList() ([]string, error) { + size, err := f.readShort() + if err != nil { + return nil, err + } l := make([]string, size) for i := 0; i < int(size); i++ { - l[i] = f.readString() + l[i], err = f.readString() + if err != nil { + return nil, err + } } - return l + return l, nil } -func (f *framer) readBytesInternal() ([]byte, error) { - size := f.readInt() +func (f *framer) readBytes() ([]byte, error) { + size, err := f.readInt() + if err != nil { + return nil, err + } if size < 0 { return nil, nil } @@ -1814,81 +2149,97 @@ func (f *framer) readBytesInternal() ([]byte, error) { return l, nil } -func (f *framer) readBytes() []byte { - l, err := f.readBytesInternal() +func (f *framer) readShortBytes() ([]byte, error) { + size, err := f.readShort() if err != nil { - panic(err) + return nil, err } - - return l -} - -func (f *framer) readShortBytes() []byte { - size := f.readShort() if len(f.buf) < int(size) { - panic(fmt.Errorf("not enough bytes in buffer to read short bytes: require %d got %d", size, len(f.buf))) + return nil, fmt.Errorf("not enough bytes in buffer to read short bytes: require %d got %d", size, len(f.buf)) } - l := f.buf[:size] + b := f.buf[:size] f.buf = f.buf[size:] - - return l + return b, nil } -func (f *framer) readInetAdressOnly() net.IP { +func (f *framer) readInetAdressOnly() (net.IP, error) { if len(f.buf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.buf))) + return nil, fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.buf)) } size := f.buf[0] f.buf = f.buf[1:] - if !(size == 4 || size == 16) { - panic(fmt.Errorf("invalid IP size: %d", size)) + return nil, fmt.Errorf("invalid IP size: %d", size) } if len(f.buf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read inet require %d got: %d", size, len(f.buf))) + return nil, fmt.Errorf("not enough bytes in buffer to read inet require %d got: %d", size, len(f.buf)) } ip := make([]byte, size) copy(ip, f.buf[:size]) f.buf = f.buf[size:] - return net.IP(ip) + // TODO: should we check if IP is nil? + return net.IP(ip), nil } -func (f *framer) readInet() (net.IP, int) { - return f.readInetAdressOnly(), f.readInt() +func (f *framer) readInet() (net.IP, int, error) { + ip, err := f.readInetAdressOnly() + if err != nil { + return nil, 0, err + } + port, err := f.readShort() + if err != nil { + return nil, 0, err + } + return ip, int(port), nil } -func (f *framer) readConsistency() Consistency { - return Consistency(f.readShort()) +func (f *framer) readConsistency() (Consistency, error) { + c, err := f.readShort() + if err != nil { + return 0, err + } + return Consistency(c), err } -func (f *framer) readBytesMap() map[string][]byte { - size := f.readShort() +func (f *framer) readBytesMap() (map[string][]byte, error) { + size, err := f.readShort() + if err != nil { + return nil, err + } m := make(map[string][]byte, size) - + var k string for i := 0; i < int(size); i++ { - k := f.readString() - v := f.readBytes() - m[k] = v + k, err = f.readString() + if err != nil { + return nil, err + } + m[k], err = f.readBytes() + if err != nil { + return nil, err + } } - - return m + return m, nil } -func (f *framer) readStringMultiMap() map[string][]string { - size := f.readShort() +func (f *framer) readStringMultiMap() (map[string][]string, error) { + size, err := f.readShort() m := make(map[string][]string, size) - + var k string for i := 0; i < int(size); i++ { - k := f.readString() - v := f.readStringList() - m[k] = v + k, err = f.readString() + if err != nil { + return nil, err + } + m[k], err = f.readStringList() + if err != nil { + return nil, err + } } - - return m + return m, nil } func (f *framer) writeByte(b byte) { diff --git a/frame_test.go b/frame_test.go index 6bf49222c..e2270d2c2 100644 --- a/frame_test.go +++ b/frame_test.go @@ -28,6 +28,7 @@ import ( "bytes" "errors" "os" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -61,15 +62,13 @@ func TestFuzzBugs(t *testing.T) { } for i, test := range tests { - t.Logf("test %d input: %q", i, test) - r := bytes.NewReader(test) head, err := readHeader(r, make([]byte, 9)) if err != nil { continue } - framer := newFramer(nil, byte(head.version)) + framer := newFramer(nil, byte(head.version), GlobalTypes) err = framer.readFrame(r, &head) if err != nil { continue @@ -90,7 +89,7 @@ func TestFrameWriteTooLong(t *testing.T) { t.Skip("skipping test in travis due to memory pressure with the race detecor") } - framer := newFramer(nil, 2) + framer := newFramer(nil, 2, GlobalTypes) framer.writeHeader(0, opStartup, 1) framer.writeBytes(make([]byte, maxFrameSize+1)) @@ -110,7 +109,7 @@ func TestFrameReadTooLong(t *testing.T) { // write a new header right after this frame to verify that we can read it r.Write([]byte{protoVersionMask & protoVersion3, 0x00, 0x00, 0x00, byte(opReady), 0x00, 0x00, 0x00, 0x00}) - framer := newFramer(nil, 3) + framer := newFramer(nil, 3, GlobalTypes) head := frameHeader{ version: protoVersion3, @@ -133,7 +132,7 @@ func TestFrameReadTooLong(t *testing.T) { } func Test_framer_writeExecuteFrame(t *testing.T) { - framer := newFramer(nil, protoVersion5) + framer := newFramer(nil, protoVersion5, GlobalTypes) nowInSeconds := 123 frame := writeExecuteFrame{ preparedID: []byte{1, 2, 3}, @@ -155,12 +154,31 @@ func Test_framer_writeExecuteFrame(t *testing.T) { // skipping header framer.buf = framer.buf[9:] - assertDeepEqual(t, "customPayload", frame.customPayload, framer.readBytesMap()) - assertDeepEqual(t, "preparedID", frame.preparedID, framer.readShortBytes()) - assertDeepEqual(t, "resultMetadataID", frame.resultMetadataID, framer.readShortBytes()) - assertDeepEqual(t, "constistency", frame.params.consistency, Consistency(framer.readShort())) + bm, err := framer.readBytesMap() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "customPayload", frame.customPayload, bm) + b, err := framer.readShortBytes() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "preparedID", frame.preparedID, b) + b, err = framer.readShortBytes() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "resultMetadataID", frame.resultMetadataID, b) + c, err := framer.readConsistency() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "constistency", frame.params.consistency, c) - flags := framer.readInt() + flags, err := framer.readInt() + if err != nil { + t.Fatal(err) + } if flags&int(flagWithNowInSeconds) != int(flagWithNowInSeconds) { t.Fatal("expected flagNowInSeconds to be set, but it is not") } @@ -169,12 +187,20 @@ func Test_framer_writeExecuteFrame(t *testing.T) { t.Fatal("expected flagWithKeyspace to be set, but it is not") } - assertDeepEqual(t, "keyspace", frame.params.keyspace, framer.readString()) - assertDeepEqual(t, "nowInSeconds", nowInSeconds, framer.readInt()) + k, err := framer.readString() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "keyspace", frame.params.keyspace, k) + secs, err := framer.readInt() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "nowInSeconds", nowInSeconds, secs) } func Test_framer_writeBatchFrame(t *testing.T) { - framer := newFramer(nil, protoVersion5) + framer := newFramer(nil, protoVersion5, GlobalTypes) nowInSeconds := 123 frame := writeBatchFrame{ customPayload: map[string][]byte{ @@ -191,17 +217,40 @@ func Test_framer_writeBatchFrame(t *testing.T) { // skipping header framer.buf = framer.buf[9:] - assertDeepEqual(t, "customPayload", frame.customPayload, framer.readBytesMap()) - assertDeepEqual(t, "typ", frame.typ, BatchType(framer.readByte())) - assertDeepEqual(t, "len(statements)", len(frame.statements), int(framer.readShort())) - assertDeepEqual(t, "consistency", frame.consistency, Consistency(framer.readShort())) + bm, err := framer.readBytesMap() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "customPayload", frame.customPayload, bm) + b, err := framer.readByte() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "typ", frame.typ, BatchType(b)) + l, err := framer.readShort() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "len(statements)", len(frame.statements), int(l)) + c, err := framer.readConsistency() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "consistency", frame.consistency, c) - flags := framer.readInt() + flags, err := framer.readInt() + if err != nil { + t.Fatal(err) + } if flags&int(flagWithNowInSeconds) != int(flagWithNowInSeconds) { t.Fatal("expected flagNowInSeconds to be set, but it is not") } - assertDeepEqual(t, "nowInSeconds", nowInSeconds, framer.readInt()) + secs, err := framer.readInt() + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "nowInSeconds", nowInSeconds, secs) } type testMockedCompressor struct { @@ -295,7 +344,7 @@ func Test_readUncompressedFrame(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - framer := newFramer(nil, protoVersion5) + framer := newFramer(nil, protoVersion5, GlobalTypes) req := writeQueryFrame{ statement: "SELECT * FROM system.local", params: queryParams{ @@ -406,7 +455,7 @@ func Test_readCompressedFrame(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - framer := newFramer(nil, protoVersion5) + framer := newFramer(nil, protoVersion5, GlobalTypes) req := writeQueryFrame{ statement: "SELECT * FROM system.local", params: queryParams{ @@ -441,3 +490,346 @@ func Test_readCompressedFrame(t *testing.T) { }) } } + +func TestFrameReadParam(t *testing.T) { + testCases := []struct { + Write func(*framer) + Param interface{} + Exp interface{} + }{ + { + Write: func(f *framer) { + f.writeString("foo") + }, + Param: "", + Exp: "foo", + }, + { + Write: func(f *framer) { + f.writeShort(2) + f.writeString("foo") + f.writeString("bar") + }, + Param: []interface{}{"", ""}, + Exp: []interface{}{"foo", "bar"}, + }, + { + Write: func(f *framer) { + f.writeShort(uint16(TypeBoolean)) + }, + Param: (*TypeInfo)(nil), + Exp: booleanTypeInfo{}, + }, + { + Write: func(f *framer) { + f.writeInt(5) + }, + Param: int(0), + Exp: int(5), + }, + { + Write: func(f *framer) { + f.writeInt(5) + }, + Param: new(int), + Exp: int(5), + }, + { + Write: func(f *framer) { + f.writeShort(5) + }, + Param: uint16(0), + Exp: uint16(5), + }, + { + Write: func(f *framer) { + f.writeByte(10) + }, + Param: byte(0), + Exp: byte(10), + }, + { + Write: func(f *framer) { + f.writeShort(uint16(TypeBoolean)) + }, + Param: Type(0), + Exp: TypeBoolean, + }, + { + Write: func(f *framer) { + f.writeShort(2) + f.writeShort(uint16(TypeBoolean)) + f.writeShort(uint16(TypeBoolean)) + }, + Param: []TypeInfo{}, + Exp: []TypeInfo{booleanTypeInfo{}, booleanTypeInfo{}}, + }, + { + Write: func(f *framer) { + f.writeShort(2) + f.writeShort(uint16(TypeBoolean)) + f.writeShort(uint16(TypeBoolean)) + }, + Param: []TypeInfo{nil, nil}, + Exp: []TypeInfo{booleanTypeInfo{}, booleanTypeInfo{}}, + }, + { + Write: func(f *framer) { + f.writeInt(5) + }, + Param: func() *interface{} { + var i interface{} + i = int(0) + return &i + }(), + Exp: int(5), + }, + } + for i := range testCases { + framer := newFramer(nil, 4, GlobalTypes) + testCases[i].Write(framer) + res, err := framer.readParam(testCases[i].Param) + if err != nil { + t.Errorf("[%d] unexpected error: %v", i, err) + } else if !reflect.DeepEqual(res, testCases[i].Exp) { + t.Errorf("[%d] expected %+v, got %+v", i, testCases[i].Exp, res) + } + } +} + +func TestFrameReadTypeInfo(t *testing.T) { + tests := []struct { + name string + typ Type + more func(f *framer) + custom string + expected TypeInfo + }{ + { + name: "text", + typ: TypeVarchar, + expected: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, + }, + { + name: "boolean", + typ: TypeBoolean, + expected: booleanTypeInfo{}, + }, + { + name: "set_int", + typ: TypeSet, + more: func(f *framer) { + f.writeShort(uint16(TypeInt)) + }, + expected: CollectionType{ + typ: TypeSet, + Elem: intTypeInfo{}, + }, + }, + { + name: "list_int", + typ: TypeList, + more: func(f *framer) { + f.writeShort(uint16(TypeInt)) + }, + expected: CollectionType{ + typ: TypeList, + Elem: intTypeInfo{}, + }, + }, + { + name: "list_list_int", + typ: TypeList, + more: func(f *framer) { + f.writeShort(uint16(TypeList)) + f.writeShort(uint16(TypeInt)) + }, + expected: CollectionType{ + typ: TypeList, + Elem: CollectionType{ + typ: TypeList, + Elem: intTypeInfo{}, + }, + }, + }, + { + name: "map_int_int", + typ: TypeMap, + more: func(f *framer) { + f.writeShort(uint16(TypeInt)) + f.writeShort(uint16(TypeInt)) + }, + expected: CollectionType{ + typ: TypeMap, + Key: intTypeInfo{}, + Elem: intTypeInfo{}, + }, + }, + { + name: "list_list_int", + typ: TypeUDT, + more: func(f *framer) { + f.writeString("gocql_test") + f.writeString("person") + f.writeShort(3) + f.writeString("first_name") + f.writeShort(uint16(TypeVarchar)) + f.writeString("last_name") + f.writeShort(uint16(TypeVarchar)) + f.writeString("age") + f.writeShort(uint16(TypeInt)) + }, + expected: UDTTypeInfo{ + Keyspace: "gocql_test", + Name: "person", + Elements: []UDTField{ + {Name: "first_name", Type: varcharLikeTypeInfo{typ: TypeVarchar}}, + {Name: "last_name", Type: varcharLikeTypeInfo{typ: TypeVarchar}}, + {Name: "age", Type: intTypeInfo{}}, + }, + }, + }, + { + name: "tuple_int_int", + typ: TypeTuple, + more: func(f *framer) { + f.writeShort(2) + f.writeShort(uint16(TypeInt)) + f.writeShort(uint16(TypeInt)) + }, + expected: TupleTypeInfo{ + Elems: []TypeInfo{ + intTypeInfo{}, + intTypeInfo{}, + }, + }, + }, + + // these abuse the custom type to test some cases of typeInfoFromString + { + name: "vector_text", + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 3)", + expected: VectorType{ + SubType: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, + Dimensions: 3, + }, + }, + { + name: "vector_set_int", + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type), 2)", + expected: VectorType{ + SubType: CollectionType{ + typ: TypeSet, + Elem: intTypeInfo{}, + }, + Dimensions: 2, + }, + }, + { + name: "vector_udt", + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UserType(gocql_test,706572736f6e,66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,6c6173745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,616765:org.apache.cassandra.db.marshal.Int32Type), 2)", + expected: VectorType{ + SubType: UDTTypeInfo{ + Keyspace: "gocql_test", + Name: "person", + Elements: []UDTField{ + {Name: "first_name", Type: varcharLikeTypeInfo{typ: TypeVarchar}}, + {Name: "last_name", Type: varcharLikeTypeInfo{typ: TypeVarchar}}, + {Name: "age", Type: intTypeInfo{}}, + }, + }, + Dimensions: 2, + }, + }, + { + name: "vector_tuple", + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TupleType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type), 2)", + expected: VectorType{ + SubType: TupleTypeInfo{ + Elems: []TypeInfo{ + varcharLikeTypeInfo{typ: TypeVarchar}, + intTypeInfo{}, + varcharLikeTypeInfo{typ: TypeVarchar}, + }, + }, + Dimensions: 2, + }, + }, + { + name: "vector_vector_inet", + typ: TypeCustom, + custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)", + expected: VectorType{ + SubType: VectorType{ + SubType: inetType{}, + Dimensions: 2, + }, + Dimensions: 3, + }, + }, + } + + // org.apache.cassandra.db.marshal.VectorType(%s, 2) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := newFramer(nil, 4, GlobalTypes) + f.writeShort(uint16(test.typ)) + if test.typ == TypeCustom { + f.writeString(test.custom) + } else if test.more != nil { + test.more(f) + } + parsedType, err := f.readTypeInfo() + require.NoError(t, err) + if len(f.buf) != 0 { + t.Errorf("frame's buffer was not empty after readTypeInfo: %d left", len(f.buf)) + } + if !reflect.DeepEqual(test.expected, parsedType) { + t.Errorf("expected (%#v) but was (%#v) instead", test.expected, parsedType) + } + }) + } +} + +func BenchmarkFramerReadCol_Tuple(b *testing.B) { + b.ReportAllocs() + framer := newFramer(nil, 4, GlobalTypes) + framer.writeString("foo") + framer.writeShort(uint16(TypeTuple)) + framer.writeShort(uint16(2)) + framer.writeShort(uint16(TypeVarchar)) + framer.writeShort(uint16(TypeVarchar)) + buf := framer.buf + var col ColumnInfo + + b.ResetTimer() + for i := 0; i < b.N; i++ { + framer.buf = buf + _ = framer.readCol(&col, nil, true, "", "") + } +} + +func BenchmarkFramerReadCol_Set(b *testing.B) { + b.ReportAllocs() + + framer := newFramer(nil, 4, GlobalTypes) + framer.writeString("foo") + framer.writeShort(uint16(TypeSet)) + framer.writeShort(uint16(TypeInt)) + buf := framer.buf + var col ColumnInfo + + b.ResetTimer() + for i := 0; i < b.N; i++ { + framer.buf = buf + _ = framer.readCol(&col, nil, true, "", "") + } +} diff --git a/helpers.go b/helpers.go index 79842b7f0..a391a11ef 100644 --- a/helpers.go +++ b/helpers.go @@ -25,394 +25,22 @@ package gocql import ( - "encoding/hex" "fmt" - "math/big" "net" "reflect" - "strconv" - "strings" - "time" - - "gopkg.in/inf.v0" ) +// RowData contains the column names and pointers to the default values for each +// column type RowData struct { Columns []string Values []interface{} } -func goType(t TypeInfo) (reflect.Type, error) { - switch t.Type() { - case TypeVarchar, TypeAscii, TypeInet, TypeText: - return reflect.TypeOf(*new(string)), nil - case TypeBigInt, TypeCounter: - return reflect.TypeOf(*new(int64)), nil - case TypeTime: - return reflect.TypeOf(*new(time.Duration)), nil - case TypeTimestamp: - return reflect.TypeOf(*new(time.Time)), nil - case TypeBlob: - return reflect.TypeOf(*new([]byte)), nil - case TypeBoolean: - return reflect.TypeOf(*new(bool)), nil - case TypeFloat: - return reflect.TypeOf(*new(float32)), nil - case TypeDouble: - return reflect.TypeOf(*new(float64)), nil - case TypeInt: - return reflect.TypeOf(*new(int)), nil - case TypeSmallInt: - return reflect.TypeOf(*new(int16)), nil - case TypeTinyInt: - return reflect.TypeOf(*new(int8)), nil - case TypeDecimal: - return reflect.TypeOf(*new(*inf.Dec)), nil - case TypeUUID, TypeTimeUUID: - return reflect.TypeOf(*new(UUID)), nil - case TypeList, TypeSet: - elemType, err := goType(t.(CollectionType).Elem) - if err != nil { - return nil, err - } - return reflect.SliceOf(elemType), nil - case TypeMap: - keyType, err := goType(t.(CollectionType).Key) - if err != nil { - return nil, err - } - valueType, err := goType(t.(CollectionType).Elem) - if err != nil { - return nil, err - } - return reflect.MapOf(keyType, valueType), nil - case TypeVarint: - return reflect.TypeOf(*new(*big.Int)), nil - case TypeTuple: - // what can we do here? all there is to do is to make a list of interface{} - tuple := t.(TupleTypeInfo) - return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil - case TypeUDT: - return reflect.TypeOf(make(map[string]interface{})), nil - case TypeDate: - return reflect.TypeOf(*new(time.Time)), nil - case TypeDuration: - return reflect.TypeOf(*new(Duration)), nil - default: - return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) - } -} - func dereference(i interface{}) interface{} { return reflect.Indirect(reflect.ValueOf(i)).Interface() } -func getCassandraBaseType(name string) Type { - switch name { - case "ascii": - return TypeAscii - case "bigint": - return TypeBigInt - case "blob": - return TypeBlob - case "boolean": - return TypeBoolean - case "counter": - return TypeCounter - case "date": - return TypeDate - case "decimal": - return TypeDecimal - case "double": - return TypeDouble - case "duration": - return TypeDuration - case "float": - return TypeFloat - case "int": - return TypeInt - case "smallint": - return TypeSmallInt - case "tinyint": - return TypeTinyInt - case "time": - return TypeTime - case "timestamp": - return TypeTimestamp - case "uuid": - return TypeUUID - case "varchar": - return TypeVarchar - case "text": - return TypeText - case "varint": - return TypeVarint - case "timeuuid": - return TypeTimeUUID - case "inet": - return TypeInet - case "MapType": - return TypeMap - case "ListType": - return TypeList - case "SetType": - return TypeSet - case "TupleType": - return TypeTuple - default: - return TypeCustom - } -} - -// TODO: Cover with unit tests. -// Parses long Java-style type definition to internal data structures. -func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo { - if strings.HasPrefix(name, SET_TYPE) { - return CollectionType{ - NativeType: NewNativeType(protoVer, TypeSet), - Elem: getCassandraLongType(unwrapCompositeTypeDefinition(name, SET_TYPE, '('), protoVer, logger), - } - } else if strings.HasPrefix(name, LIST_TYPE) { - return CollectionType{ - NativeType: NewNativeType(protoVer, TypeList), - Elem: getCassandraLongType(unwrapCompositeTypeDefinition(name, LIST_TYPE, '('), protoVer, logger), - } - } else if strings.HasPrefix(name, MAP_TYPE) { - names := splitJavaCompositeTypes(name, MAP_TYPE) - if len(names) != 2 { - logger.Printf("gocql: error parsing map type, it has %d subelements, expecting 2\n", len(names)) - return NewNativeType(protoVer, TypeCustom) - } - return CollectionType{ - NativeType: NewNativeType(protoVer, TypeMap), - Key: getCassandraLongType(names[0], protoVer, logger), - Elem: getCassandraLongType(names[1], protoVer, logger), - } - } else if strings.HasPrefix(name, TUPLE_TYPE) { - names := splitJavaCompositeTypes(name, TUPLE_TYPE) - types := make([]TypeInfo, len(names)) - - for i, name := range names { - types[i] = getCassandraLongType(name, protoVer, logger) - } - - return TupleTypeInfo{ - NativeType: NewNativeType(protoVer, TypeTuple), - Elems: types, - } - } else if strings.HasPrefix(name, UDT_TYPE) { - names := splitJavaCompositeTypes(name, UDT_TYPE) - fields := make([]UDTField, len(names)-2) - - for i := 2; i < len(names); i++ { - spec := strings.Split(names[i], ":") - fieldName, _ := hex.DecodeString(spec[0]) - fields[i-2] = UDTField{ - Name: string(fieldName), - Type: getCassandraLongType(spec[1], protoVer, logger), - } - } - - udtName, _ := hex.DecodeString(names[1]) - return UDTTypeInfo{ - NativeType: NewNativeType(protoVer, TypeUDT), - KeySpace: names[0], - Name: string(udtName), - Elements: fields, - } - } else if strings.HasPrefix(name, VECTOR_TYPE) { - names := splitJavaCompositeTypes(name, VECTOR_TYPE) - subType := getCassandraLongType(strings.TrimSpace(names[0]), protoVer, logger) - dim, err := strconv.Atoi(strings.TrimSpace(names[1])) - if err != nil { - logger.Printf("gocql: error parsing vector dimensions: %v\n", err) - return NewNativeType(protoVer, TypeCustom) - } - - return VectorType{ - NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE), - SubType: subType, - Dimensions: dim, - } - } else { - // basic type - return NativeType{ - proto: protoVer, - typ: getApacheCassandraType(name), - } - } -} - -// Parses short CQL type representation (e.g. map) to internal data structures. -func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { - if strings.HasPrefix(name, "frozen<") { - return getCassandraType(unwrapCompositeTypeDefinition(name, "frozen", '<'), protoVer, logger) - } else if strings.HasPrefix(name, "set<") { - return CollectionType{ - NativeType: NewNativeType(protoVer, TypeSet), - Elem: getCassandraType(unwrapCompositeTypeDefinition(name, "set", '<'), protoVer, logger), - } - } else if strings.HasPrefix(name, "list<") { - return CollectionType{ - NativeType: NewNativeType(protoVer, TypeList), - Elem: getCassandraType(unwrapCompositeTypeDefinition(name, "list", '<'), protoVer, logger), - } - } else if strings.HasPrefix(name, "map<") { - names := splitCQLCompositeTypes(name, "map") - if len(names) != 2 { - logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) - return NewNativeType(protoVer, TypeCustom) - } - return CollectionType{ - NativeType: NewNativeType(protoVer, TypeMap), - Key: getCassandraType(names[0], protoVer, logger), - Elem: getCassandraType(names[1], protoVer, logger), - } - } else if strings.HasPrefix(name, "tuple<") { - names := splitCQLCompositeTypes(name, "tuple") - types := make([]TypeInfo, len(names)) - - for i, name := range names { - types[i] = getCassandraType(name, protoVer, logger) - } - - return TupleTypeInfo{ - NativeType: NewNativeType(protoVer, TypeTuple), - Elems: types, - } - } else if strings.HasPrefix(name, "vector<") { - names := splitCQLCompositeTypes(name, "vector") - subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger) - dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) - - return VectorType{ - NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE), - SubType: subType, - Dimensions: dim, - } - } else { - return NativeType{ - proto: protoVer, - typ: getCassandraBaseType(name), - } - } -} - -func splitCQLCompositeTypes(name string, typeName string) []string { - return splitCompositeTypes(name, typeName, '<', '>') -} - -func splitJavaCompositeTypes(name string, typeName string) []string { - return splitCompositeTypes(name, typeName, '(', ')') -} - -func unwrapCompositeTypeDefinition(name string, typeName string, typeOpen int32) string { - return strings.TrimPrefix(name[:len(name)-1], typeName+string(typeOpen)) -} - -func splitCompositeTypes(name string, typeName string, typeOpen int32, typeClose int32) []string { - def := unwrapCompositeTypeDefinition(name, typeName, typeOpen) - if !strings.Contains(def, string(typeOpen)) { - parts := strings.Split(def, ",") - for i := range parts { - parts[i] = strings.TrimSpace(parts[i]) - } - return parts - } - var parts []string - lessCount := 0 - segment := "" - for _, char := range def { - if char == ',' && lessCount == 0 { - if segment != "" { - parts = append(parts, strings.TrimSpace(segment)) - } - segment = "" - continue - } - segment += string(char) - if char == typeOpen { - lessCount++ - } else if char == typeClose { - lessCount-- - } - } - if segment != "" { - parts = append(parts, strings.TrimSpace(segment)) - } - return parts -} - -func getApacheCassandraType(class string) Type { - switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { - case "AsciiType": - return TypeAscii - case "LongType": - return TypeBigInt - case "BytesType": - return TypeBlob - case "BooleanType": - return TypeBoolean - case "CounterColumnType": - return TypeCounter - case "DecimalType": - return TypeDecimal - case "DoubleType": - return TypeDouble - case "FloatType": - return TypeFloat - case "Int32Type": - return TypeInt - case "ShortType": - return TypeSmallInt - case "ByteType": - return TypeTinyInt - case "TimeType": - return TypeTime - case "DateType", "TimestampType": - return TypeTimestamp - case "UUIDType", "LexicalUUIDType": - return TypeUUID - case "UTF8Type": - return TypeVarchar - case "IntegerType": - return TypeVarint - case "TimeUUIDType": - return TypeTimeUUID - case "InetAddressType": - return TypeInet - case "MapType": - return TypeMap - case "ListType": - return TypeList - case "SetType": - return TypeSet - case "TupleType": - return TypeTuple - case "DurationType": - return TypeDuration - case "SimpleDateType": - return TypeDate - case "UserType": - return TypeUDT - default: - return TypeCustom - } -} - -func (r *RowData) rowMap(m map[string]interface{}) { - for i, column := range r.Columns { - val := dereference(r.Values[i]) - if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice && !valVal.IsNil() { - valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap()) - reflect.Copy(valCopy, valVal) - m[column] = valCopy.Interface() - } else { - m[column] = val - } - } -} - // TupeColumnName will return the column name of a tuple value in a column named // c at index n. It should be used if a specific element within a tuple is needed // to be extracted from a map returned from SliceMap or MapScan. @@ -431,22 +59,15 @@ func (iter *Iter) RowData() (RowData, error) { for _, column := range iter.Columns() { if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { - val, err := column.TypeInfo.NewWithError() - if err != nil { - iter.err = err - return RowData{}, err - } + val := c.Zero() columns = append(columns, column.Name) - values = append(values, val) + values = append(values, &val) } else { for i, elem := range c.Elems { columns = append(columns, TupleColumnName(column.Name, i)) - val, err := elem.NewWithError() - if err != nil { - iter.err = err - return RowData{}, err - } - values = append(values, val) + var val interface{} + val = elem.Zero() + values = append(values, &val) } } } @@ -459,22 +80,6 @@ func (iter *Iter) RowData() (RowData, error) { return rowData, nil } -// TODO(zariel): is it worth exporting this? -func (iter *Iter) rowMap() (map[string]interface{}, error) { - if iter.err != nil { - return nil, iter.err - } - - rowData, err := iter.RowData() - if err != nil { - return nil, err - } - iter.Scan(rowData.Values...) - m := make(map[string]interface{}, len(rowData.Columns)) - rowData.rowMap(m) - return m, nil -} - // SliceMap is a helper function to make the API easier to use // returns the data from the query in the form of []map[string]interface{} func (iter *Iter) SliceMap() ([]map[string]interface{}, error) { @@ -482,15 +87,13 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) { return nil, iter.err } - // Not checking for the error because we just did - rowData, err := iter.RowData() - if err != nil { - return nil, err - } - dataToReturn := make([]map[string]interface{}, 0) - for iter.Scan(rowData.Values...) { - m := make(map[string]interface{}, len(rowData.Columns)) - rowData.rowMap(m) + numCols := len(iter.Columns()) + var dataToReturn []map[string]interface{} + for { + m := make(map[string]interface{}, numCols) + if !iter.MapScan(m) { + break + } dataToReturn = append(dataToReturn, m) } if iter.err != nil { @@ -542,19 +145,43 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool { return false } - rowData, err := iter.RowData() - if err != nil { - return false - } - - for i, col := range rowData.Columns { - if dest, ok := m[col]; ok { - rowData.Values[i] = dest + cols := iter.Columns() + columnNames := make([]string, 0, len(cols)) + values := make([]interface{}, 0, len(cols)) + for _, column := range iter.Columns() { + if c, ok := column.TypeInfo.(TupleTypeInfo); ok { + for i := range c.Elems { + columnName := TupleColumnName(column.Name, i) + if dest, ok := m[columnName]; ok { + values = append(values, dest) + } else { + zero := c.Elems[i].Zero() + // technically this is a *interface{} but later we will fix it + values = append(values, &zero) + } + columnNames = append(columnNames, columnName) + } + } else { + if dest, ok := m[column.Name]; ok { + values = append(values, dest) + } else { + zero := column.TypeInfo.Zero() + // technically this is a *interface{} but later we will fix it + values = append(values, &zero) + } + columnNames = append(columnNames, column.Name) } } - - if iter.Scan(rowData.Values...) { - rowData.rowMap(m) + if iter.Scan(values...) { + for i, name := range columnNames { + if iptr, ok := values[i].(*interface{}); ok { + m[name] = *iptr + } else { + // TODO: it seems wrong to dereference the values that were passed in + // originally in the map but that's what it was doing before + m[name] = dereference(values[i]) + } + } return true } return false diff --git a/helpers_test.go b/helpers_test.go index 275752aa0..45903a695 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -29,8 +29,11 @@ import ( "testing" ) -func TestGetCassandraType_Set(t *testing.T) { - typ := getCassandraType("set", protoVersion4, &defaultLogger{}) +func TestGetCassandraTypeInfo_Set(t *testing.T) { + typ, err := GlobalTypes.typeInfoFromString(protoVersion4, "set") + if err != nil { + t.Fatal(err) + } set, ok := typ.(CollectionType) if !ok { t.Fatalf("expected CollectionType got %T", typ) @@ -38,64 +41,57 @@ func TestGetCassandraType_Set(t *testing.T) { t.Fatalf("expected type %v got %v", TypeSet, set.typ) } - inner, ok := set.Elem.(NativeType) + inner, ok := set.Elem.(TypeInfo) if !ok { - t.Fatalf("expected to get NativeType got %T", set.Elem) - } else if inner.typ != TypeText { + t.Fatalf("expected to get TypeInfo got %T", set.Elem) + } else if inner.Type() != TypeText { t.Fatalf("expected to get %v got %v for set value", TypeText, set.typ) } } -func TestGetCassandraType(t *testing.T) { +func TestGetCassandraTypeInfo(t *testing.T) { tests := []struct { input string exp TypeInfo }{ { "set", CollectionType{ - NativeType: NativeType{typ: TypeSet}, - - Elem: NativeType{typ: TypeText}, + typ: TypeSet, + Elem: varcharLikeTypeInfo{typ: TypeText}, }, }, { "map", CollectionType{ - NativeType: NativeType{typ: TypeMap}, - - Key: NativeType{typ: TypeText}, - Elem: NativeType{typ: TypeVarchar}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeText}, + Elem: varcharLikeTypeInfo{typ: TypeVarchar}, }, }, { "list", CollectionType{ - NativeType: NativeType{typ: TypeList}, - Elem: NativeType{typ: TypeInt}, + typ: TypeList, + Elem: intTypeInfo{}, }, }, { "tuple", TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, - Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeInt}, - NativeType{typ: TypeText}, + intTypeInfo{}, + intTypeInfo{}, + varcharLikeTypeInfo{typ: TypeText}, }, }, }, { "frozen>>>>>", CollectionType{ - NativeType: NativeType{typ: TypeMap}, - - Key: NativeType{typ: TypeText}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeText}, Elem: CollectionType{ - NativeType: NativeType{typ: TypeList}, + typ: TypeList, Elem: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, - Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeInt}, + intTypeInfo{}, + intTypeInfo{}, }, }, }, @@ -104,50 +100,44 @@ func TestGetCassandraType(t *testing.T) { { "frozen>>>>>, frozen>>>>>, frozen>>>>>>>", TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ - NativeType{typ: TypeText}, + varcharLikeTypeInfo{typ: TypeText}, CollectionType{ - NativeType: NativeType{typ: TypeList}, + typ: TypeList, Elem: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeInt}, + intTypeInfo{}, + intTypeInfo{}, }, }, }, }, }, TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ - NativeType{typ: TypeText}, + varcharLikeTypeInfo{typ: TypeText}, CollectionType{ - NativeType: NativeType{typ: TypeList}, + typ: TypeList, Elem: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeInt}, + intTypeInfo{}, + intTypeInfo{}, }, }, }, }, }, CollectionType{ - NativeType: NativeType{typ: TypeMap}, - Key: NativeType{typ: TypeText}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeText}, Elem: CollectionType{ - NativeType: NativeType{typ: TypeList}, + typ: TypeList, Elem: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeInt}, + intTypeInfo{}, + intTypeInfo{}, }, }, }, @@ -157,24 +147,18 @@ func TestGetCassandraType(t *testing.T) { }, { "frozen>, int, frozen>>>", TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, - Elems: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, - Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeInt}, + intTypeInfo{}, + intTypeInfo{}, }, }, - NativeType{typ: TypeInt}, + intTypeInfo{}, TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, - Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeInt}, + intTypeInfo{}, + intTypeInfo{}, }, }, }, @@ -182,69 +166,54 @@ func TestGetCassandraType(t *testing.T) { }, { "frozen>, int>>", CollectionType{ - NativeType: NativeType{typ: TypeMap}, - + typ: TypeMap, Key: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, - Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeInt}, + intTypeInfo{}, + intTypeInfo{}, }, }, - Elem: NativeType{typ: TypeInt}, + Elem: intTypeInfo{}, }, }, { "set", CollectionType{ - NativeType: NativeType{typ: TypeSet}, - Elem: NativeType{typ: TypeSmallInt}, + typ: TypeSet, + Elem: smallIntTypeInfo{}, }, }, { "list", CollectionType{ - NativeType: NativeType{typ: TypeList}, - Elem: NativeType{typ: TypeTinyInt}, + typ: TypeList, + Elem: tinyIntTypeInfo{}, }, }, - {"smallint", NativeType{typ: TypeSmallInt}}, - {"tinyint", NativeType{typ: TypeTinyInt}}, - {"duration", NativeType{typ: TypeDuration}}, - {"date", NativeType{typ: TypeDate}}, + {"smallint", smallIntTypeInfo{}}, + {"tinyint", tinyIntTypeInfo{}}, + {"duration", durationTypeInfo{}}, + {"date", dateTypeInfo{}}, { "list", CollectionType{ - NativeType: NativeType{typ: TypeList}, - Elem: NativeType{typ: TypeDate}, + typ: TypeList, + Elem: dateTypeInfo{}, }, }, { "set", CollectionType{ - NativeType: NativeType{typ: TypeSet}, - Elem: NativeType{typ: TypeDuration}, + typ: TypeSet, + Elem: durationTypeInfo{}, }, }, { "vector", VectorType{ - NativeType: NativeType{ - typ: TypeCustom, - custom: VECTOR_TYPE, - }, - SubType: NativeType{typ: TypeFloat}, + SubType: floatTypeInfo{}, Dimensions: 3, }, }, { "vector, 5>", VectorType{ - NativeType: NativeType{ - typ: TypeCustom, - custom: VECTOR_TYPE, - }, SubType: VectorType{ - NativeType: NativeType{ - typ: TypeCustom, - custom: VECTOR_TYPE, - }, - SubType: NativeType{typ: TypeFloat}, + SubType: floatTypeInfo{}, Dimensions: 3, }, Dimensions: 5, @@ -252,29 +221,20 @@ func TestGetCassandraType(t *testing.T) { }, { "vector, 5>", VectorType{ - NativeType: NativeType{ - typ: TypeCustom, - custom: VECTOR_TYPE, - }, SubType: CollectionType{ - NativeType: NativeType{typ: TypeMap}, - Key: NativeType{typ: TypeUUID}, - Elem: NativeType{typ: TypeTimestamp}, + typ: TypeMap, + Key: uuidType{}, + Elem: timestampTypeInfo{}, }, Dimensions: 5, }, }, { "vector>, 100>", VectorType{ - NativeType: NativeType{ - typ: TypeCustom, - custom: VECTOR_TYPE, - }, SubType: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeFloat}, + intTypeInfo{}, + floatTypeInfo{}, }, }, Dimensions: 100, @@ -284,7 +244,10 @@ func TestGetCassandraType(t *testing.T) { for _, test := range tests { t.Run(test.input, func(t *testing.T) { - got := getCassandraType(test.input, 0, &defaultLogger{}) + got, err := GlobalTypes.typeInfoFromString(protoVersion4, test.input) + if err != nil { + t.Fatal(err) + } // TODO(zariel): define an equal method on the types? if !reflect.DeepEqual(got, test.exp) { diff --git a/host_source.go b/host_source.go index 396bcfe1e..ece5d3e5a 100644 --- a/host_source.go +++ b/host_source.go @@ -499,7 +499,7 @@ func (s *Session) newHostInfoFromMap(addr net.IP, port int, row map[string]inter // Given a map that represents a row from either system.local or system.peers // return as much information as we can in *HostInfo func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) { - const assertErrorMsg = "Assertion failed for %s" + const assertErrorMsg = "Assertion failed for %s, type was %T" var ok bool // Default to our connected port if the cluster doesn't have port information @@ -508,101 +508,101 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* case "data_center": host.dataCenter, ok = value.(string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "data_center") + return nil, fmt.Errorf(assertErrorMsg, "data_center", value) } case "rack": host.rack, ok = value.(string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "rack") + return nil, fmt.Errorf(assertErrorMsg, "rack", value) } case "host_id": hostId, ok := value.(UUID) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "host_id") + return nil, fmt.Errorf(assertErrorMsg, "host_id", value) } host.hostId = hostId.String() case "release_version": version, ok := value.(string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "release_version") + return nil, fmt.Errorf(assertErrorMsg, "release_version", value) } host.version.Set(version) case "peer": - ip, ok := value.(string) + ip, ok := value.(net.IP) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "peer") + return nil, fmt.Errorf(assertErrorMsg, "peer", value) } - host.peer = net.ParseIP(ip) + host.peer = ip case "cluster_name": host.clusterName, ok = value.(string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "cluster_name") + return nil, fmt.Errorf(assertErrorMsg, "cluster_name", value) } case "partitioner": host.partitioner, ok = value.(string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "partitioner") + return nil, fmt.Errorf(assertErrorMsg, "partitioner", value) } case "broadcast_address": - ip, ok := value.(string) + ip, ok := value.(net.IP) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "broadcast_address") + return nil, fmt.Errorf(assertErrorMsg, "broadcast_address", value) } - host.broadcastAddress = net.ParseIP(ip) + host.broadcastAddress = ip case "preferred_ip": - ip, ok := value.(string) + ip, ok := value.(net.IP) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "preferred_ip") + return nil, fmt.Errorf(assertErrorMsg, "preferred_ip", value) } - host.preferredIP = net.ParseIP(ip) + host.preferredIP = ip case "rpc_address": - ip, ok := value.(string) + ip, ok := value.(net.IP) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "rpc_address") + return nil, fmt.Errorf(assertErrorMsg, "rpc_address", value) } - host.rpcAddress = net.ParseIP(ip) + host.rpcAddress = ip case "native_address": - ip, ok := value.(string) + ip, ok := value.(net.IP) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "native_address") + return nil, fmt.Errorf(assertErrorMsg, "native_address", value) } - host.rpcAddress = net.ParseIP(ip) + host.rpcAddress = ip case "listen_address": - ip, ok := value.(string) + ip, ok := value.(net.IP) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "listen_address") + return nil, fmt.Errorf(assertErrorMsg, "listen_address", value) } - host.listenAddress = net.ParseIP(ip) + host.listenAddress = ip case "native_port": native_port, ok := value.(int) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "native_port") + return nil, fmt.Errorf(assertErrorMsg, "native_port", value) } host.port = native_port case "workload": host.workload, ok = value.(string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "workload") + return nil, fmt.Errorf(assertErrorMsg, "workload", value) } case "graph": host.graph, ok = value.(bool) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "graph") + return nil, fmt.Errorf(assertErrorMsg, "graph", value) } case "tokens": host.tokens, ok = value.([]string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "tokens") + return nil, fmt.Errorf(assertErrorMsg, "tokens", value) } case "dse_version": host.dseVersion, ok = value.(string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "dse_version") + return nil, fmt.Errorf(assertErrorMsg, "dse_version", value) } case "schema_version": schemaVersion, ok := value.(UUID) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "schema_version") + return nil, fmt.Errorf(assertErrorMsg, "schema_version", value) } host.schemaVersion = schemaVersion.String() } diff --git a/host_source_test.go b/host_source_test.go index c8f93c78c..e2454be56 100644 --- a/host_source_test.go +++ b/host_source_test.go @@ -29,7 +29,6 @@ package gocql import ( "errors" - "fmt" "net" "sync" "sync/atomic" @@ -58,7 +57,6 @@ func TestUnmarshalCassVersion(t *testing.T) { } else if *v != test.version { t.Errorf("%d: expected %#+v got %#+v", i, test.version, *v) } - fmt.Println(v.String()) } } diff --git a/integration_test.go b/integration_test.go index ccc5939d6..18acaf4af 100644 --- a/integration_test.go +++ b/integration_test.go @@ -204,7 +204,7 @@ func TestCustomPayloadMessages(t *testing.T) { iter := query.Iter() rCustomPayload := iter.GetCustomPayload() if !reflect.DeepEqual(customPayload, rCustomPayload) { - t.Fatal("The received custom payload should match the sent") + t.Fatalf("The received custom payload %#v should match the sent %#v", rCustomPayload, customPayload) } iter.Close() @@ -213,7 +213,7 @@ func TestCustomPayloadMessages(t *testing.T) { iter = query.Iter() rCustomPayload = iter.GetCustomPayload() if !reflect.DeepEqual(customPayload, rCustomPayload) { - t.Fatal("The received custom payload should match the sent") + t.Fatalf("The received custom payload %#v should match the sent %#v", rCustomPayload, customPayload) } iter.Close() @@ -242,7 +242,7 @@ func TestCustomPayloadValues(t *testing.T) { iter := query.Iter() rCustomPayload := iter.GetCustomPayload() if !reflect.DeepEqual(customPayload, rCustomPayload) { - t.Fatal("The received custom payload should match the sent") + t.Fatalf("The received custom payload %#v should match the sent %#v", rCustomPayload, customPayload) } } } diff --git a/marshal.go b/marshal.go index 65ab656cb..5133a96d7 100644 --- a/marshal.go +++ b/marshal.go @@ -27,6 +27,7 @@ package gocql import ( "bytes" "encoding/binary" + "encoding/hex" "errors" "fmt" "math" @@ -42,8 +43,7 @@ import ( ) var ( - bigOne = big.NewInt(1) - emptyValue reflect.Value + bigOne = big.NewInt(1) ) var ( @@ -113,9 +113,6 @@ type Unmarshaler interface { // // The marshal/unmarshal error provides a list of supported types when an unsupported type is attempted. func Marshal(info TypeInfo, value interface{}) ([]byte, error) { - if info.Version() < protoVersion1 { - panic("protocol version not set") - } if valueRef := reflect.ValueOf(value); valueRef.Kind() == reflect.Ptr { if valueRef.IsNil() { return nil, nil @@ -130,60 +127,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { return v.MarshalCQL(info) } - switch info.Type() { - case TypeVarchar, TypeAscii, TypeBlob, TypeText: - return marshalVarchar(info, value) - case TypeBoolean: - return marshalBool(info, value) - case TypeTinyInt: - return marshalTinyInt(info, value) - case TypeSmallInt: - return marshalSmallInt(info, value) - case TypeInt: - return marshalInt(info, value) - case TypeBigInt, TypeCounter: - return marshalBigInt(info, value) - case TypeFloat: - return marshalFloat(info, value) - case TypeDouble: - return marshalDouble(info, value) - case TypeDecimal: - return marshalDecimal(info, value) - case TypeTime: - return marshalTime(info, value) - case TypeTimestamp: - return marshalTimestamp(info, value) - case TypeList, TypeSet: - return marshalList(info, value) - case TypeMap: - return marshalMap(info, value) - case TypeUUID, TypeTimeUUID: - return marshalUUID(info, value) - case TypeVarint: - return marshalVarint(info, value) - case TypeInet: - return marshalInet(info, value) - case TypeTuple: - return marshalTuple(info, value) - case TypeUDT: - return marshalUDT(info, value) - case TypeDate: - return marshalDate(info, value) - case TypeDuration: - return marshalDuration(info, value) - case TypeCustom: - if vector, ok := info.(VectorType); ok { - return marshalVector(vector, value) - } - } - - // detect protocol 2 UDT - if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { - return nil, ErrorUDTUnavailable - } - - // TODO(tux21b): add the remaining types - return nil, fmt.Errorf("can not marshal %T into %s", value, info) + return info.Marshal(value) } // Unmarshal parses the CQL encoded data based on the info parameter that @@ -232,95 +176,97 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { return v.UnmarshalCQL(info, data) } - if isNullableValue(value) { - return unmarshalNullable(info, data, value) - } - - switch info.Type() { - case TypeVarchar, TypeAscii, TypeBlob, TypeText: - return unmarshalVarchar(info, data, value) - case TypeBoolean: - return unmarshalBool(info, data, value) - case TypeInt: - return unmarshalInt(info, data, value) - case TypeBigInt, TypeCounter: - return unmarshalBigInt(info, data, value) - case TypeVarint: - return unmarshalVarint(info, data, value) - case TypeSmallInt: - return unmarshalSmallInt(info, data, value) - case TypeTinyInt: - return unmarshalTinyInt(info, data, value) - case TypeFloat: - return unmarshalFloat(info, data, value) - case TypeDouble: - return unmarshalDouble(info, data, value) - case TypeDecimal: - return unmarshalDecimal(info, data, value) - case TypeTime: - return unmarshalTime(info, data, value) - case TypeTimestamp: - return unmarshalTimestamp(info, data, value) - case TypeList, TypeSet: - return unmarshalList(info, data, value) - case TypeMap: - return unmarshalMap(info, data, value) - case TypeTimeUUID: - return unmarshalTimeUUID(info, data, value) - case TypeUUID: - return unmarshalUUID(info, data, value) - case TypeInet: - return unmarshalInet(info, data, value) - case TypeTuple: - return unmarshalTuple(info, data, value) - case TypeUDT: - return unmarshalUDT(info, data, value) - case TypeDate: - return unmarshalDate(info, data, value) - case TypeDuration: - return unmarshalDuration(info, data, value) - case TypeCustom: - if vector, ok := info.(VectorType); ok { - return unmarshalVector(vector, data, value) + // check for pointer + // we don't error for non-pointers because certain types support unmarshalling + // into maps/slices + valueRef := reflect.ValueOf(value) + if valueRef.Kind() == reflect.Ptr { + // handle pointers and nil data + valueElemRef := valueRef.Elem() + switch valueElemRef.Kind() { + case reflect.Ptr: + if data == nil { + if valueElemRef.IsNil() { + return nil + } + valueRef.Elem().Set(reflect.Zero(valueElemRef.Type())) + return nil + } + // we discussed wrapping this in valueElemRef.IsNil() since we don't need + // to re-allocate if its non-nil but this was safer and what it was doing + // before and we didn't want to surprise anyone that relies on this + // in case the pointer is nil, we call type first then elem to get the type + // of the underlying value regardless if the pointer is nil or not + newValue := reflect.New(valueElemRef.Type().Elem()) + valueElemRef.Set(newValue) + // call Unmarshal again to unwrap the value + return Unmarshal(info, data, valueElemRef.Interface()) + case reflect.Slice, reflect.Map: + if data == nil { + if valueElemRef.IsNil() { + return nil + } + valueRef.Elem().Set(reflect.Zero(valueElemRef.Type())) + return nil + } + case reflect.Interface: + // set to zero value of the the empty interface value + if valueElemRef.NumMethod() == 0 && data == nil { + // once we have a reflect.Type of interface{} we lose the underlying type + // inside the interface, so we need to call Elem() on the value itself + // first before calling Type() but first we make sure that it's not + // an empty interface + if valueElemRef.IsValid() { + valueElemRef = valueElemRef.Elem() + } + valueRef.Elem().Set(reflect.Zero(valueElemRef.Type())) + return nil + } + if valueElemRef.IsValid() && valueElemRef.Elem().Kind() == reflect.Ptr { + // call Unmarshal again to unwrap the value + return Unmarshal(info, data, valueElemRef.Interface()) + } } } - // detect protocol 2 UDT - if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { - return ErrorUDTUnavailable - } - - // TODO(tux21b): add the remaining types - return fmt.Errorf("can not unmarshal %s into %T", info, value) + return info.Unmarshal(data, value) } -func isNullableValue(value interface{}) bool { - v := reflect.ValueOf(value) - return v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Ptr +type varcharLikeTypeInfo struct { + typ Type } -func isNullData(info TypeInfo, data []byte) bool { - return data == nil +// Type returns the underlying type itself. +func (v varcharLikeTypeInfo) Type() Type { + return v.typ } -func unmarshalNullable(info TypeInfo, data []byte, value interface{}) error { - valueRef := reflect.ValueOf(value) - - if isNullData(info, data) { - nilValue := reflect.Zero(valueRef.Type().Elem()) - valueRef.Elem().Set(nilValue) - return nil +// Zero returns the zero value for the varchar-like CQL type. +func (v varcharLikeTypeInfo) Zero() interface{} { + if v.typ == TypeBlob { + return []byte(nil) } + return "" +} - newValue := reflect.New(valueRef.Type().Elem().Elem()) - valueRef.Elem().Set(newValue) - return Unmarshal(info, data, newValue.Interface()) +func (v varcharLikeTypeInfo) typeString() string { + switch v.typ { + case TypeVarchar: + return "varchar" + case TypeAscii: + return "ascii" + case TypeBlob: + return "blob" + case TypeText: + return "text" + default: + return "unknown" + } } -func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { +// Marshal marshals the value into a byte slice. +func (vt varcharLikeTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case string: @@ -342,13 +288,12 @@ func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: return rv.Bytes(), nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, string, []byte, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, string, []byte, UnsetValue.", value, vt.typeString()) } -func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (vt varcharLikeTypeInfo) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *string: *v = string(data) return nil @@ -359,6 +304,18 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { *v = nil } return nil + case *interface{}: + if data == nil { + *v = nil + return nil + } + if vt.typ == TypeBlob { + *v = make([]byte, len(data)) + copy((*v).([]byte), data) + } else { + *v = string(data) + } + return nil } rv := reflect.ValueOf(value) @@ -381,13 +338,24 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { rv.SetBytes(dataCopy) return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, *string, *[]byte", info, value) + return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: *string, *[]byte", vt.typeString(), value) +} + +type smallIntTypeInfo struct{} + +// Type returns the type itself. +func (smallIntTypeInfo) Type() Type { + return TypeSmallInt +} + +// Zero returns the zero value for the smallint CQL type. +func (smallIntTypeInfo) Zero() interface{} { + return int16(0) } -func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { +// Marshal marshals the value into a byte slice. +func (smallIntTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case int16: @@ -431,7 +399,7 @@ func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { case string: n, err := strconv.ParseInt(v, 10, 16) if err != nil { - return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) + return nil, marshalErrorf("can not marshal %T into smallint: %v", value, err) } return encShort(int16(n)), nil } @@ -459,13 +427,37 @@ func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { } } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, int16, uint16, int8, uint8, int, uint, int32, uint32, int64, uint64, string, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into smallint. Accepted types: Marshaler, int16, uint16, int8, uint8, int, uint, int32, uint32, int64, uint64, string, UnsetValue.", value) +} + +// Unmarshal unmarshals the byte slice into the value. +func (s smallIntTypeInfo) Unmarshal(data []byte, value interface{}) error { + if iptr, ok := value.(*interface{}); ok && iptr != nil { + var v int16 + if err := unmarshalIntlike(TypeSmallInt, int64(decShort(data)), data, &v); err != nil { + return err + } + *iptr = v + return nil + } + return unmarshalIntlike(TypeSmallInt, int64(decShort(data)), data, value) +} + +type tinyIntTypeInfo struct{} + +// Type returns the type itself. +func (tinyIntTypeInfo) Type() Type { + return TypeTinyInt +} + +// Zero returns the zero value for the tinyint CQL type. +func (tinyIntTypeInfo) Zero() interface{} { + return int8(0) } -func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { +// Marshal marshals the value into a byte slice. +func (tinyIntTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case int8: @@ -515,7 +507,7 @@ func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { case string: n, err := strconv.ParseInt(v, 10, 8) if err != nil { - return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) + return nil, marshalErrorf("can not marshal %T into tinyint: %v", value, err) } return []byte{byte(n)}, nil } @@ -543,13 +535,37 @@ func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { } } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, int8, uint8, int16, uint16, int, uint, int32, uint32, int64, uint64, string, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into tinyint. Accepted types: int8, uint8, int16, uint16, int, uint, int32, uint32, int64, uint64, string, UnsetValue.", value) +} + +// Unmarshal unmarshals the byte slice into the value. +func (t tinyIntTypeInfo) Unmarshal(data []byte, value interface{}) error { + if iptr, ok := value.(*interface{}); ok && iptr != nil { + var v int8 + if err := unmarshalIntlike(TypeSmallInt, int64(decShort(data)), data, &v); err != nil { + return err + } + *iptr = v + return nil + } + return unmarshalIntlike(TypeTinyInt, int64(decTiny(data)), data, value) +} + +type intTypeInfo struct{} + +// Type returns the type itself. +func (intTypeInfo) Type() Type { + return TypeInt } -func marshalInt(info TypeInfo, value interface{}) ([]byte, error) { +// Zero returns the zero value for the int CQL type. +func (intTypeInfo) Zero() interface{} { + return int(0) +} + +// Marshal marshals the value into a byte slice. +func (intTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case int: @@ -615,7 +631,20 @@ func marshalInt(info TypeInfo, value interface{}) ([]byte, error) { } } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, int8, uint8, int16, uint16, int, uint, int32, uint32, int64, uint64, string, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into int. Accepted types: int8, uint8, int16, uint16, int, uint, int32, uint32, int64, uint64, string, UnsetValue.", value) +} + +// Unmarshal unmarshals the byte slice into the value. +func (i intTypeInfo) Unmarshal(data []byte, value interface{}) error { + if iptr, ok := value.(*interface{}); ok && iptr != nil { + var v int + if err := unmarshalIntlike(TypeInt, int64(decInt(data)), data, &v); err != nil { + return err + } + *iptr = v + return nil + } + return unmarshalIntlike(TypeInt, int64(decInt(data)), data, value) } func encInt(x int32) []byte { @@ -650,10 +679,23 @@ func decTiny(p []byte) int8 { return int8(p[0]) } -func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { +type bigIntLikeTypeInfo struct { + typ Type +} + +// Type returns the underlying type itself. +func (b bigIntLikeTypeInfo) Type() Type { + return b.typ +} + +// Zero returns the zero value for the bigint-like CQL type. +func (bigIntLikeTypeInfo) Zero() interface{} { + return int64(0) +} + +// Marshal marshals the value into a byte slice. +func (bigIntLikeTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case int: @@ -708,7 +750,7 @@ func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { } return encBigInt(int64(v)), nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: big.Int, Marshaler, int8, uint8, int16, uint16, int, uint, int32, uint32, int64, uint64, string, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into bigint. Accepted types: big.Int, int8, uint8, int16, uint16, int, uint, int32, uint32, int64, uint64, string, UnsetValue.", value) } func encBigInt(x int64) []byte { @@ -730,45 +772,33 @@ func bytesToUint64(data []byte) (ret uint64) { return ret } -func unmarshalBigInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, decBigInt(data), data, value) -} - -func unmarshalInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decInt(data)), data, value) +// Unmarshal unmarshals the byte slice into the value. +func (b bigIntLikeTypeInfo) Unmarshal(data []byte, value interface{}) error { + if iptr, ok := value.(*interface{}); ok && iptr != nil { + var v int64 + if err := unmarshalIntlike(b.typ, decBigInt(data), data, &v); err != nil { + return err + } + *iptr = v + return nil + } + return unmarshalIntlike(b.typ, decBigInt(data), data, value) } -func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decShort(data)), data, value) -} +type varintTypeInfo struct{} -func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decTiny(data)), data, value) +// Type returns the type itself. +func (varintTypeInfo) Type() Type { + return TypeVarint } -func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case *big.Int: - return unmarshalIntlike(info, 0, data, value) - case *uint64: - if len(data) == 9 && data[0] == 0 { - *v = bytesToUint64(data[1:]) - return nil - } - } - - if len(data) > 8 { - return unmarshalErrorf("unmarshal int: varint value %v out of range for %T (use big.Int)", data, value) - } - - int64Val := bytesToInt64(data) - if len(data) > 0 && len(data) < 8 && data[0]&0x80 > 0 { - int64Val -= (1 << uint(len(data)*8)) - } - return unmarshalIntlike(info, int64Val, data, value) +// Zero returns the zero value for the varint CQL type. +func (varintTypeInfo) Zero() interface{} { + return new(big.Int) } -func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { +// Marshal marshals the value into a byte slice. +func (varintTypeInfo) Marshal(value interface{}) ([]byte, error) { var ( retBytes []byte err error @@ -788,7 +818,7 @@ func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { case big.Int: retBytes = encBigInt2C(&v) default: - retBytes, err = marshalBigInt(info, value) + retBytes, err = (bigIntLikeTypeInfo{}).Marshal(value) } if err == nil { @@ -821,7 +851,37 @@ func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { return retBytes, err } -func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (varintTypeInfo) Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case *big.Int: + return unmarshalIntlike(TypeVarint, 0, data, value) + case *uint64: + if len(data) == 9 && data[0] == 0 { + *v = bytesToUint64(data[1:]) + return nil + } + case *interface{}: + var bi big.Int + if err := unmarshalIntlike(TypeVarint, 0, data, &bi); err != nil { + return err + } + *v = &bi + return nil + } + + if len(data) > 8 { + return unmarshalErrorf("unmarshal int: varint value %v out of range for %T (use big.Int)", data, value) + } + + int64Val := bytesToInt64(data) + if len(data) > 0 && len(data) < 8 && data[0]&0x80 > 0 { + int64Val -= (1 << uint(len(data)*8)) + } + return unmarshalIntlike(TypeVarint, int64Val, data, value) +} + +func unmarshalIntlike(typ Type, int64Val int64, data []byte, value interface{}) error { switch v := value.(type) { case *int: if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { @@ -831,7 +891,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac return nil case *uint: unitVal := uint64(int64Val) - switch info.Type() { + switch typ { case TypeInt: *v = uint(unitVal) & 0xFFFFFFFF case TypeSmallInt: @@ -849,7 +909,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac *v = int64Val return nil case *uint64: - switch info.Type() { + switch typ { case TypeInt: *v = uint64(int64Val) & 0xFFFFFFFF case TypeSmallInt: @@ -867,7 +927,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac *v = int32(int64Val) return nil case *uint32: - switch info.Type() { + switch typ { case TypeInt: *v = uint32(int64Val) & 0xFFFFFFFF case TypeSmallInt: @@ -888,7 +948,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac *v = int16(int64Val) return nil case *uint16: - switch info.Type() { + switch typ { case TypeSmallInt: *v = uint16(int64Val) & 0xFFFF case TypeTinyInt: @@ -907,7 +967,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac *v = int8(int64Val) return nil case *uint8: - if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { + if typ != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) } *v = uint8(int64Val) & 0xFF @@ -956,7 +1016,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac return nil case reflect.Uint: unitVal := uint64(int64Val) - switch info.Type() { + switch typ { case TypeInt: rv.SetUint(unitVal & 0xFFFFFFFF) case TypeSmallInt: @@ -972,7 +1032,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac return nil case reflect.Uint64: unitVal := uint64(int64Val) - switch info.Type() { + switch typ { case TypeInt: rv.SetUint(unitVal & 0xFFFFFFFF) case TypeSmallInt: @@ -985,7 +1045,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac return nil case reflect.Uint32: unitVal := uint64(int64Val) - switch info.Type() { + switch typ { case TypeInt: rv.SetUint(unitVal & 0xFFFFFFFF) case TypeSmallInt: @@ -1001,7 +1061,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac return nil case reflect.Uint16: unitVal := uint64(int64Val) - switch info.Type() { + switch typ { case TypeSmallInt: rv.SetUint(unitVal & 0xFFFF) case TypeTinyInt: @@ -1014,13 +1074,13 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac } return nil case reflect.Uint8: - if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { + if typ != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) } rv.SetUint(uint64(int64Val) & 0xff) return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: big.Int, Marshaler, int8, uint8, int16, uint16, int, uint, int32, uint32, int64, uint64, string.", info, value) + return unmarshalErrorf("can not unmarshal int-like into %T. Accepted types: big.Int, int8, uint8, int16, uint16, int, uint, int32, uint32, int64, uint64, string, *interface{}.", value) } func decBigInt(data []byte) int64 { @@ -1033,10 +1093,23 @@ func decBigInt(data []byte) int64 { int64(data[6])<<8 | int64(data[7]) } -func marshalBool(info TypeInfo, value interface{}) ([]byte, error) { +type booleanTypeInfo struct{} + +// Type returns the type itself. +func (booleanTypeInfo) Type() Type { + return TypeBoolean +} + +// Zero returns the zero value for the boolean CQL type. +func (booleanTypeInfo) Zero() interface{} { + return false +} + +// Marshal marshals the value into a byte slice. +func (b booleanTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: - return v.MarshalCQL(info) + return v.MarshalCQL(b) case unsetColumn: return nil, nil case bool: @@ -1052,23 +1125,18 @@ func marshalBool(info TypeInfo, value interface{}) ([]byte, error) { case reflect.Bool: return encBool(rv.Bool()), nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, bool, UnsetValue.", value, info) -} - -func encBool(v bool) []byte { - if v { - return []byte{1} - } - return []byte{0} + return nil, marshalErrorf("can not marshal %T into boolean. Accepted types: bool, UnsetValue.", value) } -func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (b booleanTypeInfo) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *bool: *v = decBool(data) return nil + case *interface{}: + *v = decBool(data) + return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { @@ -1080,7 +1148,14 @@ func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { rv.SetBool(decBool(data)) return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, *bool.", info, value) + return unmarshalErrorf("can not unmarshal boolean into %T. Accepted types: *bool, *interface{}.", value) +} + +func encBool(v bool) []byte { + if v { + return []byte{1} + } + return []byte{0} } func decBool(v []byte) bool { @@ -1090,10 +1165,21 @@ func decBool(v []byte) bool { return v[0] != 0 } -func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { +type floatTypeInfo struct{} + +// Type returns the type itself. +func (floatTypeInfo) Type() Type { + return TypeFloat +} + +// Zero returns the zero value for the float CQL type. +func (floatTypeInfo) Zero() interface{} { + return float32(0) +} + +// Marshal marshals the value into a byte slice. +func (floatTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case float32: @@ -1109,16 +1195,18 @@ func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { case reflect.Float32: return encInt(int32(math.Float32bits(float32(rv.Float())))), nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, float32, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into float. Accepted types: Marshaler, float32, UnsetValue.", value) } -func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (floatTypeInfo) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *float32: *v = math.Float32frombits(uint32(decInt(data))) return nil + case *interface{}: + *v = math.Float32frombits(uint32(decInt(data))) + return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { @@ -1130,13 +1218,24 @@ func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data))))) return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, *float32, UnsetValue.", info, value) + return unmarshalErrorf("can not unmarshal float into %T. Accepted types: *float32, *interface{}, UnsetValue.", value) +} + +type doubleTypeInfo struct{} + +// Type returns the type itself. +func (doubleTypeInfo) Type() Type { + return TypeDouble +} + +// Zero returns the zero value for the double CQL type. +func (doubleTypeInfo) Zero() interface{} { + return float64(0) } -func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { +// Marshal marshals the value into a byte slice. +func (doubleTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case float64: @@ -1150,16 +1249,18 @@ func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { case reflect.Float64: return encBigInt(int64(math.Float64bits(rv.Float()))), nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, float64, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into double. Accepted types: Marshaler, float64, UnsetValue.", value) } -func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (doubleTypeInfo) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *float64: *v = math.Float64frombits(uint64(decBigInt(data))) return nil + case *interface{}: + *v = math.Float64frombits(uint64(decBigInt(data))) + return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { @@ -1171,23 +1272,34 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { rv.SetFloat(math.Float64frombits(uint64(decBigInt(data)))) return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, *float64.", info, value) + return unmarshalErrorf("can not unmarshal double into %T. Accepted types: *float64, *interface{}.", value) } -func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { +type decimalTypeInfo struct{} + +// Type returns the type itself. +func (decimalTypeInfo) Type() Type { + return TypeDecimal +} + +// Zero returns the zero value for the decimal CQL type. +func (decimalTypeInfo) Zero() interface{} { + return new(inf.Dec) +} + +// Marshal marshals the value into a byte slice. +func (decimalTypeInfo) Marshal(value interface{}) ([]byte, error) { if value == nil { return nil, nil } switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case inf.Dec: unscaled := encBigInt2C(v.UnscaledBig()) if unscaled == nil { - return nil, marshalErrorf("can not marshal %T into %s", value, info) + return nil, marshalErrorf("can not marshal %T into decimal", value) } buf := make([]byte, 4+len(unscaled)) @@ -1195,13 +1307,12 @@ func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { copy(buf[4:], unscaled) return buf, nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, inf.Dec, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into decimal. Accepted types: inf.Dec, UnsetValue.", value) } -func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (decimalTypeInfo) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *inf.Dec: if len(data) < 4 { return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) @@ -1210,8 +1321,16 @@ func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { unscaled := decBigInt2C(data[4:], nil) *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) return nil + case *interface{}: + if len(data) < 4 { + return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) + } + scale := decInt(data[0:4]) + unscaled := decBigInt2C(data[4:], nil) + *v = inf.NewDecBig(unscaled, inf.Scale(scale)) + return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, *inf.Dec.", info, value) + return unmarshalErrorf("can not unmarshal decimal into %T. Accepted types: *inf.Dec, *interface{}.", value) } // decBigInt2C sets the value of n to the big-endian two's complement @@ -1254,34 +1373,21 @@ func encBigInt2C(n *big.Int) []byte { return nil } -func marshalTime(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encBigInt(v), nil - case time.Duration: - return encBigInt(v.Nanoseconds()), nil - } +type timestampTypeInfo struct{} - if value == nil { - return nil, nil - } +// Type returns the type itself. +func (timestampTypeInfo) Type() Type { + return TypeTimestamp +} - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encBigInt(rv.Int()), nil - } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, int64, time.Duration, UnsetValue.", value, info) +// Zero returns the zero value for the timestamp CQL type. +func (timestampTypeInfo) Zero() interface{} { + return time.Time{} } -func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { +// Marshal marshals the value into a byte slice. +func (timestampTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case int64: @@ -1303,38 +1409,12 @@ func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { case reflect.Int64: return encBigInt(rv.Int()), nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, int64, time.Time, UnsetValue.", value, info) -} - -func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *int64: - *v = decBigInt(data) - return nil - case *time.Duration: - *v = time.Duration(decBigInt(data)) - return nil - } - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - switch rv.Type().Kind() { - case reflect.Int64: - rv.SetInt(decBigInt(data)) - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, *int64, *time.Duration.", info, value) + return nil, marshalErrorf("can not marshal %T into timestamp. Accepted types: int64, time.Time, UnsetValue.", value) } -func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (timestampTypeInfo) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *int64: *v = decBigInt(data) return nil @@ -1348,6 +1428,16 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { nsec := (x - sec*1000) * 1000000 *v = time.Unix(sec, nsec).In(time.UTC) return nil + case *interface{}: + if len(data) == 0 { + *v = time.Time{} + return nil + } + x := decBigInt(data) + sec := x / 1000 + nsec := (x - sec*1000) * 1000000 + *v = time.Unix(sec, nsec).In(time.UTC) + return nil } rv := reflect.ValueOf(value) @@ -1360,16 +1450,89 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { rv.SetInt(decBigInt(data)) return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, *int64, *time.Time.", info, value) + return unmarshalErrorf("can not unmarshal timestamp into %T. Accepted types: *int64, *time.Time, *interface{}.", value) } -const millisecondsInADay int64 = 24 * 60 * 60 * 1000 +type timeTypeInfo struct{} -func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { - var timestamp int64 - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) +// Type returns the type itself. +func (timeTypeInfo) Type() Type { + return TypeTime +} + +// Zero returns the zero value for the time CQL type. +func (timeTypeInfo) Zero() interface{} { + return time.Duration(0) +} + +// Marshal marshals the value into a byte slice. +func (timeTypeInfo) Marshal(value interface{}) ([]byte, error) { + switch v := value.(type) { + case unsetColumn: + return nil, nil + case int64: + return encBigInt(v), nil + case time.Duration: + return encBigInt(v.Nanoseconds()), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return encBigInt(rv.Int()), nil + } + return nil, marshalErrorf("can not marshal %T into time. Accepted types: int64, time.Duration, UnsetValue.", value) +} + +// Unmarshal unmarshals the byte slice into the value. +func (timeTypeInfo) Unmarshal(data []byte, value interface{}) error { + switch v := value.(type) { + case *int64: + *v = decBigInt(data) + return nil + case *time.Duration: + *v = time.Duration(decBigInt(data)) + return nil + case *interface{}: + *v = time.Duration(decBigInt(data)) + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Int64: + rv.SetInt(decBigInt(data)) + return nil + } + return unmarshalErrorf("can not unmarshal time into %T. Accepted types: *int64, *time.Duration, *interface{}.", value) +} + +type dateTypeInfo struct{} + +// Type returns the type itself. +func (dateTypeInfo) Type() Type { + return TypeDate +} + +// Zero returns the zero value for the date CQL type. +func (dateTypeInfo) Zero() interface{} { + return time.Time{} +} + +const millisecondsInADay int64 = 24 * 60 * 60 * 1000 + +// Marshal marshals the value into a byte slice. +func (dateTypeInfo) Marshal(value interface{}) ([]byte, error) { + var timestamp int64 + switch v := value.(type) { case unsetColumn: return nil, nil case int64: @@ -1396,7 +1559,7 @@ func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { } t, err := time.Parse("2006-01-02", v) if err != nil { - return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) + return nil, marshalErrorf("can not marshal %T into date, date layout must be '2006-01-02'", value) } timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) x := timestamp/millisecondsInADay + int64(1<<31) @@ -1406,13 +1569,12 @@ func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { if value == nil { return nil, nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, int64, time.Time, *time.Time, string, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into date. Accepted types: int64, time.Time, *time.Time, string, UnsetValue.", value) } -func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (dateTypeInfo) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *time.Time: if len(data) == 0 { *v = time.Time{} @@ -1423,6 +1585,16 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { timestamp := (int64(current) - int64(origin)) * millisecondsInADay *v = time.UnixMilli(timestamp).In(time.UTC) return nil + case *interface{}: + if len(data) == 0 { + *v = time.Time{} + return nil + } + var origin uint32 = 1 << 31 + var current uint32 = binary.BigEndian.Uint32(data) + timestamp := (int64(current) - int64(origin)) * millisecondsInADay + *v = time.UnixMilli(timestamp).In(time.UTC) + return nil case *string: if len(data) == 0 { *v = "" @@ -1434,13 +1606,24 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, *time.Time, *string.", info, value) + return unmarshalErrorf("can not unmarshal date into %T. Accepted types: *time.Time, *interface{}, *string.", value) +} + +type durationTypeInfo struct{} + +// Type returns the type itself. +func (durationTypeInfo) Type() Type { + return TypeDuration +} + +// Zero returns the zero value for the duration CQL type. +func (durationTypeInfo) Zero() interface{} { + return Duration{} } -func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { +// Marshal marshals the value into a byte slice. +func (durationTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, nil case int64: @@ -1466,13 +1649,12 @@ func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { case reflect.Int64: return encBigInt(rv.Int()), nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: Marshaler, int64, time.Duration, string, Duration, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into duration. Accepted types: int64, time.Duration, string, Duration, UnsetValue.", value) } -func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (durationTypeInfo) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *Duration: if len(data) == 0 { *v = Duration{ @@ -1484,7 +1666,26 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { } months, days, nanos, err := decVints(data) if err != nil { - return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) + return unmarshalErrorf("failed to unmarshal duration into %T: %s", value, err.Error()) + } + *v = Duration{ + Months: months, + Days: days, + Nanoseconds: nanos, + } + return nil + case *interface{}: + if len(data) == 0 { + *v = Duration{ + Months: 0, + Days: 0, + Nanoseconds: 0, + } + return nil + } + months, days, nanos, err := decVints(data) + if err != nil { + return unmarshalErrorf("failed to unmarshal duration into %T: %s", value, err.Error()) } *v = Duration{ Months: months, @@ -1493,7 +1694,7 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { } return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, *Duration.", info, value) + return unmarshalErrorf("can not unmarshal duration into %T. Accepted types: *Duration, *interface{}.", value) } func decVints(data []byte) (int32, int32, int64, error) { @@ -1564,24 +1765,127 @@ func encVint(v int64) []byte { return buf } +type listSetCQLType struct { + typ Type + types *RegisteredTypes +} + +// Params returns the types to build the slice of params for TypeInfoFromParams. +func (listSetCQLType) Params(proto int) []interface{} { + return []interface{}{ + (*TypeInfo)(nil), + } +} + +// TypeInfoFromParams builds a TypeInfo implementation for the composite type with +// the given parameters. +func (t listSetCQLType) TypeInfoFromParams(proto int, params []interface{}) (TypeInfo, error) { + if len(params) != 1 { + return nil, fmt.Errorf("expected 1 param for list/set, got %d", len(params)) + } + elem, ok := params[0].(TypeInfo) + if !ok { + return nil, fmt.Errorf("expected TypeInfo for list/set, got %T", params[0]) + } + return CollectionType{ + typ: t.typ, + Elem: elem, + }, nil +} + +// TypeInfoFromString builds a TypeInfo implementation for the composite type with +// the given names/classes. Only the portion within the parantheses or arrows +// are passed to this function. +func (t listSetCQLType) TypeInfoFromString(proto int, name string) (TypeInfo, error) { + elem, err := t.types.typeInfoFromString(proto, name) + if err != nil { + return nil, err + } + return CollectionType{ + typ: t.typ, + Elem: elem, + }, nil +} + +type CollectionType struct { + typ Type + Key TypeInfo // only used for TypeMap + Elem TypeInfo // only used for TypeMap, TypeList and TypeSet +} + +// Type returns the type of the collection. +func (c CollectionType) Type() Type { + return c.typ +} + +func (c CollectionType) zeroType() reflect.Type { + switch c.typ { + case TypeMap: + return reflect.MapOf(reflect.TypeOf(c.Key.Zero()), reflect.TypeOf(c.Elem.Zero())) + case TypeList, TypeSet: + return reflect.SliceOf(reflect.TypeOf(c.Elem.Zero())) + default: + // we should never have any other types + panic(fmt.Errorf("unsupported type for CollectionType: %d", c.typ)) + } +} + +// Zero returns the zero value for the collection CQL type. +func (c CollectionType) Zero() interface{} { + return reflect.Zero(c.zeroType()).Interface() +} + +// String returns the string representation of the collection. +func (c CollectionType) String() string { + switch c.typ { + case TypeMap: + return fmt.Sprintf("map(%s, %s)", c.Key, c.Elem) + case TypeList: + return fmt.Sprintf("list(%s)", c.Elem) + case TypeSet: + return fmt.Sprintf("set(%s)", c.Elem) + default: + return "unknown" + } +} + +// Marshal marshals the value into a byte slice. +func (c CollectionType) Marshal(value interface{}) ([]byte, error) { + switch c.typ { + case TypeMap: + return c.marshalMap(value) + case TypeList, TypeSet: + return c.marshalListSet(value) + } + return nil, marshalErrorf("unsupported collection type: %s. Accepted types: map, list, set.", c.String()) +} + +// Unmarshal unmarshals the byte slice into the value. +func (c CollectionType) Unmarshal(data []byte, value interface{}) error { + switch c.typ { + case TypeMap: + return c.unmarshalMap(data, value) + case TypeList, TypeSet: + return c.unmarshalListSet(data, value) + } + return unmarshalErrorf("unsupported collection type: %s. Accepted types: map, list, set.", c.String()) +} + func writeCollectionSize(n int, buf *bytes.Buffer) error { if n > math.MaxInt32 { return marshalErrorf("marshal: collection too large") } - buf.WriteByte(byte(n >> 24)) - buf.WriteByte(byte(n >> 16)) - buf.WriteByte(byte(n >> 8)) - buf.WriteByte(byte(n)) - return nil + _, err := buf.Write([]byte{ + byte(n >> 24), + byte(n >> 16), + byte(n >> 8), + byte(n), + }) + return err } -func marshalList(info TypeInfo, value interface{}) ([]byte, error) { - listInfo, ok := info.(CollectionType) - if !ok { - return nil, marshalErrorf("marshal: can not marshal non collection type into list") - } - +func (l CollectionType) marshalListSet(value interface{}) ([]byte, error) { if value == nil { return nil, nil } else if _, ok := value.(unsetColumn); ok { @@ -1605,7 +1909,7 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) { } for i := 0; i < n; i++ { - item, err := Marshal(listInfo.Elem, rv.Index(i).Interface()) + item, err := Marshal(l.Elem, rv.Index(i).Interface()) if err != nil { return nil, err } @@ -1628,35 +1932,36 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) { for i := 0; i < len(keys); i++ { keys[i] = rkeys[i].Interface() } - return marshalList(listInfo, keys) + return l.Marshal(keys) } } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: slice, array, map[]struct.", value, info) + return nil, marshalErrorf("can not marshal %T into collection. Accepted types: slice, array, map[]struct.", value) } -func readCollectionSize(data []byte) (size, read int, err error) { +func readCollectionSize(data []byte) (int, int, error) { if len(data) < 4 { return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") } - size = int(int32(data[0])<<24 | int32(data[1])<<16 | int32(data[2])<<8 | int32(data[3])) - read = 4 - return + return int(int32(data[0])<<24 | int32(data[1])<<16 | int32(data[2])<<8 | int32(data[3])), + 4, + nil } -func unmarshalList(info TypeInfo, data []byte, value interface{}) error { - listInfo, ok := info.(CollectionType) - if !ok { - return unmarshalErrorf("unmarshal: can not unmarshal none collection type into list") - } - +func (c CollectionType) unmarshalListSet(data []byte, value interface{}) error { rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() t := rv.Type() - k := t.Kind() + if t.Kind() == reflect.Interface { + if t.NumMethod() != 0 { + return unmarshalErrorf("can not unmarshal into non-empty interface %T", value) + } + t = c.zeroType() + } + k := t.Kind() switch k { case reflect.Slice, reflect.Array: if data == nil { @@ -1680,6 +1985,9 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error { } } else { rv.Set(reflect.MakeSlice(t, n, n)) + if rv.Kind() == reflect.Interface { + rv = rv.Elem() + } } for i := 0; i < n; i++ { m, p, err := readCollectionSize(data) @@ -1696,180 +2004,72 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error { unmarshalData = data[:m] data = data[m:] } - if err := Unmarshal(listInfo.Elem, unmarshalData, rv.Index(i).Addr().Interface()); err != nil { + if err := Unmarshal(c.Elem, unmarshalData, rv.Index(i).Addr().Interface()); err != nil { return err } } return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: *slice, *array.", info, value) + return unmarshalErrorf("can not unmarshal collection into %T. Accepted types: *slice, *array.", value) } -func marshalVector(info VectorType, value interface{}) ([]byte, error) { - if value == nil { - return nil, nil - } else if _, ok := value.(unsetColumn); ok { - return nil, nil - } - - rv := reflect.ValueOf(value) - t := rv.Type() - k := t.Kind() - if k == reflect.Slice && rv.IsNil() { - return nil, nil - } - - switch k { - case reflect.Slice, reflect.Array: - buf := &bytes.Buffer{} - n := rv.Len() - if n != info.Dimensions { - return nil, marshalErrorf("expected vector with %d dimensions, received %d", info.Dimensions, n) - } - - for i := 0; i < n; i++ { - item, err := Marshal(info.SubType, rv.Index(i).Interface()) - if err != nil { - return nil, err - } - if isVectorVariableLengthType(info.SubType) { - writeUnsignedVInt(buf, uint64(len(item))) - } - buf.Write(item) - } - return buf.Bytes(), nil - } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: slice, array.", value, info) +type mapCQLType struct { + types *RegisteredTypes } -func unmarshalVector(info VectorType, data []byte, value interface{}) error { - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) +// Params returns the types to build the slice of params for TypeInfoFromParams. +func (mapCQLType) Params(proto int) []interface{} { + return []interface{}{ + (*TypeInfo)(nil), + (*TypeInfo)(nil), } - rv = rv.Elem() - t := rv.Type() - k := t.Kind() - switch k { - case reflect.Slice, reflect.Array: - if data == nil { - if k == reflect.Array { - return unmarshalErrorf("unmarshal vector: can not store nil in array value") - } - if rv.IsNil() { - return nil - } - rv.Set(reflect.Zero(t)) - return nil - } - if k == reflect.Array { - if rv.Len() != info.Dimensions { - return unmarshalErrorf("unmarshal vector: array of size %d cannot store vector of %d dimensions", rv.Len(), info.Dimensions) - } - } else { - rv.Set(reflect.MakeSlice(t, info.Dimensions, info.Dimensions)) - } - elemSize := len(data) / info.Dimensions - for i := 0; i < info.Dimensions; i++ { - offset := 0 - if isVectorVariableLengthType(info.SubType) { - m, p, err := readUnsignedVInt(data) - if err != nil { - return err - } - elemSize = int(m) - offset = p - } - if offset > 0 { - data = data[offset:] - } - var unmarshalData []byte - if elemSize >= 0 { - if len(data) < elemSize { - return unmarshalErrorf("unmarshal vector: unexpected eof") - } - unmarshalData = data[:elemSize] - data = data[elemSize:] - } - err := Unmarshal(info.SubType, unmarshalData, rv.Index(i).Addr().Interface()) - if err != nil { - return unmarshalErrorf("failed to unmarshal %s into %T: %s", info.SubType, unmarshalData, err.Error()) - } - } - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: slice, array.", info, value) } -// isVectorVariableLengthType determines if a type requires explicit length serialization within a vector. -// Variable-length types need their length encoded before the actual data to allow proper deserialization. -// Fixed-length types, on the other hand, don't require this kind of length prefix. -func isVectorVariableLengthType(elemType TypeInfo) bool { - switch elemType.Type() { - case TypeVarchar, TypeAscii, TypeBlob, TypeText, - TypeCounter, - TypeDuration, TypeDate, TypeTime, - TypeDecimal, TypeSmallInt, TypeTinyInt, TypeVarint, - TypeInet, - TypeList, TypeSet, TypeMap, TypeUDT, TypeTuple: - return true - case TypeCustom: - if vecType, ok := elemType.(VectorType); ok { - return isVectorVariableLengthType(vecType.SubType) - } - return true +// TypeInfoFromParams builds a TypeInfo implementation for the composite type with +// the given parameters. +func (mapCQLType) TypeInfoFromParams(proto int, params []interface{}) (TypeInfo, error) { + if len(params) != 2 { + return nil, fmt.Errorf("expected 2 param for map, got %d", len(params)) } - return false -} - -func writeUnsignedVInt(buf *bytes.Buffer, v uint64) { - numBytes := computeUnsignedVIntSize(v) - if numBytes <= 1 { - buf.WriteByte(byte(v)) - return + key, ok := params[0].(TypeInfo) + if !ok { + return nil, fmt.Errorf("expected TypeInfo for map, got %T", params[0]) } - - extraBytes := numBytes - 1 - var tmp = make([]byte, numBytes) - for i := extraBytes; i >= 0; i-- { - tmp[i] = byte(v) - v >>= 8 + elem, ok := params[1].(TypeInfo) + if !ok { + return nil, fmt.Errorf("expected TypeInfo for map, got %T", params[1]) } - tmp[0] |= byte(^(0xff >> uint(extraBytes))) - buf.Write(tmp) + return CollectionType{ + typ: TypeMap, + Key: key, + Elem: elem, + }, nil } -func readUnsignedVInt(data []byte) (uint64, int, error) { - if len(data) <= 0 { - return 0, 0, errors.New("unexpected eof") - } - firstByte := data[0] - if firstByte&0x80 == 0 { - return uint64(firstByte), 1, nil +// TypeInfoFromString builds a TypeInfo implementation for the composite type with +// the given names/classes. Only the portion within the parantheses or arrows +// are passed to this function. +func (m mapCQLType) TypeInfoFromString(proto int, name string) (TypeInfo, error) { + names := splitCompositeTypes(name) + if len(names) != 2 { + return nil, fmt.Errorf("expected 2 elements for map, got %v", names) } - numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 - ret := uint64(firstByte & (0xff >> uint(numBytes))) - if len(data) < numBytes+1 { - return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", numBytes+1, len(data)) + kt, err := m.types.typeInfoFromString(proto, names[0]) + if err != nil { + return nil, err } - for i := 0; i < numBytes; i++ { - ret <<= 8 - ret |= uint64(data[i+1] & 0xff) + et, err := m.types.typeInfoFromString(proto, names[1]) + if err != nil { + return nil, err } - return ret, numBytes + 1, nil + return CollectionType{ + typ: TypeMap, + Key: kt, + Elem: et, + }, nil } -func computeUnsignedVIntSize(v uint64) int { - lead0 := bits.LeadingZeros64(v) - return (639 - lead0*9) >> 6 -} - -func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { - mapInfo, ok := info.(CollectionType) - if !ok { - return nil, marshalErrorf("marshal: can not marshal none collection type into map") - } - +func (c CollectionType) marshalMap(value interface{}) ([]byte, error) { if value == nil { return nil, nil } else if _, ok := value.(unsetColumn); ok { @@ -1880,7 +2080,7 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { t := rv.Type() if t.Kind() != reflect.Map { - return nil, marshalErrorf("can not marshal %T into %s", value, info) + return nil, marshalErrorf("can not marshal %T into map", value) } if rv.IsNil() { @@ -1895,8 +2095,8 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { } keys := rv.MapKeys() - for _, key := range keys { - item, err := Marshal(mapInfo.Key, key.Interface()) + for i := range keys { + item, err := Marshal(c.Key, keys[i].Interface()) if err != nil { return nil, err } @@ -1910,7 +2110,7 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { } buf.Write(item) - item, err = Marshal(mapInfo.Elem, rv.MapIndex(key).Interface()) + item, err = Marshal(c.Elem, rv.MapIndex(keys[i]).Interface()) if err != nil { return nil, err } @@ -1927,20 +2127,20 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { return buf.Bytes(), nil } -func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { - mapInfo, ok := info.(CollectionType) - if !ok { - return unmarshalErrorf("unmarshal: can not unmarshal none collection type into map") - } - +func (c CollectionType) unmarshalMap(data []byte, value interface{}) error { rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + return unmarshalErrorf("can not unmarshal map into non-pointer %T", value) } rv = rv.Elem() t := rv.Type() - if t.Kind() != reflect.Map { - return unmarshalErrorf("can not unmarshal %s into %T", info, value) + if t.Kind() == reflect.Interface { + if t.NumMethod() != 0 { + return unmarshalErrorf("can not unmarshal map into non-empty interface %T", value) + } + t = c.zeroType() + } else if t.Kind() != reflect.Map { + return unmarshalErrorf("can not unmarshal map into %T", value) } if data == nil { rv.Set(reflect.Zero(t)) @@ -1954,6 +2154,9 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("negative map size %d", n) } rv.Set(reflect.MakeMapWithSize(t, n)) + if rv.Kind() == reflect.Interface { + rv = rv.Elem() + } data = data[p:] for i := 0; i < n; i++ { m, p, err := readCollectionSize(data) @@ -1971,7 +2174,7 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { unmarshalData = data[:m] data = data[m:] } - if err := Unmarshal(mapInfo.Key, unmarshalData, key.Interface()); err != nil { + if err := Unmarshal(c.Key, unmarshalData, key.Interface()); err != nil { return err } @@ -1991,7 +2194,7 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { unmarshalData = data[:m] data = data[m:] } - if err := Unmarshal(mapInfo.Elem, unmarshalData, val.Interface()); err != nil { + if err := Unmarshal(c.Elem, unmarshalData, val.Interface()); err != nil { return err } @@ -2000,7 +2203,24 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { return nil } -func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { +type uuidType struct{} + +// Type returns the type itself. +func (uuidType) Type() Type { + return TypeUUID +} + +// Zero returns the zero value for the uuid CQL type. +func (uuidType) Zero() interface{} { + return UUID{} +} + +// Marshal marshals the value into a byte slice. +func (uuidType) Marshal(value interface{}) ([]byte, error) { + return uuidMarshal("UUID", value) +} + +func uuidMarshal(kind string, value interface{}) ([]byte, error) { switch val := value.(type) { case unsetColumn: return nil, nil @@ -2010,7 +2230,7 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { return val[:], nil case []byte: if len(val) != 16 { - return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info) + return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), kind) } return val, nil case string: @@ -2025,10 +2245,15 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { return nil, nil } - return nil, marshalErrorf("can not marshal %T into %s. Accepted types: UUID, [16]byte, string, UnsetValue.", value, info) + return nil, marshalErrorf("can not marshal %T into %s. Accepted types: UUID, [16]byte, string, UnsetValue.", value, kind) +} + +func (uuidType) Unmarshal(data []byte, value interface{}) error { + return uuidUnmarshal("UUID", data, value) } -func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func uuidUnmarshal(kind string, data []byte, value interface{}) error { if len(data) == 0 { switch v := value.(type) { case *string: @@ -2037,15 +2262,17 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { *v = nil case *UUID: *v = UUID{} + case *interface{}: + *v = UUID{} default: - return unmarshalErrorf("can not unmarshal X %s into %T. Accepted types: *UUID, *[]byte, *string.", info, value) + return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: *UUID, *[]byte, *string, *interface{}.", kind, value) } return nil } if len(data) != 16 { - return unmarshalErrorf("unable to parse UUID: UUIDs must be exactly 16 bytes long") + return unmarshalErrorf("unable to parse %s: UUIDs must be exactly 16 bytes long", kind) } switch v := value.(type) { @@ -2055,11 +2282,16 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { case *UUID: copy((*v)[:], data) return nil + case *interface{}: + var u UUID + copy(u[:], data) + *v = u + return nil } u, err := UUIDFromBytes(data) if err != nil { - return unmarshalErrorf("unable to parse UUID: %s", err) + return unmarshalErrorf("unable to parse %s: %s", kind, err) } switch v := value.(type) { @@ -2070,13 +2302,33 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { *v = u[:] return nil } - return unmarshalErrorf("can not unmarshal X %s into %T. Accepted types: *UUID, *[]byte, *string.", info, value) + return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: *UUID, *[]byte, *string, *interface{}.", kind, value) +} + +type timeUUIDType struct{} + +// Type returns the type itself. +func (timeUUIDType) Type() Type { + return TypeTimeUUID } -func unmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { +// Zero returns the zero value for the timeuuid CQL type. +func (timeUUIDType) Zero() interface{} { + return UUID{} +} + +// Marshal marshals the value into a byte slice. +func (t timeUUIDType) Marshal(value interface{}) ([]byte, error) { + switch val := value.(type) { + case time.Time: + return UUIDFromTime(val).Bytes(), nil + } + return uuidMarshal("timeuuid", value) +} + +// Unmarshal unmarshals the byte slice into the value. +func (t timeUUIDType) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *time.Time: id, err := UUIDFromBytes(data) if err != nil { @@ -2087,11 +2339,24 @@ func unmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { *v = id.Time() return nil default: - return unmarshalUUID(info, data, value) + return uuidUnmarshal("timeuuid", data, value) } } -func marshalInet(info TypeInfo, value interface{}) ([]byte, error) { +type inetType struct{} + +// Type returns the type itself. +func (inetType) Type() Type { + return TypeInet +} + +// Zero returns the zero value for the inet CQL type. +func (inetType) Zero() interface{} { + return net.IP(nil) +} + +// Marshal marshals the value into a byte slice. +func (inetType) Marshal(value interface{}) ([]byte, error) { // we return either the 4 or 16 byte representation of an // ip address here otherwise the db value will be prefixed // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 @@ -2120,16 +2385,35 @@ func marshalInet(info TypeInfo, value interface{}) ([]byte, error) { return nil, nil } - return nil, marshalErrorf("cannot marshal %T into %s. Accepted types: net.IP, string.", value, info) + return nil, marshalErrorf("cannot marshal %T into inet. Accepted types: net.IP, string.", value) } -func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (inetType) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case *net.IP: + if len(data) == 0 { + *v = nil + return nil + } + if x := len(data); !(x == 4 || x == 16) { + return unmarshalErrorf("cannot unmarshal inet into %T: invalid sized IP: got %d bytes not 4 or 16", value, x) + } + buf := copyBytes(data) + ip := net.IP(buf) + if v4 := ip.To4(); v4 != nil { + *v = v4 + return nil + } + *v = ip + return nil + case *interface{}: + if len(data) == 0 { + *v = net.IP(nil) + return nil + } if x := len(data); !(x == 4 || x == 16) { - return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x) + return unmarshalErrorf("cannot unmarshal inet into %T: invalid sized IP: got %d bytes not 4 or 16", value, x) } buf := copyBytes(data) ip := net.IP(buf) @@ -2152,11 +2436,73 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { *v = ip.String() return nil } - return unmarshalErrorf("cannot unmarshal %s into %T. Accepted types: Unmarshaler, *net.IP, *string.", info, value) + return unmarshalErrorf("cannot unmarshal inet into %T. Accepted types: *net.IP, *string, *interface{}.", value) +} + +type tupleCQLType struct { + types *RegisteredTypes +} + +// Params returns the types to build the slice of params for TypeInfoFromParams. +func (tupleCQLType) Params(proto int) []interface{} { + return []interface{}{ + []TypeInfo(nil), + } +} + +// TypeInfoFromParams builds a TypeInfo implementation for the composite type with +// the given parameters. +func (tupleCQLType) TypeInfoFromParams(proto int, params []interface{}) (TypeInfo, error) { + if len(params) != 1 { + return nil, fmt.Errorf("expected 1 param for tuple, got %d", len(params)) + } + elems, ok := params[0].([]TypeInfo) + if !ok { + return nil, fmt.Errorf("expected []TypeInfo for tuple, got %T", params[0]) + } + return TupleTypeInfo{ + Elems: elems, + }, nil +} + +// TypeInfoFromString builds a TypeInfo implementation for the composite type with +// the given names/classes. Only the portion within the parantheses or arrows +// are passed to this function. +func (t tupleCQLType) TypeInfoFromString(proto int, name string) (TypeInfo, error) { + names := splitCompositeTypes(name) + types := make([]TypeInfo, len(names)) + var err error + for i, name := range names { + types[i], err = t.types.typeInfoFromString(proto, name) + if err != nil { + return nil, err + } + } + return TupleTypeInfo{ + Elems: types, + }, nil +} + +// TODO: move to types.go +type TupleTypeInfo struct { + Elems []TypeInfo +} + +func (TupleTypeInfo) Type() Type { + return TypeTuple +} + +// Zero returns the zero value for the tuple CQL type. +func (t TupleTypeInfo) Zero() interface{} { + s := make([]interface{}, len(t.Elems), len(t.Elems)) + for i := range s { + s[i] = t.Elems[i].Zero() + } + return s } -func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { - tuple := info.(TupleTypeInfo) +// Marshal marshals the value into a byte slice. +func (tuple TupleTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { case unsetColumn: return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples") @@ -2166,13 +2512,13 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { } var buf []byte - for i, elem := range v { - if elem == nil { + for i := range v { + if v[i] == nil { buf = appendInt(buf, int32(-1)) continue } - data, err := Marshal(tuple.Elems[i], elem) + data, err := Marshal(tuple.Elems[i], v[i]) if err != nil { return nil, err } @@ -2186,17 +2532,17 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { } rv := reflect.ValueOf(value) - t := rv.Type() - k := t.Kind() + typ := rv.Type() + k := typ.Kind() switch k { case reflect.Struct: - if v := t.NumField(); v != len(tuple.Elems) { - return nil, marshalErrorf("can not marshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) + if v := typ.NumField(); v != len(tuple.Elems) { + return nil, marshalErrorf("can not marshal tuple into struct %v, not enough fields have %d need %d", typ, v, len(tuple.Elems)) } var buf []byte - for i, elem := range tuple.Elems { + for i := range tuple.Elems { field := rv.Field(i) if field.Kind() == reflect.Ptr && field.IsNil() { @@ -2204,7 +2550,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { continue } - data, err := Marshal(elem, field.Interface()) + data, err := Marshal(tuple.Elems[i], field.Interface()) if err != nil { return nil, err } @@ -2222,7 +2568,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { } var buf []byte - for i, elem := range tuple.Elems { + for i := range tuple.Elems { item := rv.Index(i) if item.Kind() == reflect.Ptr && item.IsNil() { @@ -2230,7 +2576,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { continue } - data, err := Marshal(elem, item.Interface()) + data, err := Marshal(tuple.Elems[i], item.Interface()) if err != nil { return nil, err } @@ -2243,7 +2589,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { return buf, nil } - return nil, marshalErrorf("cannot marshal %T into %s. Accepted types: struct, []interface{}, array, slice, UnsetValue.", value, tuple) + return nil, marshalErrorf("cannot marshal %T into tuple. Accepted types: struct, []interface{}, array, slice, UnsetValue.", value) } func readBytes(p []byte) ([]byte, []byte) { @@ -2256,29 +2602,42 @@ func readBytes(p []byte) ([]byte, []byte) { return p[:size], p[size:] } +// Unmarshal unmarshals the byte slice into the value. // currently only support unmarshal into a list of values, this makes it possible // to support tuples without changing the query API. In the future this can be extend // to allow unmarshalling into custom tuple types. -func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { - if v, ok := value.(Unmarshaler); ok { - return v.UnmarshalCQL(info, data) - } - - tuple := info.(TupleTypeInfo) +func (tuple TupleTypeInfo) Unmarshal(data []byte, value interface{}) error { switch v := value.(type) { case []interface{}: - for i, elem := range tuple.Elems { + if len(v) != len(tuple.Elems) { + return unmarshalErrorf("can not unmarshal tuple into slice of length %d need %d elements", len(v), len(tuple.Elems)) + } + for i := range tuple.Elems { // each element inside data is a [bytes] var p []byte if len(data) >= 4 { p, data = readBytes(data) } - err := Unmarshal(elem, p, v[i]) + err := Unmarshal(tuple.Elems[i], p, v[i]) if err != nil { return err } } - + return nil + case *interface{}: + s := make([]interface{}, len(tuple.Elems)) + for i := range tuple.Elems { + // each element inside data is a [bytes] + var p []byte + if len(data) >= 4 { + p, data = readBytes(data) + } + err := Unmarshal(tuple.Elems[i], p, &s[i]) + if err != nil { + return err + } + } + *v = s return nil } @@ -2293,33 +2652,25 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { switch k { case reflect.Struct: + // TODO: should we ignore private fields? if v := t.NumField(); v != len(tuple.Elems) { return unmarshalErrorf("can not unmarshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) } - for i, elem := range tuple.Elems { + for i := range tuple.Elems { var p []byte if len(data) >= 4 { p, data = readBytes(data) } - v, err := elem.NewWithError() - if err != nil { - return err - } - if err := Unmarshal(elem, p, v); err != nil { - return err + // handle null data + if p == nil && rv.Field(i).Kind() == reflect.Ptr { + rv.Field(i).Set(reflect.Zero(rv.Field(i).Type())) + continue } - switch rv.Field(i).Kind() { - case reflect.Ptr: - if p != nil { - rv.Field(i).Set(reflect.ValueOf(v)) - } else { - rv.Field(i).Set(reflect.Zero(reflect.TypeOf(v))) - } - default: - rv.Field(i).Set(reflect.ValueOf(v).Elem()) + if err := Unmarshal(tuple.Elems[i], p, rv.Field(i).Addr().Interface()); err != nil { + return err } } @@ -2334,36 +2685,27 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { rv.Set(reflect.MakeSlice(t, len(tuple.Elems), len(tuple.Elems))) } - for i, elem := range tuple.Elems { + for i := range tuple.Elems { var p []byte if len(data) >= 4 { p, data = readBytes(data) } - v, err := elem.NewWithError() - if err != nil { - return err - } - if err := Unmarshal(elem, p, v); err != nil { - return err + // handle null data + if p == nil && rv.Index(i).Kind() == reflect.Ptr { + rv.Index(i).Set(reflect.Zero(rv.Index(i).Type())) + continue } - switch rv.Index(i).Kind() { - case reflect.Ptr: - if p != nil { - rv.Index(i).Set(reflect.ValueOf(v)) - } else { - rv.Index(i).Set(reflect.Zero(reflect.TypeOf(v))) - } - default: - rv.Index(i).Set(reflect.ValueOf(v).Elem()) + if err := Unmarshal(tuple.Elems[i], p, rv.Index(i).Addr().Interface()); err != nil { + return err } } return nil } - return unmarshalErrorf("cannot unmarshal %s into %T. Accepted types: *struct, []interface{}, *array, *slice, Unmarshaler.", info, value) + return unmarshalErrorf("cannot unmarshal tuple into %T. Accepted types: *struct, []interface{}, *array, *slice, *interface{}.", value) } // UDTMarshaler is an interface which should be implemented by users wishing to @@ -2385,18 +2727,131 @@ type UDTUnmarshaler interface { UnmarshalUDT(name string, info TypeInfo, data []byte) error } -func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { - udt := info.(UDTTypeInfo) +type udtCQLType struct { + types *RegisteredTypes +} + +// Params returns the types to build the slice of params for TypeInfoFromParams. +func (udtCQLType) Params(proto int) []interface{} { + return []interface{}{ + "", + "", + []UDTField(nil), + } +} + +// TypeInfoFromParams builds a TypeInfo implementation for the composite type with +// the given parameters. +func (udtCQLType) TypeInfoFromParams(proto int, params []interface{}) (TypeInfo, error) { + if len(params) != 3 { + return nil, fmt.Errorf("expected 3 param for udt, got %d", len(params)) + } + keyspace, ok := params[0].(string) + if !ok { + return nil, fmt.Errorf("expected string for udt, got %T", params[0]) + } + name, ok := params[1].(string) + if !ok { + return nil, fmt.Errorf("expected string for udt, got %T", params[1]) + } + elements, ok := params[2].([]UDTField) + if !ok { + return nil, fmt.Errorf("expected []UDTField for udt, got %T", params[2]) + } + return UDTTypeInfo{ + Keyspace: keyspace, + Name: name, + Elements: elements, + }, nil +} + +// TypeInfoFromString builds a TypeInfo implementation for the composite type with +// the given names/classes. Only the portion within the parantheses or arrows are +// passed to this function. +func (u udtCQLType) TypeInfoFromString(proto int, name string) (TypeInfo, error) { + parts := splitCompositeTypes(name) + // let's check to see if its java or not because if its java then we can get + // everything we need + if strings.Contains(name, ":") { + if len(parts) < 3 { + return nil, fmt.Errorf("expected 3 parts for udt, got %s", name) + } + // first is keyspace, second is hex(name), third is elements + name, _ := hex.DecodeString(parts[1]) + ti := UDTTypeInfo{ + Keyspace: parts[0], + Name: string(name), + } + ti.Elements = make([]UDTField, 0, len(parts)-2) + for i := 2; i < len(parts); i++ { + colonIdx := strings.Index(parts[i], ":") + var name string + var typ string + if colonIdx == -1 { + typ = parts[i] + } else { + // name is hex(name) + nameb, _ := hex.DecodeString(parts[i][:colonIdx]) + name = string(nameb) + if len(parts[i]) > colonIdx+1 { + typ = parts[i][colonIdx+1:] + } + } + et, err := u.types.typeInfoFromString(proto, typ) + if err != nil { + return nil, err + } + ti.Elements = append(ti.Elements, UDTField{ + Name: name, + Type: et, + }) + } + return ti, nil + } + // we can't get the name or anything so we'll just try to parse the elements + ti := UDTTypeInfo{} + ti.Elements = make([]UDTField, 0, len(parts)) + for _, part := range parts { + et, err := u.types.typeInfoFromString(proto, part) + if err != nil { + return nil, err + } + ti.Elements = append(ti.Elements, UDTField{ + Type: et, + }) + } + return ti, nil +} + +type UDTField struct { + Name string + Type TypeInfo +} + +type UDTTypeInfo struct { + Keyspace string + Name string + Elements []UDTField +} + +func (u UDTTypeInfo) Type() Type { + return TypeUDT +} + +// Zero returns the zero value for the UDT CQL type. +func (UDTTypeInfo) Zero() interface{} { + return map[string]interface{}(nil) +} +// Marshal marshals the value into a byte slice. +func (udt UDTTypeInfo) Marshal(value interface{}) ([]byte, error) { switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) case unsetColumn: return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types") case UDTMarshaler: var buf []byte - for _, e := range udt.Elements { - data, err := v.MarshalUDT(e.Name, e.Type) + for i := range udt.Elements { + data, err := v.MarshalUDT(udt.Elements[i].Name, udt.Elements[i].Type) if err != nil { return nil, err } @@ -2407,14 +2862,14 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { return buf, nil case map[string]interface{}: var buf []byte - for _, e := range udt.Elements { - val, ok := v[e.Name] + for i := range udt.Elements { + val, ok := v[udt.Elements[i].Name] var data []byte if ok { var err error - data, err = Marshal(e.Type, val) + data, err = Marshal(udt.Elements[i].Type, val) if err != nil { return nil, err } @@ -2429,13 +2884,13 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { k := reflect.ValueOf(value) if k.Kind() == reflect.Ptr { if k.IsNil() { - return nil, marshalErrorf("cannot marshal %T into %s", value, info) + return nil, marshalErrorf("cannot marshal %T into UDT", value) } k = k.Elem() } if k.Kind() != reflect.Struct || !k.IsValid() { - return nil, marshalErrorf("cannot marshal %T into %s. Accepted types: Marshaler, UDTMarshaler, map[string]interface{}, struct, UnsetValue.", value, info) + return nil, marshalErrorf("cannot marshal %T into UDT. Accepted types: UDTMarshaler, map[string]interface{}, struct, UnsetValue.", value) } fields := make(map[string]reflect.Value) @@ -2449,16 +2904,16 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { } var buf []byte - for _, e := range udt.Elements { - f, ok := fields[e.Name] + for i := range udt.Elements { + f, ok := fields[udt.Elements[i].Name] if !ok { - f = k.FieldByName(e.Name) + f = k.FieldByName(udt.Elements[i].Name) } var data []byte if f.IsValid() && f.CanInterface() { var err error - data, err = Marshal(e.Type, f.Interface()) + data, err = Marshal(udt.Elements[i].Type, f.Interface()) if err != nil { return nil, err } @@ -2470,19 +2925,22 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { return buf, nil } -func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { +// Unmarshal unmarshals the byte slice into the value. +func (udt UDTTypeInfo) Unmarshal(data []byte, value interface{}) error { + // do this up here so we don't need to duplicate all of the map logic below + if iptr, ok := value.(*interface{}); ok && iptr != nil { + v := map[string]interface{}{} + *iptr = v + value = &v + } switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) case UDTUnmarshaler: - udt := info.(UDTTypeInfo) - for id, e := range udt.Elements { if len(data) == 0 { return nil } if len(data) < 4 { - return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) + return unmarshalErrorf("can not unmarshal UDT: field [%d]%s: unexpected eof", id, e.Name) } var p []byte @@ -2494,48 +2952,30 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { return nil case *map[string]interface{}: - udt := info.(UDTTypeInfo) - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - - rv = rv.Elem() - t := rv.Type() - if t.Kind() != reflect.Map { - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: Unmarshaler, UDTUnmarshaler, *map[string]interface{}, struct.", info, value) - } else if data == nil { - rv.Set(reflect.Zero(t)) + if data == nil { + *v = nil return nil } - rv.Set(reflect.MakeMap(t)) - m := *v + m := map[string]interface{}{} + *v = m for id, e := range udt.Elements { if len(data) == 0 { return nil } if len(data) < 4 { - return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) - } - - valType, err := goType(e.Type) - if err != nil { - return unmarshalErrorf("can not unmarshal %s: %v", info, err) + return unmarshalErrorf("can not unmarshal UDT: field [%d]%s: unexpected eof", id, e.Name) } - val := reflect.New(valType) - var p []byte p, data = readBytes(data) - if err := Unmarshal(e.Type, p, val.Interface()); err != nil { + v := reflect.New(reflect.TypeOf(e.Type.Zero())) + if err := Unmarshal(e.Type, p, v.Interface()); err != nil { return err } - - m[e.Name] = val.Elem().Interface() + m[e.Name] = v.Elem().Interface() } return nil @@ -2547,7 +2987,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { } k := rv.Elem() if k.Kind() != reflect.Struct || !k.IsValid() { - return unmarshalErrorf("cannot unmarshal %s into %T. Accepted types: Unmarshaler, UDTUnmarshaler, *map[string]interface{}, *struct.", info, value) + return unmarshalErrorf("cannot unmarshal UDT into %T. Accepted types: UDTUnmarshaler, *map[string]interface{}, *struct.", value) } if len(data) == 0 { @@ -2568,14 +3008,13 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { } } - udt := info.(UDTTypeInfo) for id, e := range udt.Elements { if len(data) == 0 { return nil } if len(data) < 4 { // UDT def does not match the column value - return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) + return unmarshalErrorf("can not unmarshal UDT: field [%d]%s: unexpected eof", id, e.Name) } var p []byte @@ -2584,7 +3023,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { f, ok := fields[e.Name] if !ok { f = k.FieldByName(e.Name) - if f == emptyValue { //nolint:govet // there is no other way to compare with empty value + if !f.IsValid() { // skip fields which exist in the UDT but not in // the struct passed in continue @@ -2592,7 +3031,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { } if !f.IsValid() || !f.CanAddr() { - return unmarshalErrorf("cannot unmarshal %s into %T: field %v is not valid", info, value, e.Name) + return unmarshalErrorf("cannot unmarshal UDT into %T: field %v is not valid", value, e.Name) } fk := f.Addr().Interface() @@ -2604,253 +3043,6 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { return nil } -// TypeInfo describes a Cassandra specific data type. -type TypeInfo interface { - Type() Type - Version() byte - Custom() string - - // NewWithError creates a pointer to an empty version of whatever type - // is referenced by the TypeInfo receiver. - // - // If there is no corresponding Go type for the CQL type, NewWithError returns an error. - NewWithError() (interface{}, error) -} - -type NativeType struct { - proto byte - typ Type - custom string // only used for TypeCustom -} - -func NewNativeType(proto byte, typ Type) NativeType { - return NativeType{proto, typ, ""} -} - -func NewCustomType(proto byte, typ Type, custom string) NativeType { - return NativeType{proto, typ, custom} -} - -func (t NativeType) NewWithError() (interface{}, error) { - typ, err := goType(t) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (s NativeType) Type() Type { - return s.typ -} - -func (s NativeType) Version() byte { - return s.proto -} - -func (s NativeType) Custom() string { - return s.custom -} - -func (s NativeType) String() string { - switch s.typ { - case TypeCustom: - return fmt.Sprintf("%s(%s)", s.typ, s.custom) - default: - return s.typ.String() - } -} - -type CollectionType struct { - NativeType - Key TypeInfo // only used for TypeMap - Elem TypeInfo // only used for TypeMap, TypeList and TypeSet -} - -type VectorType struct { - NativeType - SubType TypeInfo - Dimensions int -} - -func (t CollectionType) NewWithError() (interface{}, error) { - typ, err := goType(t) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (c CollectionType) String() string { - switch c.typ { - case TypeMap: - return fmt.Sprintf("%s(%s, %s)", c.typ, c.Key, c.Elem) - case TypeList, TypeSet: - return fmt.Sprintf("%s(%s)", c.typ, c.Elem) - case TypeCustom: - return fmt.Sprintf("%s(%s)", c.typ, c.custom) - default: - return c.typ.String() - } -} - -type TupleTypeInfo struct { - NativeType - Elems []TypeInfo -} - -func (t TupleTypeInfo) String() string { - var buf bytes.Buffer - buf.WriteString(fmt.Sprintf("%s(", t.typ)) - for _, elem := range t.Elems { - buf.WriteString(fmt.Sprintf("%s, ", elem)) - } - buf.Truncate(buf.Len() - 2) - buf.WriteByte(')') - return buf.String() -} - -func (t TupleTypeInfo) NewWithError() (interface{}, error) { - typ, err := goType(t) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -type UDTField struct { - Name string - Type TypeInfo -} - -type UDTTypeInfo struct { - NativeType - KeySpace string - Name string - Elements []UDTField -} - -func (u UDTTypeInfo) NewWithError() (interface{}, error) { - typ, err := goType(u) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (u UDTTypeInfo) String() string { - buf := &bytes.Buffer{} - - fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name) - first := true - for _, e := range u.Elements { - if !first { - fmt.Fprint(buf, ",") - } else { - first = false - } - - fmt.Fprintf(buf, "%s=%v", e.Name, e.Type) - } - fmt.Fprint(buf, "}") - - return buf.String() -} - -// String returns a human readable name for the Cassandra datatype -// described by t. -// Type is the identifier of a Cassandra internal datatype. -type Type int - -const ( - TypeCustom Type = 0x0000 - TypeAscii Type = 0x0001 - TypeBigInt Type = 0x0002 - TypeBlob Type = 0x0003 - TypeBoolean Type = 0x0004 - TypeCounter Type = 0x0005 - TypeDecimal Type = 0x0006 - TypeDouble Type = 0x0007 - TypeFloat Type = 0x0008 - TypeInt Type = 0x0009 - TypeText Type = 0x000A - TypeTimestamp Type = 0x000B - TypeUUID Type = 0x000C - TypeVarchar Type = 0x000D - TypeVarint Type = 0x000E - TypeTimeUUID Type = 0x000F - TypeInet Type = 0x0010 - TypeDate Type = 0x0011 - TypeTime Type = 0x0012 - TypeSmallInt Type = 0x0013 - TypeTinyInt Type = 0x0014 - TypeDuration Type = 0x0015 - TypeList Type = 0x0020 - TypeMap Type = 0x0021 - TypeSet Type = 0x0022 - TypeUDT Type = 0x0030 - TypeTuple Type = 0x0031 -) - -// String returns the name of the identifier. -func (t Type) String() string { - switch t { - case TypeCustom: - return "custom" - case TypeAscii: - return "ascii" - case TypeBigInt: - return "bigint" - case TypeBlob: - return "blob" - case TypeBoolean: - return "boolean" - case TypeCounter: - return "counter" - case TypeDecimal: - return "decimal" - case TypeDouble: - return "double" - case TypeFloat: - return "float" - case TypeInt: - return "int" - case TypeText: - return "text" - case TypeTimestamp: - return "timestamp" - case TypeUUID: - return "uuid" - case TypeVarchar: - return "varchar" - case TypeTimeUUID: - return "timeuuid" - case TypeInet: - return "inet" - case TypeDate: - return "date" - case TypeDuration: - return "duration" - case TypeTime: - return "time" - case TypeSmallInt: - return "smallint" - case TypeTinyInt: - return "tinyint" - case TypeList: - return "list" - case TypeMap: - return "map" - case TypeSet: - return "set" - case TypeVarint: - return "varint" - case TypeTuple: - return "tuple" - default: - return fmt.Sprintf("unknown_type_%d", t) - } -} - type MarshalError string func (m MarshalError) Error() string { diff --git a/marshal_test.go b/marshal_test.go index cdbc64674..5b518b131 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -30,7 +30,6 @@ package gocql import ( "bytes" "encoding/binary" - "fmt" "math" "math/big" "net" @@ -59,56 +58,56 @@ var marshalTests = []struct { UnmarshalError error }{ { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte("hello world"), []byte("hello world"), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte("hello world"), "hello world", nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte(nil), []byte(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte("hello world"), MyString("hello world"), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte("HELLO WORLD"), CustomString("hello world"), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBlob}, + varcharLikeTypeInfo{typ: TypeBlob}, []byte("hello\x00"), []byte("hello\x00"), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBlob}, + varcharLikeTypeInfo{typ: TypeBlob}, []byte(nil), []byte(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTimeUUID}, + timeUUIDType{}, []byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, func() UUID { x, _ := UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}) @@ -118,287 +117,287 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeTimeUUID}, + timeUUIDType{}, []byte{0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, []byte{0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, MarshalError("can not marshal []byte 6 bytes long into timeuuid, must be exactly 16 bytes long"), - UnmarshalError("unable to parse UUID: UUIDs must be exactly 16 bytes long"), + UnmarshalError("unable to parse timeuuid: UUIDs must be exactly 16 bytes long"), }, { - NativeType{proto: protoVersion3, typ: TypeTimeUUID}, + timeUUIDType{}, []byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, [16]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x00\x00\x00\x00"), 0, nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x01\x02\x03\x04"), int(16909060), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x01\x02\x03\x04"), AliasInt(16909060), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x80\x00\x00\x00"), int32(math.MinInt32), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x7f\xff\xff\xff"), int32(math.MaxInt32), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x00\x00\x00\x00"), "0", nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x01\x02\x03\x04"), "16909060", nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x80\x00\x00\x00"), "-2147483648", // math.MinInt32 nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x7f\xff\xff\xff"), "2147483647", // math.MaxInt32 nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x00\x00"), 0, nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x01\x02\x03\x04\x05\x06\x07\x08"), 72623859790382856, nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x80\x00\x00\x00\x00\x00\x00\x00"), int64(math.MinInt64), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x7f\xff\xff\xff\xff\xff\xff\xff"), int64(math.MaxInt64), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x00\x00"), "0", nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x01\x02\x03\x04\x05\x06\x07\x08"), "72623859790382856", nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x80\x00\x00\x00\x00\x00\x00\x00"), "-9223372036854775808", // math.MinInt64 nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x7f\xff\xff\xff\xff\xff\xff\xff"), "9223372036854775807", // math.MaxInt64 nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBoolean}, + booleanTypeInfo{}, []byte("\x00"), false, nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBoolean}, + booleanTypeInfo{}, []byte("\x01"), true, nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeFloat}, + floatTypeInfo{}, []byte("\x40\x49\x0f\xdb"), float32(3.14159265), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDouble}, + doubleTypeInfo{}, []byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"), float64(3.14159265), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\x00\x00"), inf.NewDec(0, 0), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\x00\x64"), inf.NewDec(100, 0), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\x02\x19"), decimalize("0.25"), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\x13\xD5\a;\x20\x14\xA2\x91"), decimalize("-0.0012095473475870063"), // From the iconara/cql-rb test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\x13*\xF8\xC4\xDF\xEB]o"), decimalize("0.0012095473475870063"), // From the iconara/cql-rb test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\x12\xF2\xD8\x02\xB6R\x7F\x99\xEE\x98#\x99\xA9V"), decimalize("-1042342234234.123423435647768234"), // From the iconara/cql-rb test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\r\nJ\x04\"^\x91\x04\x8a\xb1\x18\xfe"), decimalize("1243878957943.1234124191998"), // From the datastax/python-driver test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\x06\xe5\xde]\x98Y"), decimalize("-112233.441191"), // From the datastax/python-driver test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\x14\x00\xfa\xce"), decimalize("0.00000000000000064206"), // From the datastax/python-driver test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\x00\x00\x00\x14\xff\x052"), decimalize("-0.00000000000000064206"), // From the datastax/python-driver test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\xff\xff\xff\x9c\x00\xfa\xce"), inf.NewDec(64206, -100), // From the datastax/python-driver test suite nil, nil, }, { - NativeType{proto: protoVersion4, typ: TypeTime}, + timeTypeInfo{}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), time.Duration(int64(1376387523000)), nil, nil, }, { - NativeType{proto: protoVersion4, typ: TypeTime}, + timeTypeInfo{}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), int64(1376387523000), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), int64(1376387523000), nil, nil, }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, []byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91\x06"), Duration{Months: 1233, Days: 123213, Nanoseconds: 2312323}, nil, nil, }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, []byte("\x89\xa1\xc3\xc2\x99\xe0F\x91\x05"), Duration{Months: -1233, Days: -123213, Nanoseconds: -2312323}, nil, nil, }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, []byte("\x02\x04\x80\xe6"), Duration{Months: 1, Days: 2, Nanoseconds: 115}, nil, @@ -406,8 +405,8 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeList}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeList, + Elem: intTypeInfo{}, }, []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02"), []int{1, 2}, @@ -416,8 +415,8 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeList}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeList, + Elem: intTypeInfo{}, }, []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02"), [2]int{1, 2}, @@ -426,8 +425,8 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeSet}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeSet, + Elem: intTypeInfo{}, }, []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02"), []int{1, 2}, @@ -436,8 +435,8 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeSet}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeSet, + Elem: intTypeInfo{}, }, []byte{0, 0, 0, 0}, // encoding of a list should always include the size of the collection []int{}, @@ -446,9 +445,9 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeMap}, - Key: NativeType{proto: protoVersion3, typ: TypeVarchar}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeVarchar}, + Elem: intTypeInfo{}, }, []byte("\x00\x00\x00\x01\x00\x00\x00\x03foo\x00\x00\x00\x04\x00\x00\x00\x01"), map[string]int{"foo": 1}, @@ -457,9 +456,9 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeMap}, - Key: NativeType{proto: protoVersion3, typ: TypeVarchar}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeVarchar}, + Elem: intTypeInfo{}, }, []byte{0, 0, 0, 0}, map[string]int{}, @@ -468,8 +467,8 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeList}, - Elem: NativeType{proto: protoVersion3, typ: TypeVarchar}, + typ: TypeList, + Elem: varcharLikeTypeInfo{typ: TypeVarchar}, }, bytes.Join([][]byte{ []byte("\x00\x00\x00\x01\x00\x00\xff\xff"), @@ -480,9 +479,9 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeMap}, - Key: NativeType{proto: protoVersion3, typ: TypeVarchar}, - Elem: NativeType{proto: protoVersion3, typ: TypeVarchar}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeVarchar}, + Elem: varcharLikeTypeInfo{typ: TypeVarchar}, }, bytes.Join([][]byte{ []byte("\x00\x00\x00\x01\x00\x00\xff\xff"), @@ -496,119 +495,119 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarint}, + varintTypeInfo{}, []byte("\x00"), 0, nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarint}, + varintTypeInfo{}, []byte("\x37\xE2\x3C\xEC"), int32(937573612), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarint}, + varintTypeInfo{}, []byte("\x37\xE2\x3C\xEC"), big.NewInt(937573612), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarint}, + varintTypeInfo{}, []byte("\x03\x9EV \x15\f\x03\x9DK\x18\xCDI\\$?\a["), bigintize("1231312312331283012830129382342342412123"), // From the iconara/cql-rb test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarint}, + varintTypeInfo{}, []byte("\xC9v\x8D:\x86"), big.NewInt(-234234234234), // From the iconara/cql-rb test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarint}, + varintTypeInfo{}, []byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"), bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarint}, + varintTypeInfo{}, []byte(nil), nil, nil, UnmarshalError("can not unmarshal into non-pointer "), }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte("\x7F\x00\x00\x01"), net.ParseIP("127.0.0.1").To4(), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte("\xFF\xFF\xFF\xFF"), net.ParseIP("255.255.255.255").To4(), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte("\x7F\x00\x00\x01"), "127.0.0.1", nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte("\xFF\xFF\xFF\xFF"), "255.255.255.255", nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"), "21da:d3:0:2f3b:2aa:ff:fe28:9c5a", nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"), "fe80::202:b3ff:fe1e:8329", nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"), net.ParseIP("21da:d3:0:2f3b:2aa:ff:fe28:9c5a"), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"), net.ParseIP("fe80::202:b3ff:fe1e:8329"), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte(nil), nil, nil, UnmarshalError("can not unmarshal into non-pointer "), }, { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte("nullable string"), func() *string { value := "nullable string" @@ -618,14 +617,14 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte(nil), (*string)(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x7f\xff\xff\xff"), func() *int { var value int = math.MaxInt32 @@ -635,28 +634,28 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte(nil), (*int)(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTimeUUID}, + timeUUIDType{}, []byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, &UUID{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTimeUUID}, + timeUUIDType{}, []byte(nil), (*UUID)(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), func() *time.Time { t := time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC) @@ -666,14 +665,14 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte(nil), (*time.Time)(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBoolean}, + booleanTypeInfo{}, []byte("\x00"), func() *bool { b := false @@ -683,7 +682,7 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeBoolean}, + booleanTypeInfo{}, []byte("\x01"), func() *bool { b := true @@ -693,14 +692,14 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeBoolean}, + booleanTypeInfo{}, []byte(nil), (*bool)(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeFloat}, + floatTypeInfo{}, []byte("\x40\x49\x0f\xdb"), func() *float32 { f := float32(3.14159265) @@ -710,14 +709,14 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeFloat}, + floatTypeInfo{}, []byte(nil), (*float32)(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeDouble}, + doubleTypeInfo{}, []byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"), func() *float64 { d := float64(3.14159265) @@ -727,14 +726,14 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeDouble}, + doubleTypeInfo{}, []byte(nil), (*float64)(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte("\x7F\x00\x00\x01"), func() *net.IP { ip := net.ParseIP("127.0.0.1").To4() @@ -744,7 +743,7 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeInet}, + inetType{}, []byte(nil), (*net.IP)(nil), nil, @@ -752,8 +751,8 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeList}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeList, + Elem: intTypeInfo{}, }, []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02"), func() *[]int { @@ -765,21 +764,8 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeList}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, - }, - []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02"), - func() *[]int { - l := []int{1, 2} - return &l - }(), - nil, - nil, - }, - { - CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeList}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeList, + Elem: intTypeInfo{}, }, []byte(nil), (*[]int)(nil), @@ -788,9 +774,9 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeMap}, - Key: NativeType{proto: protoVersion3, typ: TypeVarchar}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeVarchar}, + Elem: intTypeInfo{}, }, []byte("\x00\x00\x00\x01\x00\x00\x00\x03foo\x00\x00\x00\x04\x00\x00\x00\x01"), func() *map[string]int { @@ -802,9 +788,9 @@ var marshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeMap}, - Key: NativeType{proto: protoVersion3, typ: TypeVarchar}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeVarchar}, + Elem: intTypeInfo{}, }, []byte(nil), (*map[string]int)(nil), @@ -812,7 +798,7 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte("HELLO WORLD"), func() *CustomString { customString := CustomString("hello world") @@ -822,252 +808,252 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte(nil), (*CustomString)(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\x7f\xff"), 32767, // math.MaxInt16 nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\x7f\xff"), "32767", // math.MaxInt16 nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\x00\x01"), int16(1), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), int16(-1), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\x00\xff"), uint8(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), uint16(65535), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), uint32(65535), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), uint64(65535), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\x00\xff"), AliasUint8(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), AliasUint16(65535), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), AliasUint32(65535), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), AliasUint64(65535), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), AliasUint(65535), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\x7f"), 127, // math.MaxInt8 nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\x7f"), "127", // math.MaxInt8 nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\x01"), int16(1), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), int16(-1), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), uint8(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), uint64(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), uint32(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), uint16(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), uint(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), AliasUint8(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), AliasUint64(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), AliasUint32(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), AliasUint16(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeTinyInt}, + tinyIntTypeInfo{}, []byte("\xff"), AliasUint(255), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x00\xff"), uint8(math.MaxUint8), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\xff\xff"), uint64(math.MaxUint16), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\xff\xff\xff\xff"), uint64(math.MaxUint32), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint64(math.MaxUint64), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\xff\xff\xff\xff"), uint32(math.MaxUint32), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\xff\xff\xff\xff"), uint64(math.MaxUint32), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeBlob}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte(nil), ([]byte)(nil), nil, nil, }, { - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, []byte{}, func() interface{} { var s string @@ -1077,7 +1063,7 @@ var marshalTests = []struct { nil, }, { - NativeType{proto: protoVersion3, typ: TypeTime}, + timeTypeInfo{}, encBigInt(1000), time.Duration(1000), nil, @@ -1092,177 +1078,177 @@ var unmarshalTests = []struct { UnmarshalError error }{ { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), uint8(0), UnmarshalError("unmarshal int: value -1 out of range for uint8"), }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\x01\x00"), uint8(0), UnmarshalError("unmarshal int: value 256 out of range for uint8"), }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\xff\xff\xff\xff"), uint8(0), UnmarshalError("unmarshal int: value -1 out of range for uint8"), }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x00\x00\x01\x00"), uint8(0), UnmarshalError("unmarshal int: value 256 out of range for uint8"), }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\xff\xff\xff\xff"), uint16(0), UnmarshalError("unmarshal int: value -1 out of range for uint16"), }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x00\x01\x00\x00"), uint16(0), UnmarshalError("unmarshal int: value 65536 out of range for uint16"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint8(0), UnmarshalError("unmarshal int: value -1 out of range for uint8"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), uint8(0), UnmarshalError("unmarshal int: value 256 out of range for uint8"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint8(0), UnmarshalError("unmarshal int: value -1 out of range for uint8"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), uint8(0), UnmarshalError("unmarshal int: value 256 out of range for uint8"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint16(0), UnmarshalError("unmarshal int: value -1 out of range for uint16"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x01\x00\x00"), uint16(0), UnmarshalError("unmarshal int: value 65536 out of range for uint16"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint32(0), UnmarshalError("unmarshal int: value -1 out of range for uint32"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x01\x00\x00\x00\x00"), uint32(0), UnmarshalError("unmarshal int: value 4294967296 out of range for uint32"), }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\xff\xff"), AliasUint8(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), }, { - NativeType{proto: protoVersion3, typ: TypeSmallInt}, + smallIntTypeInfo{}, []byte("\x01\x00"), AliasUint8(0), UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\xff\xff\xff\xff"), AliasUint8(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x00\x00\x01\x00"), AliasUint8(0), UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\xff\xff\xff\xff"), AliasUint16(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint16"), }, { - NativeType{proto: protoVersion3, typ: TypeInt}, + intTypeInfo{}, []byte("\x00\x01\x00\x00"), AliasUint16(0), UnmarshalError("unmarshal int: value 65536 out of range for gocql.AliasUint16"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), AliasUint8(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), AliasUint8(0), UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), AliasUint8(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), AliasUint8(0), UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), AliasUint16(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint16"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x01\x00\x00"), AliasUint16(0), UnmarshalError("unmarshal int: value 65536 out of range for gocql.AliasUint16"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), AliasUint32(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint32"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, []byte("\x00\x00\x00\x01\x00\x00\x00\x00"), AliasUint32(0), UnmarshalError("unmarshal int: value 4294967296 out of range for gocql.AliasUint32"), }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeList}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeList, + Elem: intTypeInfo{}, }, []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00"), // truncated data func() *[]int { @@ -1273,9 +1259,9 @@ var unmarshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeMap}, - Key: NativeType{proto: protoVersion3, typ: TypeVarchar}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeVarchar}, + Elem: intTypeInfo{}, }, []byte("\x00\x00\x00\x01\x00\x00\x00\x03fo"), map[string]int{"foo": 1}, @@ -1283,50 +1269,59 @@ var unmarshalTests = []struct { }, { CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeMap}, - Key: NativeType{proto: protoVersion3, typ: TypeVarchar}, - Elem: NativeType{proto: protoVersion3, typ: TypeInt}, + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeVarchar}, + Elem: intTypeInfo{}, }, []byte("\x00\x00\x00\x01\x00\x00\x00\x03foo\x00\x04\x00\x00"), map[string]int{"foo": 1}, UnmarshalError("unmarshal map: unexpected eof"), }, { - NativeType{proto: protoVersion3, typ: TypeDecimal}, + decimalTypeInfo{}, []byte("\xff\xff\xff"), inf.NewDec(0, 0), // From the datastax/python-driver test suite UnmarshalError("inf.Dec needs at least 4 bytes, while value has only 3"), }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, []byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract nanoseconds: data expect to have 9 bytes, but it has only 8"), }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, []byte("\x89\xa2\xc3\xc2\x9a"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract nanoseconds: unexpected eof"), }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, []byte("\x89\xa2\xc3\xc2"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract days: data expect to have 5 bytes, but it has only 4"), }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, []byte("\x89\xa2"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract days: unexpected eof"), }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, []byte("\x89"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract month: data expect to have 2 bytes, but it has only 1"), }, + { + varcharLikeTypeInfo{typ: TypeVarchar}, + []byte("HELLO WORLD"), + func() *CustomString { + s := CustomString("hello world") + return &s + }(), + nil, + }, } func decimalize(s string) *inf.Dec { @@ -1364,15 +1359,15 @@ func TestMarshal_Decode(t *testing.T) { v := reflect.New(reflect.TypeOf(test.Value)) err := Unmarshal(test.Info, test.Data, v.Interface()) if err != nil { - t.Errorf("unmarshalTest[%d] (%v=>%T): %v", i, test.Info, test.Value, err) + t.Errorf("marshalTest[%d] (%v=>%T): %v", i, test.Info, test.Value, err) continue } if !reflect.DeepEqual(v.Elem().Interface(), test.Value) { - t.Errorf("unmarshalTest[%d] (%v=>%T): expected %#v, got %#v.", i, test.Info, test.Value, test.Value, v.Elem().Interface()) + t.Errorf("marshalTest[%d] (%v=>%T): expected %#v, got %#v.", i, test.Info, test.Value, test.Value, v.Elem().Interface()) } } else { if err := Unmarshal(test.Info, test.Data, test.Value); err != test.UnmarshalError { - t.Errorf("unmarshalTest[%d] (%v=>%T): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError) + t.Errorf("marshalTest[%d] (%v=>%T): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError) } } } @@ -1459,7 +1454,7 @@ func TestMarshalVarint(t *testing.T) { } for i, test := range varintTests { - data, err := Marshal(NativeType{proto: protoVersion3, typ: TypeVarint}, test.Value) + data, err := Marshal(varintTypeInfo{}, test.Value) if err != nil { t.Errorf("error marshaling varint: %v (test #%d)", err, i) } @@ -1469,7 +1464,7 @@ func TestMarshalVarint(t *testing.T) { } binder := new(big.Int) - err = Unmarshal(NativeType{proto: protoVersion3, typ: TypeVarint}, test.Marshaled, binder) + err = Unmarshal(varintTypeInfo{}, test.Marshaled, binder) if err != nil { t.Errorf("error unmarshaling varint: %v (test #%d)", err, i) } @@ -1517,7 +1512,7 @@ func TestMarshalVarint(t *testing.T) { } for i, test := range varintUint64Tests { - data, err := Marshal(NativeType{proto: protoVersion3, typ: TypeVarint}, test.Value) + data, err := Marshal(varintTypeInfo{}, test.Value) if err != nil { t.Errorf("error marshaling varint: %v (test #%d)", err, i) } @@ -1527,7 +1522,7 @@ func TestMarshalVarint(t *testing.T) { } var binder uint64 - err = Unmarshal(NativeType{proto: protoVersion3, typ: TypeVarint}, test.Marshaled, &binder) + err = Unmarshal(varintTypeInfo{}, test.Marshaled, &binder) if err != nil { t.Errorf("error unmarshaling varint to uint64: %v (test #%d)", err, i) } @@ -1545,12 +1540,12 @@ func TestMarshalBigInt(t *testing.T) { MarshalError error }{ { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, "-78635384813432117863538481343211", MarshalError("can not marshal string to bigint: strconv.ParseInt: parsing \"-78635384813432117863538481343211\": value out of range"), }, { - NativeType{proto: protoVersion3, typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeBigInt}, "922337203685477692259749625974294", MarshalError("can not marshal string to bigint: strconv.ParseInt: parsing \"922337203685477692259749625974294\": value out of range"), }, @@ -1577,9 +1572,9 @@ func equalStringPointerSlice(leftList, rightList []*string) bool { } func TestMarshalList(t *testing.T) { - typeInfoV3 := CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeList}, - Elem: NativeType{proto: protoVersion3, typ: TypeVarchar}, + typeInfo := CollectionType{ + typ: TypeList, + Elem: varcharLikeTypeInfo{typ: TypeVarchar}, } type tc struct { @@ -1593,17 +1588,17 @@ func TestMarshalList(t *testing.T) { valueEmpty := "" testCases := []tc{ { - typeInfo: typeInfoV3, + typeInfo: typeInfo, input: []*string{&valueEmpty}, expected: []*string{&valueEmpty}, }, { - typeInfo: typeInfoV3, + typeInfo: typeInfo, input: []*string{nil}, expected: []*string{nil}, }, { - typeInfo: typeInfoV3, + typeInfo: typeInfo, input: []*string{&valueA, nil, &valueB}, expected: []*string{&valueA, nil, &valueB}, }, @@ -1657,46 +1652,6 @@ func (c *CustomString) UnmarshalCQL(info TypeInfo, data []byte) error { type MyString string -var typeLookupTest = []struct { - TypeName string - ExpectedType Type -}{ - {"AsciiType", TypeAscii}, - {"LongType", TypeBigInt}, - {"BytesType", TypeBlob}, - {"BooleanType", TypeBoolean}, - {"CounterColumnType", TypeCounter}, - {"DecimalType", TypeDecimal}, - {"DoubleType", TypeDouble}, - {"FloatType", TypeFloat}, - {"Int32Type", TypeInt}, - {"DateType", TypeTimestamp}, - {"TimestampType", TypeTimestamp}, - {"UUIDType", TypeUUID}, - {"UTF8Type", TypeVarchar}, - {"IntegerType", TypeVarint}, - {"TimeUUIDType", TypeTimeUUID}, - {"InetAddressType", TypeInet}, - {"MapType", TypeMap}, - {"ListType", TypeList}, - {"SetType", TypeSet}, - {"unknown", TypeCustom}, - {"ShortType", TypeSmallInt}, - {"ByteType", TypeTinyInt}, -} - -func testType(t *testing.T, cassType string, expectedType Type) { - if computedType := getApacheCassandraType(apacheCassandraTypePrefix + cassType); computedType != expectedType { - t.Errorf("Cassandra custom type lookup for %s failed. Expected %s, got %s.", cassType, expectedType.String(), computedType.String()) - } -} - -func TestLookupCassType(t *testing.T) { - for _, lookupTest := range typeLookupTest { - testType(t, lookupTest.TypeName, lookupTest.ExpectedType) - } -} - type MyPointerMarshaler struct{} func (m *MyPointerMarshaler) MarshalCQL(_ TypeInfo) ([]byte, error) { @@ -1705,7 +1660,7 @@ func (m *MyPointerMarshaler) MarshalCQL(_ TypeInfo) ([]byte, error) { func TestMarshalPointer(t *testing.T) { m := &MyPointerMarshaler{} - typ := NativeType{proto: protoVersion3, typ: TypeInt} + typ := intTypeInfo{} data, err := Marshal(typ, m) @@ -1727,24 +1682,23 @@ func TestMarshalTime(t *testing.T) { Value interface{} }{ { - NativeType{proto: protoVersion4, typ: TypeTime}, + timeTypeInfo{}, expectedData, duration.Nanoseconds(), }, { - NativeType{proto: protoVersion4, typ: TypeTime}, + timeTypeInfo{}, expectedData, duration, }, { - NativeType{proto: protoVersion4, typ: TypeTime}, + timeTypeInfo{}, expectedData, &duration, }, } for i, test := range marshalTimeTests { - t.Log(i, test) data, err := Marshal(test.Info, test.Value) if err != nil { t.Errorf("marshalTest[%d]: %v", i, err) @@ -1764,53 +1718,52 @@ func TestMarshalTimestamp(t *testing.T) { Value interface{} }{ { - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), }, { - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), int64(1376387523000), }, { // 9223372036854 is the maximum time representable in ms since the epoch // with int64 if using UnixNano to convert - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte("\x00\x00\x08\x63\x7b\xd0\x5a\xf6"), time.Date(2262, time.April, 11, 23, 47, 16, 854775807, time.UTC), }, { // One nanosecond after causes overflow when using UnixNano // Instead it should resolve to the same time in ms - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte("\x00\x00\x08\x63\x7b\xd0\x5a\xf6"), time.Date(2262, time.April, 11, 23, 47, 16, 854775808, time.UTC), }, { // -9223372036855 is the minimum time representable in ms since the epoch // with int64 if using UnixNano to convert - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte("\xff\xff\xf7\x9c\x84\x2f\xa5\x09"), time.Date(1677, time.September, 21, 00, 12, 43, 145224192, time.UTC), }, { // One nanosecond earlier causes overflow when using UnixNano // it should resolve to the same time in ms - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte("\xff\xff\xf7\x9c\x84\x2f\xa5\x09"), time.Date(1677, time.September, 21, 00, 12, 43, 145224191, time.UTC), }, { // Store the zero time as a blank slice - NativeType{proto: protoVersion3, typ: TypeTimestamp}, + timestampTypeInfo{}, []byte{}, time.Time{}, }, } for i, test := range marshalTimestampTests { - t.Log(i, test) data, err := Marshal(test.Info, test.Value) if err != nil { t.Errorf("marshalTest[%d]: %v", i, err) @@ -1825,10 +1778,9 @@ func TestMarshalTimestamp(t *testing.T) { func TestMarshalTuple(t *testing.T) { info := TupleTypeInfo{ - NativeType: NativeType{proto: protoVersion3, typ: TypeTuple}, Elems: []TypeInfo{ - NativeType{proto: protoVersion3, typ: TypeVarchar}, - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, }, } @@ -1972,10 +1924,9 @@ func TestMarshalTuple(t *testing.T) { func TestUnmarshalTuple(t *testing.T) { info := TupleTypeInfo{ - NativeType: NativeType{proto: protoVersion3, typ: TypeTuple}, Elems: []TypeInfo{ - NativeType{proto: protoVersion3, typ: TypeVarchar}, - NativeType{proto: protoVersion3, typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, }, } @@ -1995,11 +1946,23 @@ func TestUnmarshalTuple(t *testing.T) { t.Errorf("unmarshalTest: %v", err) return } + if tmp.A != nil || *tmp.B != "foo" { + t.Errorf("unmarshalTest: expected [nil, foo], got [%#v, %#v]", *tmp.A, *tmp.B) + } + + tmp.A = new(string) + *tmp.A = "bar" + err = Unmarshal(info, data, &tmp) + if err != nil { + t.Errorf("unmarshalTest: %v", err) + return + } if tmp.A != nil || *tmp.B != "foo" { - t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", *tmp.A, *tmp.B) + t.Errorf("unmarshalTest: expected [nil, foo], got [%#v, %#v]", *tmp.A, *tmp.B) } }) + t.Run("struct-nonptr", func(t *testing.T) { var tmp struct { A string @@ -2011,7 +1974,17 @@ func TestUnmarshalTuple(t *testing.T) { t.Errorf("unmarshalTest: %v", err) return } + if tmp.A != "" || tmp.B != "foo" { + t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", tmp.A, tmp.B) + } + tmp.A = "bar" + + err = Unmarshal(info, data, &tmp) + if err != nil { + t.Errorf("unmarshalTest: %v", err) + return + } if tmp.A != "" || tmp.B != "foo" { t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", tmp.A, tmp.B) } @@ -2025,11 +1998,23 @@ func TestUnmarshalTuple(t *testing.T) { t.Errorf("unmarshalTest: %v", err) return } + if tmp[0] != nil || *tmp[1] != "foo" { + t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", *tmp[0], *tmp[1]) + } + + tmp[0] = new(string) + *tmp[0] = "bar" + err = Unmarshal(info, data, &tmp) + if err != nil { + t.Errorf("unmarshalTest: %v", err) + return + } if tmp[0] != nil || *tmp[1] != "foo" { t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", *tmp[0], *tmp[1]) } }) + t.Run("array-nonptr", func(t *testing.T) { var tmp [2]string @@ -2038,7 +2023,17 @@ func TestUnmarshalTuple(t *testing.T) { t.Errorf("unmarshalTest: %v", err) return } + if tmp[0] != "" || tmp[1] != "foo" { + t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", tmp[0], tmp[1]) + } + + tmp[0] = "bar" + err = Unmarshal(info, data, &tmp) + if err != nil { + t.Errorf("unmarshalTest: %v", err) + return + } if tmp[0] != "" || tmp[1] != "foo" { t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", tmp[0], tmp[1]) } @@ -2046,11 +2041,14 @@ func TestUnmarshalTuple(t *testing.T) { } func TestMarshalUDTMap(t *testing.T) { - typeInfo := UDTTypeInfo{NativeType{proto: protoVersion3, typ: TypeUDT}, "", "xyz", []UDTField{ - {Name: "x", Type: NativeType{proto: protoVersion3, typ: TypeInt}}, - {Name: "y", Type: NativeType{proto: protoVersion3, typ: TypeInt}}, - {Name: "z", Type: NativeType{proto: protoVersion3, typ: TypeInt}}, - }} + typeInfo := UDTTypeInfo{ + Name: "xyz", + Elements: []UDTField{ + {Name: "x", Type: intTypeInfo{}}, + {Name: "y", Type: intTypeInfo{}}, + {Name: "z", Type: intTypeInfo{}}, + }, + } t.Run("partially bound", func(t *testing.T) { value := map[string]interface{}{ @@ -2101,11 +2099,14 @@ func TestMarshalUDTMap(t *testing.T) { } func TestMarshalUDTStruct(t *testing.T) { - typeInfo := UDTTypeInfo{NativeType{proto: protoVersion3, typ: TypeUDT}, "", "xyz", []UDTField{ - {Name: "x", Type: NativeType{proto: protoVersion3, typ: TypeInt}}, - {Name: "y", Type: NativeType{proto: protoVersion3, typ: TypeInt}}, - {Name: "z", Type: NativeType{proto: protoVersion3, typ: TypeInt}}, - }} + typeInfo := UDTTypeInfo{ + Name: "xyz", + Elements: []UDTField{ + {Name: "x", Type: intTypeInfo{}}, + {Name: "y", Type: intTypeInfo{}}, + {Name: "z", Type: intTypeInfo{}}, + }, + } type xyzStruct struct { X int32 `cql:"x"` @@ -2170,26 +2171,26 @@ func TestMarshalUDTStruct(t *testing.T) { } func TestMarshalNil(t *testing.T) { - types := []Type{ - TypeAscii, - TypeBlob, - TypeBoolean, - TypeBigInt, - TypeCounter, - TypeDecimal, - TypeDouble, - TypeFloat, - TypeInt, - TypeTimestamp, - TypeUUID, - TypeVarchar, - TypeVarint, - TypeTimeUUID, - TypeInet, + types := []TypeInfo{ + varcharLikeTypeInfo{typ: TypeAscii}, + varcharLikeTypeInfo{typ: TypeBlob}, + booleanTypeInfo{}, + bigIntLikeTypeInfo{typ: TypeBigInt}, + bigIntLikeTypeInfo{typ: TypeCounter}, + decimalTypeInfo{}, + doubleTypeInfo{}, + floatTypeInfo{}, + intTypeInfo{}, + timestampTypeInfo{}, + uuidType{}, + varcharLikeTypeInfo{typ: TypeVarchar}, + varintTypeInfo{}, + timeUUIDType{}, + inetType{}, } for _, typ := range types { - data, err := Marshal(NativeType{proto: protoVersion3, typ: typ}, nil) + data, err := Marshal(typ, nil) if err != nil { t.Errorf("unable to marshal nil %v: %v\n", typ, err) } else if data != nil { @@ -2198,10 +2199,20 @@ func TestMarshalNil(t *testing.T) { } } -func TestUnmarshalInetCopyBytes(t *testing.T) { +func TestUnmarshalInet_Nil(t *testing.T) { + var ip net.IP + if err := Unmarshal(inetType{}, []byte(nil), &ip); err != nil { + t.Fatal(err) + } + if ip != nil { + t.Fatalf("expected nil ip, got %v", ip) + } +} + +func TestUnmarshalInet_CopyBytes(t *testing.T) { data := []byte{127, 0, 0, 1} var ip net.IP - if err := unmarshalInet(NativeType{proto: protoVersion3, typ: TypeInet}, data, &ip); err != nil { + if err := Unmarshal(inetType{}, data, &ip); err != nil { t.Fatal(err) } @@ -2215,7 +2226,7 @@ func TestUnmarshalInetCopyBytes(t *testing.T) { func TestUnmarshalDate(t *testing.T) { data := []uint8{0x80, 0x0, 0x43, 0x31} var date time.Time - if err := unmarshalDate(NativeType{proto: protoVersion3, typ: TypeDate}, data, &date); err != nil { + if err := Unmarshal(dateTypeInfo{}, data, &date); err != nil { t.Fatal(err) } @@ -2226,7 +2237,7 @@ func TestUnmarshalDate(t *testing.T) { return } var stringDate string - if err2 := unmarshalDate(NativeType{proto: protoVersion3, typ: TypeDate}, data, &stringDate); err2 != nil { + if err2 := Unmarshal(dateTypeInfo{}, data, &stringDate); err2 != nil { t.Fatal(err2) } if expectedDate != stringDate { @@ -2246,29 +2257,28 @@ func TestMarshalDate(t *testing.T) { Value interface{} }{ { - NativeType{proto: protoVersion4, typ: TypeDate}, + dateTypeInfo{}, expectedData, timestamp, }, { - NativeType{proto: protoVersion4, typ: TypeDate}, + dateTypeInfo{}, expectedData, now, }, { - NativeType{proto: protoVersion4, typ: TypeDate}, + dateTypeInfo{}, expectedData, &now, }, { - NativeType{proto: protoVersion4, typ: TypeDate}, + dateTypeInfo{}, expectedData, now.Format("2006-01-02"), }, } for i, test := range marshalDateTests { - t.Log(i, test) data, err := Marshal(test.Info, test.Value) if err != nil { t.Errorf("marshalTest[%d]: %v", i, err) @@ -2305,12 +2315,10 @@ func TestLargeDate(t *testing.T) { }, } - nativeType := NativeType{proto: protoVersion4, typ: TypeDate} + typ := dateTypeInfo{} for i, test := range marshalDateTests { - t.Log(i, test) - - data, err := Marshal(nativeType, test.Value) + data, err := Marshal(typ, test.Value) if err != nil { t.Errorf("largeDateTest[%d]: %v", i, err) continue @@ -2321,7 +2329,7 @@ func TestLargeDate(t *testing.T) { } var date time.Time - if err := Unmarshal(nativeType, data, &date); err != nil { + if err := Unmarshal(typ, data, &date); err != nil { t.Fatal(err) } @@ -2332,19 +2340,6 @@ func TestLargeDate(t *testing.T) { } } -func BenchmarkUnmarshalVarchar(b *testing.B) { - b.ReportAllocs() - src := make([]byte, 1024) - dst := make([]byte, len(src)) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := unmarshalVarchar(NativeType{}, src, &dst); err != nil { - b.Fatal(err) - } - } -} - func TestMarshalDuration(t *testing.T) { durationS := "1h10m10s" duration, _ := time.ParseDuration(durationS) @@ -2355,22 +2350,22 @@ func TestMarshalDuration(t *testing.T) { Value interface{} }{ { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, expectedData, duration.Nanoseconds(), }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, expectedData, duration, }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, expectedData, durationS, }, { - NativeType{proto: protoVersion5, typ: TypeDuration}, + durationTypeInfo{}, expectedData, &duration, }, @@ -2391,9 +2386,9 @@ func TestMarshalDuration(t *testing.T) { } func TestReadCollectionSize(t *testing.T) { - listV3 := CollectionType{ - NativeType: NativeType{proto: protoVersion3, typ: TypeList}, - Elem: NativeType{proto: protoVersion3, typ: TypeVarchar}, + list := CollectionType{ + typ: TypeList, + Elem: varcharLikeTypeInfo{typ: TypeVarchar}, } tests := []struct { @@ -2405,31 +2400,31 @@ func TestReadCollectionSize(t *testing.T) { }{ { name: "short read 0 proto 3", - info: listV3, + info: list, data: []byte{}, isError: true, }, { name: "short read 1 proto 3", - info: listV3, + info: list, data: []byte{0x01}, isError: true, }, { name: "short read 2 proto 3", - info: listV3, + info: list, data: []byte{0x01, 0x38}, isError: true, }, { name: "short read 3 proto 3", - info: listV3, + info: list, data: []byte{0x01, 0x38, 0x42}, isError: true, }, { name: "good read proto 3", - info: listV3, + info: list, data: []byte{0x01, 0x38, 0x42, 0x22}, expectedSize: 0x01384222, }, @@ -2453,70 +2448,24 @@ func TestReadCollectionSize(t *testing.T) { } } -func TestReadUnsignedVInt(t *testing.T) { - tests := []struct { - decodedInt uint64 - encodedVint []byte - }{ - { - decodedInt: 0, - encodedVint: []byte{0}, - }, - { - decodedInt: 100, - encodedVint: []byte{100}, - }, - { - decodedInt: 256000, - encodedVint: []byte{195, 232, 0}, - }, - } - for _, test := range tests { - t.Run(fmt.Sprintf("%d", test.decodedInt), func(t *testing.T) { - actual, _, err := readUnsignedVInt(test.encodedVint) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if actual != test.decodedInt { - t.Fatalf("Expected %d, but got %d", test.decodedInt, actual) - } - }) - } -} - -func BenchmarkUnmarshalUUID(b *testing.B) { - b.ReportAllocs() - src := make([]byte, 16) - dst := UUID{} - var ti TypeInfo = NativeType{} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := unmarshalUUID(ti, src, &dst); err != nil { - b.Fatal(err) - } - } -} - func TestUnmarshalUDT(t *testing.T) { info := UDTTypeInfo{ - NativeType: NativeType{proto: protoVersion4, typ: TypeUDT}, - Name: "myudt", - KeySpace: "myks", + Name: "myudt", + Keyspace: "myks", Elements: []UDTField{ { Name: "first", - Type: NativeType{proto: protoVersion4, typ: TypeAscii}, + Type: varcharLikeTypeInfo{typ: TypeAscii}, }, { Name: "second", - Type: NativeType{proto: protoVersion4, typ: TypeSmallInt}, + Type: smallIntTypeInfo{}, }, }, } - data := bytesWithLength( // UDT - bytesWithLength([]byte("Hello")), // first - bytesWithLength([]byte("\x00\x2a")), // second + data := append( + bytesWithLength([]byte("Hello")), // first + bytesWithLength([]byte("\x00\x2a"))..., // second ) value := map[string]interface{}{} expectedErr := UnmarshalError("can not unmarshal into non-pointer map[string]interface {}") @@ -2525,6 +2474,18 @@ func TestUnmarshalUDT(t *testing.T) { t.Errorf("(%v=>%T): %#v returned error %#v, want %#v.", info, value, value, err, expectedErr) } + + err := Unmarshal(info, data, &value) + if err != nil { + t.Errorf("unexpected error %v", err) + } else { + if value["first"] != "Hello" { + t.Errorf(`Expected "Hello" for first but received: %T(%v)`, value["first"], value["first"]) + } + if value["second"] != int16(42) { + t.Errorf(`Expected 42 for second but received: %T(%v)`, value["second"], value["second"]) + } + } } // bytesWithLength concatenates all data slices and prepends the total length as uint32. @@ -2546,3 +2507,361 @@ func bytesWithLength(data ...[]byte) []byte { } return ret } + +func TestUnmarshal_PointerToPointer(t *testing.T) { + var a string + b := &a + data := []byte("foo") + info := varcharLikeTypeInfo{ + typ: TypeVarchar, + } + err := Unmarshal(info, data, &b) + if err != nil { + t.Error(err) + } else { + if b == nil || *b != "foo" { + t.Errorf("expected b to be *foo, got %+v", b) + } + if a != "" { + t.Errorf("expected a to be empty, got %v", a) + } + } +} + +func TestUnmarshal_PointerToInterface(t *testing.T) { + var a string + var b interface{} = &a + data := []byte("foo") + info := varcharLikeTypeInfo{ + typ: TypeVarchar, + } + err := Unmarshal(info, data, &b) + if err != nil { + t.Error(err) + } else { + if b == nil { + t.Error("expected b to be *foo, got nil") + } else if bstr, ok := b.(*string); !ok { + t.Errorf("expected b to be *foo, got %T", b) + } else if bstr == nil || *bstr != "foo" { + t.Errorf("expected b to be *foo, got %+v", bstr) + } + if a != "foo" { + t.Errorf("expected a to be foo, got %v", a) + } + } +} + +func BenchmarkUnmarshal_BigInt(b *testing.B) { + b.ReportAllocs() + src := []byte("\x01\x02\x03\x04\x05\x06\x07\x08") + var dst int64 + var ti TypeInfo = GlobalTypes.fastTypeInfoLookup(TypeBigInt) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Blob(b *testing.B) { + b.ReportAllocs() + src := []byte("hello\x00") + var dst []byte + var ti TypeInfo = GlobalTypes.fastTypeInfoLookup(TypeBlob) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Boolean(b *testing.B) { + b.ReportAllocs() + src := []byte("\x01") + var dst bool + var ti TypeInfo = GlobalTypes.fastTypeInfoLookup(TypeBoolean) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Date(b *testing.B) { + b.ReportAllocs() + src := []byte("\x80\x00\x43\x31") + var dst time.Time + var ti TypeInfo = GlobalTypes.fastTypeInfoLookup(TypeDate) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Decimal(b *testing.B) { + b.ReportAllocs() + src := []byte("\x00\x00\x00\x13*\xF8\xC4\xDF\xEB]o") + dst := new(inf.Dec) + var ti TypeInfo = NewNativeType(4, TypeDecimal, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Double(b *testing.B) { + b.ReportAllocs() + src := []byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1") + var dst float64 + var ti TypeInfo = NewNativeType(4, TypeDouble, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Duration(b *testing.B) { + b.ReportAllocs() + src := []byte("\x02\x04\x80\xe6") + var dst Duration + var ti TypeInfo = NewNativeType(4, TypeDuration, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Float(b *testing.B) { + b.ReportAllocs() + src := []byte("\x40\x49\x0f\xdb") + var dst float32 + var ti TypeInfo = NewNativeType(4, TypeFloat, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Int(b *testing.B) { + b.ReportAllocs() + src := []byte("\x01\x02\x03\x04") + var dst int32 + var ti TypeInfo = NewNativeType(4, TypeInt, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Inet(b *testing.B) { + b.ReportAllocs() + src := []byte("\x7F\x00\x00\x01") + var dst net.IP + var ti TypeInfo = NewNativeType(4, TypeInet, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_SmallInt(b *testing.B) { + b.ReportAllocs() + src := []byte("\x00\xff") + var dst int16 + var ti TypeInfo = NewNativeType(4, TypeSmallInt, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Time(b *testing.B) { + b.ReportAllocs() + src := []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8") + var dst time.Duration + var ti TypeInfo = NewNativeType(4, TypeTime, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Timestamp(b *testing.B) { + b.ReportAllocs() + src := []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8") + var dst int64 + var ti TypeInfo = NewNativeType(4, TypeTimestamp, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_TinyInt(b *testing.B) { + b.ReportAllocs() + src := []byte("\x01") + var dst int8 + var ti TypeInfo = NewNativeType(4, TypeTinyInt, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_UUID(b *testing.B) { + b.ReportAllocs() + src := make([]byte, 16) + dst := UUID{} + var ti TypeInfo = NewNativeType(4, TypeUUID, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Varchar(b *testing.B) { + b.ReportAllocs() + src := []byte("foo") + dst := make([]byte, len(src)) + var ti TypeInfo = NewNativeType(4, TypeVarchar, "") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_List(b *testing.B) { + b.ReportAllocs() + src := []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02") + dst := make([]int32, 2) + var ti TypeInfo = CollectionType{ + typ: TypeList, + Elem: intTypeInfo{}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Set(b *testing.B) { + b.ReportAllocs() + src := []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02") + dst := make([]int32, 2) + var ti TypeInfo = CollectionType{ + typ: TypeSet, + Elem: intTypeInfo{}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_Map(b *testing.B) { + b.ReportAllocs() + src := []byte("\x00\x00\x00\x01\x00\x00\x00\x03foo\x00\x00\x00\x04\x00\x00\x00\x01") + dst := map[string]int32{} + var ti TypeInfo = CollectionType{ + typ: TypeMap, + Key: varcharLikeTypeInfo{typ: TypeVarchar}, + Elem: intTypeInfo{}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_TupleStrings(b *testing.B) { + b.ReportAllocs() + src := []byte("\x00\x00\x00\x03foo\x00\x00\x00\x03bar") + dst := make([]string, 2) + var ti TypeInfo = TupleTypeInfo{ + Elems: []TypeInfo{ + varcharLikeTypeInfo{typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshal_TupleInterfaces(b *testing.B) { + b.ReportAllocs() + src := []byte("\x00\x00\x00\x03foo\x00\x00\x00\x03bar") + dst := make([]interface{}, 2) + var ti TypeInfo = TupleTypeInfo{ + Elems: []TypeInfo{ + varcharLikeTypeInfo{typ: TypeVarchar}, + varcharLikeTypeInfo{typ: TypeVarchar}, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Unmarshal(ti, src, &dst); err != nil { + b.Fatal(err) + } + } +} diff --git a/metadata.go b/metadata.go index c3e6ad33d..a8996f0da 100644 --- a/metadata.go +++ b/metadata.go @@ -304,8 +304,8 @@ func (s *schemaDescriber) refreshSchema(keyspaceName string) error { } // organize the schema data - compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns, functions, aggregates, userTypes, - materializedViews, s.session.logger) + compileMetadata(s.session, keyspace, tables, columns, functions, aggregates, userTypes, + materializedViews) // update the cache s.cache[keyspaceName] = keyspace @@ -319,7 +319,7 @@ func (s *schemaDescriber) refreshSchema(keyspaceName string) error { // Links the metadata objects together and derives the column composition of // the partition key and clustering key for a table. func compileMetadata( - protoVersion int, + session *Session, keyspace *KeyspaceMetadata, tables []TableMetadata, columns []ColumnMetadata, @@ -327,8 +327,7 @@ func compileMetadata( aggregates []AggregateMetadata, uTypes []UserTypeMetadata, materializedViews []MaterializedViewMetadata, - logger StdLogger, -) { +) error { keyspace.Tables = make(map[string]*TableMetadata) for i := range tables { tables[i].Columns = make(map[string]*ColumnMetadata) @@ -356,17 +355,26 @@ func compileMetadata( } // add columns from the schema data + var err error for i := range columns { col := &columns[i] // decode the validator for TypeInfo and order if col.ClusteringOrder != "" { // Cassandra 3.x+ - col.Type = getCassandraType(col.Validator, byte(protoVersion), logger) + col.Type, err = session.types.typeInfoFromString(session.cfg.ProtoVersion, col.Validator) + if err != nil { + // we don't error out completely for unknown types because we didn't before + // and the caller might not care about this type + col.Type = unknownTypeInfo(col.Validator) + } col.Order = ASC if col.ClusteringOrder == "desc" { col.Order = DESC } } else { - validatorParsed := parseType(col.Validator, byte(protoVersion), logger) + validatorParsed, err := parseType(session, col.Validator) + if err != nil { + return err + } col.Type = validatorParsed.types[0] col.Order = ASC if validatorParsed.reversed[0] { @@ -387,11 +395,11 @@ func compileMetadata( table.OrderedColumns = append(table.OrderedColumns, col.Name) } - compileV2Metadata(tables, protoVersion, logger) + return compileV2Metadata(tables, session) } // The simpler compile case for V2+ protocol -func compileV2Metadata(tables []TableMetadata, protoVer int, logger StdLogger) { +func compileV2Metadata(tables []TableMetadata, session *Session) error { for i := range tables { table := &tables[i] @@ -399,7 +407,10 @@ func compileV2Metadata(tables []TableMetadata, protoVer int, logger StdLogger) { table.ClusteringColumns = make([]*ColumnMetadata, clusteringColumnCount) if table.KeyValidator != "" { - keyValidatorParsed := parseType(table.KeyValidator, byte(protoVer), logger) + keyValidatorParsed, err := parseType(session, table.KeyValidator) + if err != nil { + return err + } table.PartitionKey = make([]*ColumnMetadata, len(keyValidatorParsed.types)) } else { // Cassandra 3.x+ partitionKeyCount := componentColumnCountOfType(table.Columns, ColumnPartitionKey) @@ -415,6 +426,7 @@ func compileV2Metadata(tables []TableMetadata, protoVer int, logger StdLogger) { } } } + return nil } // returns the count of coluns with the given "kind" value. @@ -753,13 +765,6 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, return columns, nil } -func getTypeInfo(t string, protoVer byte, logger StdLogger) TypeInfo { - if strings.HasPrefix(t, apacheCassandraTypePrefix) { - return getCassandraLongType(t, protoVer, logger) - } - return getCassandraType(t, protoVer, logger) -} - func getUserTypeMetadata(session *Session, keyspaceName string) ([]UserTypeMetadata, error) { var tableName string if session.useSystemSchema { @@ -790,10 +795,16 @@ func getUserTypeMetadata(session *Session, keyspaceName string) ([]UserTypeMetad } uType.FieldTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - uType.FieldTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) + uType.FieldTypes[i], err = session.types.typeInfoFromString(session.cfg.ProtoVersion, argumentType) + if err != nil { + // we don't error out completely for unknown types because we didn't before + // and the caller might not care about this type + uType.FieldTypes[i] = unknownTypeInfo(argumentType) + } } uTypes = append(uTypes, uType) } + // TODO: if a UDT refers to another UDT, should we resolve it? if err := rows.Err(); err != nil { return nil, err @@ -1043,10 +1054,20 @@ func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMeta if err != nil { return nil, err } - function.ReturnType = getTypeInfo(returnType, byte(session.cfg.ProtoVersion), session.logger) + function.ReturnType, err = session.types.typeInfoFromString(session.cfg.ProtoVersion, returnType) + if err != nil { + // we don't error out completely for unknown types because we didn't before + // and the caller might not care about this type + function.ReturnType = unknownTypeInfo(returnType) + } function.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - function.ArgumentTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) + function.ArgumentTypes[i], err = session.types.typeInfoFromString(session.cfg.ProtoVersion, argumentType) + if err != nil { + // we don't error out completely for unknown types because we didn't before + // and the caller might not care about this type + function.ArgumentTypes[i] = unknownTypeInfo(argumentType) + } } functions = append(functions, function) } @@ -1100,11 +1121,26 @@ func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMe if err != nil { return nil, err } - aggregate.ReturnType = getTypeInfo(returnType, byte(session.cfg.ProtoVersion), session.logger) - aggregate.StateType = getTypeInfo(stateType, byte(session.cfg.ProtoVersion), session.logger) + aggregate.ReturnType, err = session.types.typeInfoFromString(session.cfg.ProtoVersion, returnType) + if err != nil { + // we don't error out completely for unknown types because we didn't before + // and the caller might not care about this type + aggregate.ReturnType = unknownTypeInfo(returnType) + } + aggregate.StateType, err = session.types.typeInfoFromString(session.cfg.ProtoVersion, stateType) + if err != nil { + // we don't error out completely for unknown types because we didn't before + // and the caller might not care about this type + aggregate.StateType = unknownTypeInfo(stateType) + } aggregate.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) + aggregate.ArgumentTypes[i], err = session.types.typeInfoFromString(session.cfg.ProtoVersion, argumentType) + if err != nil { + // we don't error out completely for unknown types because we didn't before + // and the caller might not care about this type + aggregate.ArgumentTypes[i] = unknownTypeInfo(argumentType) + } } aggregates = append(aggregates, aggregate) } @@ -1118,10 +1154,9 @@ func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMe // type definition parser state type typeParser struct { - input string - index int - logger StdLogger - proto byte + input string + index int + session *Session } // the type definition parser result @@ -1133,20 +1168,37 @@ type typeParserResult struct { } // Parse the type definition used for validator and comparator schema data -func parseType(def string, protoVer byte, logger StdLogger) typeParserResult { - parser := &typeParser{input: def, proto: protoVer, logger: logger} - return parser.parse() +func parseType(session *Session, def string) (typeParserResult, error) { + parser := &typeParser{ + input: def, + session: session, + } + res, ok, err := parser.parse() + if err != nil { + return typeParserResult{}, err + } + if !ok { + t, err := session.types.typeInfoFromString(session.cfg.ProtoVersion, def) + if err != nil { + // we don't error out completely for unknown types because we didn't before + // and the caller might not care about this type + t = unknownTypeInfo(def) + } + // treat this is a custom type + return typeParserResult{ + isComposite: false, + types: []TypeInfo{t}, + reversed: []bool{false}, + collections: nil, + }, nil + } + return res, err } const ( REVERSED_TYPE = "org.apache.cassandra.db.marshal.ReversedType" COMPOSITE_TYPE = "org.apache.cassandra.db.marshal.CompositeType" COLLECTION_TYPE = "org.apache.cassandra.db.marshal.ColumnToCollectionType" - LIST_TYPE = "org.apache.cassandra.db.marshal.ListType" - SET_TYPE = "org.apache.cassandra.db.marshal.SetType" - MAP_TYPE = "org.apache.cassandra.db.marshal.MapType" - UDT_TYPE = "org.apache.cassandra.db.marshal.UserType" - TUPLE_TYPE = "org.apache.cassandra.db.marshal.TupleType" VECTOR_TYPE = "org.apache.cassandra.db.marshal.VectorType" ) @@ -1155,8 +1207,8 @@ type typeParserClassNode struct { name string params []typeParserParamNode // this is the segment of the input string that defined this node - input string - proto byte + input string + session *Session } // represents a class parameter in the type def AST @@ -1165,25 +1217,14 @@ type typeParserParamNode struct { class typeParserClassNode } -func (t *typeParser) parse() typeParserResult { +func (t *typeParser) parse() (typeParserResult, bool, error) { // parse the AST ast, ok := t.parseClassNode() if !ok { - // treat this is a custom type - return typeParserResult{ - isComposite: false, - types: []TypeInfo{ - NativeType{ - typ: TypeCustom, - custom: t.input, - proto: t.proto, - }, - }, - reversed: []bool{false}, - collections: nil, - } + return typeParserResult{}, false, nil } + var err error // interpret the AST if strings.HasPrefix(ast.name, COMPOSITE_TYPE) { count := len(ast.params) @@ -1195,22 +1236,14 @@ func (t *typeParser) parse() typeParserResult { count-- for _, param := range last.class.params { - // decode the name - var name string decoded, err := hex.DecodeString(*param.name) if err != nil { - t.logger.Printf( - "Error parsing type '%s', contains collection name '%s' with an invalid format: %v", - t.input, - *param.name, - err, - ) - // just use the provided name - name = *param.name - } else { - name = string(decoded) + return typeParserResult{}, false, fmt.Errorf("type '%s' contains collection name '%s' with an invalid format: %w", t.input, *param.name, err) + } + collections[string(decoded)], err = param.class.asTypeInfo() + if err != nil { + return typeParserResult{}, false, err } - collections[name] = param.class.asTypeInfo() } } @@ -1223,7 +1256,10 @@ func (t *typeParser) parse() typeParserResult { if reversed[i] { class = class.params[0].class } - types[i] = class.asTypeInfo() + types[i], err = class.asTypeInfo() + if err != nil { + return typeParserResult{}, false, err + } } return typeParserResult{ @@ -1231,7 +1267,7 @@ func (t *typeParser) parse() typeParserResult { types: types, reversed: reversed, collections: collections, - } + }, true, nil } else { // not composite, so one type class := *ast @@ -1239,57 +1275,43 @@ func (t *typeParser) parse() typeParserResult { if reversed { class = class.params[0].class } - typeInfo := class.asTypeInfo() + typeInfo, err := class.asTypeInfo() + if err != nil { + return typeParserResult{}, false, err + } return typeParserResult{ isComposite: false, types: []TypeInfo{typeInfo}, reversed: []bool{reversed}, - } + }, true, nil } } -func (class *typeParserClassNode) asTypeInfo() TypeInfo { - if strings.HasPrefix(class.name, LIST_TYPE) { - elem := class.params[0].class.asTypeInfo() - return CollectionType{ - NativeType: NativeType{ - typ: TypeList, - proto: class.proto, - }, - Elem: elem, - } - } - if strings.HasPrefix(class.name, SET_TYPE) { - elem := class.params[0].class.asTypeInfo() - return CollectionType{ - NativeType: NativeType{ - typ: TypeSet, - proto: class.proto, - }, - Elem: elem, - } - } - if strings.HasPrefix(class.name, MAP_TYPE) { - key := class.params[0].class.asTypeInfo() - elem := class.params[1].class.asTypeInfo() - return CollectionType{ - NativeType: NativeType{ - typ: TypeMap, - proto: class.proto, - }, - Key: key, - Elem: elem, +func (class *typeParserClassNode) asTypeInfo() (TypeInfo, error) { + // TODO: should we just use types.typeInfoFromString(class.input) but then it + // wouldn't be reversed + t, ok := class.session.types.getType(class.name) + if !ok { + return unknownTypeInfo(class.input), nil + } + var params string + if len(class.params) > 0 { + // compile the params just like they are in 3.x + for i, param := range class.params { + if i > 0 { + params += ", " + } + if param.name != nil { + params += (*param.name) + ":" + } + params += param.class.name } } - - // must be a simple type or custom type - info := NativeType{typ: getApacheCassandraType(class.name), proto: class.proto} - if info.typ == TypeCustom { - // add the entire class definition - info.custom = class.input - } - return info + // we are returning the error here because we're failing to parse it, but if + // this ends up unnecessarily breaking things we could do the same thing as + // above and return unknownTypeInfo + return t.TypeInfoFromString(class.session.cfg.ProtoVersion, params) } // CLASS := ID [ PARAMS ] @@ -1311,10 +1333,10 @@ func (t *typeParser) parseClassNode() (node *typeParserClassNode, ok bool) { endIndex := t.index node = &typeParserClassNode{ - name: name, - params: params, - input: t.input[startIndex:endIndex], - proto: t.proto, + name: name, + params: params, + input: t.input[startIndex:endIndex], + session: t.session, } return node, true } diff --git a/metadata_test.go b/metadata_test.go index 21829b33d..b6a7a88f2 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -36,8 +36,13 @@ import ( // Tests V1 and V2 metadata "compilation" from example data which might be returned // from metadata schema queries (see getKeyspaceMetadata, getTableMetadata, and getColumnMetadata) func TestCompileMetadata(t *testing.T) { - // V1 tests - these are all based on real examples from the integration test ccm cluster - log := &defaultLogger{} + session := &Session{ + cfg: ClusterConfig{ + ProtoVersion: 1, + }, + logger: &defaultLogger{}, + types: GlobalTypes, + } // V2 test - V2+ protocol is simpler so here are some toy examples to verify that the mapping works keyspace := &KeyspaceMetadata{ Name: "V2Keyspace", @@ -101,7 +106,7 @@ func TestCompileMetadata(t *testing.T) { Validator: "org.apache.cassandra.db.marshal.UTF8Type", }, } - compileMetadata(2, keyspace, tables, columns, nil, nil, nil, nil, log) + compileMetadata(session, keyspace, tables, columns, nil, nil, nil, nil) assertKeyspaceMetadata( t, keyspace, @@ -112,19 +117,25 @@ func TestCompileMetadata(t *testing.T) { PartitionKey: []*ColumnMetadata{ { Name: "Key1", - Type: NativeType{typ: TypeVarchar}, + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, }, }, ClusteringColumns: []*ColumnMetadata{}, Columns: map[string]*ColumnMetadata{ "KEY1": { Name: "KEY1", - Type: NativeType{typ: TypeVarchar}, + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, Kind: ColumnPartitionKey, }, "Key1": { Name: "Key1", - Type: NativeType{typ: TypeVarchar}, + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, Kind: ColumnPartitionKey, }, }, @@ -133,42 +144,56 @@ func TestCompileMetadata(t *testing.T) { PartitionKey: []*ColumnMetadata{ { Name: "Column1", - Type: NativeType{typ: TypeVarchar}, + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, }, }, ClusteringColumns: []*ColumnMetadata{ { - Name: "Column2", - Type: NativeType{typ: TypeVarchar}, + Name: "Column2", + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, Order: ASC, }, { - Name: "Column3", - Type: NativeType{typ: TypeVarchar}, + Name: "Column3", + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, Order: DESC, }, }, Columns: map[string]*ColumnMetadata{ "Column1": { Name: "Column1", - Type: NativeType{typ: TypeVarchar}, + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, Kind: ColumnPartitionKey, }, "Column2": { - Name: "Column2", - Type: NativeType{typ: TypeVarchar}, + Name: "Column2", + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, Order: ASC, Kind: ColumnClusteringKey, }, "Column3": { - Name: "Column3", - Type: NativeType{typ: TypeVarchar}, + Name: "Column3", + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, Order: DESC, Kind: ColumnClusteringKey, }, "Column4": { Name: "Column4", - Type: NativeType{typ: TypeVarchar}, + Type: varcharLikeTypeInfo{ + typ: TypeVarchar, + }, Kind: ColumnRegular, }, }, @@ -402,10 +427,20 @@ func assertParseNonCompositeType( typeExpected assertTypeInfo, ) { - log := &defaultLogger{} - result := parseType(def, protoVersion4, log) + session := &Session{ + cfg: ClusterConfig{ + ProtoVersion: 4, + }, + logger: &defaultLogger{}, + types: GlobalTypes, + } + result, err := parseType(session, def) + if err != nil { + t.Fatal(err) + } if len(result.reversed) != 1 { t.Errorf("%s expected %d reversed values but there were %d", def, 1, len(result.reversed)) + return } assertParseNonCompositeTypes( @@ -433,8 +468,17 @@ func assertParseCompositeType( collectionsExpected map[string]assertTypeInfo, ) { - log := &defaultLogger{} - result := parseType(def, protoVersion4, log) + session := &Session{ + cfg: ClusterConfig{ + ProtoVersion: 4, + }, + logger: &defaultLogger{}, + types: GlobalTypes, + } + result, err := parseType(session, def) + if err != nil { + t.Fatal(err) + } if len(result.reversed) != len(typesExpected) { t.Errorf("%s expected %d reversed values but there were %d", def, len(typesExpected), len(result.reversed)) } @@ -502,11 +546,17 @@ func assertParseNonCompositeTypes( // check the type if typeActual.Type() != typeExpected.Type { - t.Errorf("%s: Expected to parse Type to %s but was %s", context, typeExpected.Type, typeActual.Type()) + t.Errorf("%s: Expected to parse Type to %v but was %v", context, typeExpected.Type, typeActual.Type()) } - // check the custom - if typeActual.Custom() != typeExpected.Custom { - t.Errorf("%s: Expected to parse Custom %s but was %s", context, typeExpected.Custom, typeActual.Custom()) + if typeExpected.Custom != "" { + ct, ok := typeActual.(unknownTypeInfo) + if !ok { + t.Errorf("%s: Expected to get unknownCustomTypeInfo but was %T", context, typeActual) + continue + } + if string(ct) != typeExpected.Custom { + t.Errorf("%s: Expected to parse Custom %s but was %s", context, typeExpected.Custom, string(ct)) + } } collection, _ := typeActual.(CollectionType) diff --git a/session.go b/session.go index dd972f76f..bdda06406 100644 --- a/session.go +++ b/session.go @@ -65,6 +65,7 @@ type Session struct { hostSource *ringDescriber ringRefresher *refreshDebouncer stmtsLRU *preparedLRU + types *RegisteredTypes connCfg *ConnConfig @@ -161,6 +162,11 @@ func NewSession(cfg ClusterConfig) (*Session, error) { logger: cfg.logger(), trace: cfg.Tracer, } + if cfg.RegisteredTypes == nil { + s.types = GlobalTypes.Copy() + } else { + s.types = cfg.RegisteredTypes.Copy() + } s.schemaDescriber = newSchemaDescriber(s) @@ -1634,7 +1640,7 @@ func (iter *Iter) Scanner() Scanner { } func (iter *Iter) readColumn() ([]byte, error) { - return iter.framer.readBytesInternal() + return iter.framer.readBytes() } // Scan consumes the next row of the iterator and copies the columns of the diff --git a/token_test.go b/token_test.go index 90e0d4fd8..1152760e8 100644 --- a/token_test.go +++ b/token_test.go @@ -48,7 +48,7 @@ func TestMurmur3Partitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil - pk, _ := marshalInt(nil, 1) + pk, _ := Marshal(intTypeInfo{}, 1) token = murmur3Partitioner{}.Hash(pk) if token == nil { t.Fatal("token was nil") @@ -73,7 +73,7 @@ func TestOrderedPartitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil p := orderedPartitioner{} - pk, _ := marshalInt(nil, 1) + pk, _ := Marshal(intTypeInfo{}, 1) token := p.Hash(pk) if token == nil { t.Fatal("token was nil") @@ -109,7 +109,7 @@ func TestRandomPartitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil p := randomPartitioner{} - pk, _ := marshalInt(nil, 1) + pk, _ := Marshal(intTypeInfo{}, 1) token := p.Hash(pk) if token == nil { t.Fatal("token was nil") diff --git a/tuple_test.go b/tuple_test.go index de5317b34..9194f0c75 100644 --- a/tuple_test.go +++ b/tuple_test.go @@ -138,10 +138,10 @@ func TestTuple_TupleNotSet(t *testing.T) { if err := iter.Scan(x, y); err != nil { t.Fatal(err) } - if x == nil || *x != 1 { + if *x != 1 { t.Fatalf("x should be %d got %+#v, value=%d", 1, x, *x) } - if y == nil || *y != 2 { + if *y != 2 { t.Fatalf("y should be %d got %+#v, value=%d", 2, y, *y) } @@ -150,10 +150,10 @@ func TestTuple_TupleNotSet(t *testing.T) { if err := iter.Scan(x, y); err != nil { t.Fatal(err) } - if x == nil || *x != 0 { + if *x != 0 { t.Fatalf("x should be %d got %+#v, value=%d", 0, x, *x) } - if y == nil || *y != 0 { + if *y != 0 { t.Fatalf("y should be %d got %+#v, value=%d", 0, y, *y) } } diff --git a/types.go b/types.go new file mode 100644 index 000000000..ce8fee0dc --- /dev/null +++ b/types.go @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gocql + +import ( + "errors" + "fmt" + "strings" + "sync" +) + +// CQLType is the interface that must be implemented by all registered types. +// For simple types, you can just wrap a TypeInfo with SimpleCQLType. +type CQLType interface { + // Params should return a new slice of zero values of the closest Go types to' + // be filled when parsing a frame. These values associcated with the types + // are sent to TypeInfoFromParams after being read from the frame. + // + // The supported types are: Type, TypeInfo, []UDTField, []byte, string, int, uint16, byte. + // Pointers are followed when filling the params but the slice sent to + // TypeInfoFromParams will only contain the underlying values. Since + // TypeInfo(nil) isn't supported, you can send (*TypeInfo)(nil) and + // TypeInfoFromParams will get a TypeInfo (not *TypeInfo since that's not usable). + // + // If no params are needed this can return a nil slice and TypeInfoFromParams + // will be sent nil. + Params(proto int) []interface{} + + // TypeInfoFromParams should return a TypeInfo implementation for the type + // with the given filled parameters. See the Params() method for what values + // to expect. + TypeInfoFromParams(proto int, params []interface{}) (TypeInfo, error) + + // TypeInfoFromString should return a TypeInfo implementation for the type with + // the given names/classes. Only the portion within the parantheses or arrows + // are passed to this function. For simple types, the name passed might be empty. + TypeInfoFromString(proto int, name string) (TypeInfo, error) +} + +// SimpleCQLType is a convenience wrapper around a TypeInfo that implements +// CQLType by returning nil for Params, and the TypeInfo for TypeInfoFromParams +// and TypeInfoFromString. +type SimpleCQLType struct { + TypeInfo +} + +// Params returns nil. +func (SimpleCQLType) Params(int) []interface{} { + return nil +} + +// TypeInfoFromParams returns the wrapped TypeInfo. +func (s SimpleCQLType) TypeInfoFromParams(proto int, params []interface{}) (TypeInfo, error) { + return s.TypeInfo, nil +} + +// TypeInfoFromString returns the wrapped TypeInfo. +func (s SimpleCQLType) TypeInfoFromString(proto int, name string) (TypeInfo, error) { + return s.TypeInfo, nil +} + +// TypeInfo describes a Cassandra specific data type and handles marshalling +// and unmarshalling. +type TypeInfo interface { + // Type returns the Type id for the TypeInfo. + Type() Type + + // Zero returns the Go zero value. For types that directly map to a Go type like + // list it should return []int(nil) but for complex types like a + // tuple it should be []interface{}{int(0), bool(false)}. + Zero() interface{} + + // Marshal should marshal the value for the given TypeInfo into a byte slice + Marshal(value interface{}) ([]byte, error) + + // Unmarshal should unmarshal the byte slice into the value for the given + // TypeInfo. + Unmarshal(data []byte, value interface{}) error +} + +// RegisteredTypes is a collection of CQL types +type RegisteredTypes struct { + byType map[Type]CQLType + simples map[Type]TypeInfo + byString map[string]Type + custom map[string]CQLType + + // this mutex is only used for registration and after that it's assumed that + // the types are immutable + mut sync.Mutex + initialized sync.Once +} + +func (r *RegisteredTypes) init() { + r.initialized.Do(func() { + if r.byType == nil { + r.byType = map[Type]CQLType{} + } + if r.simples == nil { + r.simples = map[Type]TypeInfo{} + } + if r.byString == nil { + r.byString = map[string]Type{} + } + if r.custom == nil { + r.custom = map[string]CQLType{} + } + }) +} + +func (r *RegisteredTypes) addDefaultTypes() { + r.init() + r.mut.Lock() + defer r.mut.Unlock() + + r.mustRegisterType(TypeAscii, "ascii", SimpleCQLType{varcharLikeTypeInfo{ + typ: TypeAscii, + }}) + r.mustRegisterAlias("AsciiType", "ascii") + + r.mustRegisterType(TypeBigInt, "bigint", SimpleCQLType{bigIntLikeTypeInfo{ + typ: TypeBigInt, + }}) + r.mustRegisterAlias("LongType", "bigint") + + r.mustRegisterType(TypeBlob, "blob", SimpleCQLType{varcharLikeTypeInfo{ + typ: TypeBlob, + }}) + r.mustRegisterAlias("BytesType", "blob") + + r.mustRegisterType(TypeBoolean, "boolean", SimpleCQLType{booleanTypeInfo{}}) + r.mustRegisterAlias("BooleanType", "boolean") + + r.mustRegisterType(TypeCounter, "counter", SimpleCQLType{bigIntLikeTypeInfo{ + typ: TypeCounter, + }}) + r.mustRegisterAlias("CounterColumnType", "counter") + + r.mustRegisterType(TypeDate, "date", SimpleCQLType{dateTypeInfo{}}) + r.mustRegisterAlias("SimpleDateType", "date") + + r.mustRegisterType(TypeDecimal, "decimal", SimpleCQLType{decimalTypeInfo{}}) + r.mustRegisterAlias("DecimalType", "decimal") + + r.mustRegisterType(TypeDouble, "double", SimpleCQLType{doubleTypeInfo{}}) + r.mustRegisterAlias("DoubleType", "double") + + r.mustRegisterType(TypeDuration, "duration", SimpleCQLType{durationTypeInfo{}}) + r.mustRegisterAlias("DurationType", "duration") + + r.mustRegisterType(TypeFloat, "float", SimpleCQLType{floatTypeInfo{}}) + r.mustRegisterAlias("FloatType", "float") + + r.mustRegisterType(TypeInet, "inet", SimpleCQLType{inetType{}}) + r.mustRegisterAlias("InetAddressType", "inet") + + r.mustRegisterType(TypeInt, "int", SimpleCQLType{intTypeInfo{}}) + r.mustRegisterAlias("Int32Type", "int") + + r.mustRegisterType(TypeSmallInt, "smallint", SimpleCQLType{smallIntTypeInfo{}}) + r.mustRegisterAlias("ShortType", "smallint") + + r.mustRegisterType(TypeText, "text", SimpleCQLType{varcharLikeTypeInfo{ + typ: TypeText, + }}) + + r.mustRegisterType(TypeTime, "time", SimpleCQLType{timeTypeInfo{}}) + r.mustRegisterAlias("TimeType", "time") + + r.mustRegisterType(TypeTimestamp, "timestamp", SimpleCQLType{timestampTypeInfo{}}) + r.mustRegisterAlias("TimestampType", "timestamp") + // DateType was a timestamp when date didn't exist + r.mustRegisterAlias("DateType", "timestamp") + + r.mustRegisterType(TypeTimeUUID, "timeuuid", SimpleCQLType{timeUUIDType{}}) + r.mustRegisterAlias("TimeUUIDType", "timeuuid") + + r.mustRegisterType(TypeTinyInt, "tinyint", SimpleCQLType{tinyIntTypeInfo{}}) + r.mustRegisterAlias("ByteType", "tinyint") + + r.mustRegisterType(TypeUUID, "uuid", SimpleCQLType{uuidType{}}) + r.mustRegisterAlias("UUIDType", "uuid") + r.mustRegisterAlias("LexicalUUIDType", "uuid") + + r.mustRegisterType(TypeVarchar, "varchar", SimpleCQLType{varcharLikeTypeInfo{ + typ: TypeVarchar, + }}) + r.mustRegisterAlias("UTF8Type", "varchar") + + r.mustRegisterType(TypeVarint, "varint", SimpleCQLType{varintTypeInfo{}}) + r.mustRegisterAlias("IntegerType", "varint") + + // these types need references to the registered types + r.mustRegisterType(TypeList, "list", listSetCQLType{ + typ: TypeList, + types: r, + }) + r.mustRegisterAlias("ListType", "list") + + r.mustRegisterType(TypeSet, "set", listSetCQLType{ + typ: TypeSet, + types: r, + }) + r.mustRegisterAlias("SetType", "set") + + r.mustRegisterType(TypeMap, "map", mapCQLType{ + types: r, + }) + r.mustRegisterAlias("MapType", "map") + + r.mustRegisterType(TypeTuple, "tuple", tupleCQLType{ + types: r, + }) + r.mustRegisterAlias("TupleType", "tuple") + + r.mustRegisterType(TypeUDT, "udt", udtCQLType{ + types: r, + }) + r.mustRegisterAlias("UserType", "udt") + + r.mustRegisterCustom("vector", vectorCQLType{ + types: r, + }) + r.mustRegisterAlias("VectorType", "vector") +} + +// RegisterType registers a new CQL data type. Type should be the CQL id for +// the type. Name is the name of the type as returned in the metadata for the +// column. CQLType is the implementation of the type. +// This function must not be called after a session has been created. +func (r *RegisteredTypes) RegisterType(typ Type, name string, t CQLType) error { + r.init() + r.mut.Lock() + defer r.mut.Unlock() + return r.registerType(typ, name, t) +} + +func (r *RegisteredTypes) registerType(typ Type, name string, t CQLType) error { + if typ == TypeCustom { + return errors.New("custom types must be registered with RegisterCustom") + } + + if _, ok := r.byType[typ]; ok { + return fmt.Errorf("type %d already registered", typ) + } + if _, ok := r.byString[name]; ok { + return fmt.Errorf("type name %s already registered", name) + } + r.byType[typ] = t + if s, ok := t.(SimpleCQLType); ok { + r.simples[typ] = s.TypeInfo + } + r.byString[name] = typ + return nil +} + +func (r *RegisteredTypes) mustRegisterType(typ Type, name string, t CQLType) { + if err := r.registerType(typ, name, t); err != nil { + panic(err) + } +} + +// RegisterCustom registers a new custom CQL type. Name is the name of the type +// as returned in the metadata for the column. CQLType is the implementation of +// the type. +// This function must not be called after a session has been created. +func (r *RegisteredTypes) RegisterCustom(name string, t CQLType) error { + r.init() + r.mut.Lock() + defer r.mut.Unlock() + return r.registerCustom(name, t) +} + +func (r *RegisteredTypes) registerCustom(name string, t CQLType) error { + if r.custom == nil { + r.custom = map[string]CQLType{} + } + if _, ok := r.custom[name]; ok { + return fmt.Errorf("custom type %s already registered", name) + } + if _, ok := r.byString[name]; ok { + return fmt.Errorf("type name %s already registered", name) + } + + r.custom[name] = t + r.byString[name] = TypeCustom + return nil +} + +func (r *RegisteredTypes) mustRegisterCustom(name string, t CQLType) { + if err := r.registerCustom(name, t); err != nil { + panic(err) + } +} + +// AddAlias adds an alias for an already registered type. If you expect a type +// to be referenced as multiple different types or if you need to add the Java +// marshal class for a type you should call this method. +// This function must not be called after a session has been created. +func (r *RegisteredTypes) AddAlias(name, as string) error { + r.init() + r.mut.Lock() + defer r.mut.Unlock() + return r.registerAlias(name, as) +} + +func (r *RegisteredTypes) registerAlias(name, as string) error { + if strings.HasPrefix(name, apacheCassandraTypePrefix) { + name = strings.TrimPrefix(name, apacheCassandraTypePrefix) + } + if strings.HasPrefix(as, apacheCassandraTypePrefix) { + as = strings.TrimPrefix(as, apacheCassandraTypePrefix) + } + if _, ok := r.byString[name]; ok { + return fmt.Errorf("type name %s already registered", name) + } + if _, ok := r.byString[as]; !ok { + return fmt.Errorf("type name %s was not registered", as) + } + if _, ok := r.custom[as]; ok { + r.custom[name] = r.custom[as] + } + r.byString[name] = r.byString[as] + return nil +} + +func (r *RegisteredTypes) mustRegisterAlias(name, as string) { + if err := r.registerAlias(name, as); err != nil { + panic(err) + } +} + +func (r *RegisteredTypes) typeInfoFromJavaString(proto int, fullName string) (TypeInfo, error) { + name := strings.TrimPrefix(fullName, apacheCassandraTypePrefix) + compositeNameIdx := strings.Index(name, "(") + var params string + if compositeNameIdx != -1 { + compositeParamsEndIdx := strings.LastIndex(name, ")") + if compositeParamsEndIdx == -1 { + return nil, fmt.Errorf("invalid type string %v", fullName) + } + params = name[compositeNameIdx+1 : compositeParamsEndIdx] + name = name[:compositeNameIdx] + } + t, ok := r.getType(name) + if !ok { + return nil, unknownTypeError(fullName) + } + return t.TypeInfoFromString(proto, params) +} + +func (r *RegisteredTypes) typeInfoFromString(proto int, name string) (TypeInfo, error) { + // check for java long-form type + if strings.HasPrefix(name, apacheCassandraTypePrefix) { + return r.typeInfoFromJavaString(proto, name) + } + + compositeNameIdx := strings.Index(name, "<") + var params string + if compositeNameIdx != -1 { + compositeParamsEndIdx := strings.LastIndex(name, ">") + if compositeParamsEndIdx == -1 { + return nil, fmt.Errorf("invalid type string %v", name) + } + params = name[compositeNameIdx+1 : compositeParamsEndIdx] + name = name[:compositeNameIdx] + // frozen is a special case + if name == "frozen" { + return r.typeInfoFromString(proto, params) + } + } else if strings.Contains(name, "(") { + // most likely a java long-form type + return r.typeInfoFromJavaString(proto, name) + } + + t, ok := r.getType(name) + if !ok { + return nil, unknownTypeError(name) + } + return t.TypeInfoFromString(proto, params) +} + +func splitCompositeTypes(name string) []string { + // check for the simple case without any composite types + if !strings.Contains(name, "(") && !strings.Contains(name, "<") { + parts := strings.Split(name, ",") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts + } + var parts []string + lessCount := 0 + segment := "" + var openChar, closeChar rune + for _, char := range name { + if char == ',' && lessCount == 0 { + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + segment = "" + continue + } + segment += string(char) + // determine which open/close characters to use + if openChar == 0 { + if char == '<' { + openChar = '<' + closeChar = '>' + lessCount++ + } else if char == '(' { + openChar = '(' + closeChar = ')' + lessCount++ + } + } else if char == openChar { + lessCount++ + } else if char == closeChar { + lessCount-- + } + } + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + return parts +} + +func (r *RegisteredTypes) getType(classOrName string) (CQLType, bool) { + classOrName = strings.TrimPrefix(classOrName, apacheCassandraTypePrefix) + typ, ok := r.byString[classOrName] + if !ok { + // it could also be a UDT but this is what we've always done + typ = TypeCustom + } + var t CQLType + if typ == TypeCustom { + t, ok = r.custom[classOrName] + } else { + t = r.fastRegisteredTypeLookup(typ) + ok = t != nil + } + return t, ok +} + +func (r *RegisteredTypes) fastTypeInfoLookup(typ Type) TypeInfo { + switch typ { + case TypeAscii: + return varcharLikeTypeInfo{ + typ: TypeAscii, + } + case TypeBigInt: + return bigIntLikeTypeInfo{ + typ: TypeBigInt, + } + case TypeBlob: + return varcharLikeTypeInfo{ + typ: TypeBlob, + } + case TypeBoolean: + return booleanTypeInfo{} + case TypeCounter: + return bigIntLikeTypeInfo{ + typ: TypeCounter, + } + case TypeDate: + return dateTypeInfo{} + case TypeDecimal: + return decimalTypeInfo{} + case TypeDouble: + return doubleTypeInfo{} + case TypeDuration: + return durationTypeInfo{} + case TypeFloat: + return floatTypeInfo{} + case TypeInet: + return inetType{} + case TypeInt: + return intTypeInfo{} + case TypeSmallInt: + return smallIntTypeInfo{} + case TypeText: + return varcharLikeTypeInfo{ + typ: TypeText, + } + case TypeTime: + return timeTypeInfo{} + case TypeTimestamp: + return timestampTypeInfo{} + case TypeTimeUUID: + return timeUUIDType{} + case TypeTinyInt: + return tinyIntTypeInfo{} + case TypeUUID: + return uuidType{} + case TypeVarchar: + return varcharLikeTypeInfo{ + typ: TypeVarchar, + } + case TypeVarint: + return varintTypeInfo{} + default: + return r.simples[typ] + } +} + +// fastRegisteredTypeLookup is a fast lookup for the registered type that avoids +// the need for a map lookup which was shown to be significant +// in cases where it's necessary you should consider manually inlining this method +func (r *RegisteredTypes) fastRegisteredTypeLookup(typ Type) CQLType { + switch typ { + case TypeAscii: + return SimpleCQLType{varcharLikeTypeInfo{ + typ: TypeAscii, + }} + case TypeBigInt: + return SimpleCQLType{bigIntLikeTypeInfo{ + typ: TypeBigInt, + }} + case TypeBlob: + return SimpleCQLType{varcharLikeTypeInfo{ + typ: TypeBlob, + }} + case TypeBoolean: + return SimpleCQLType{booleanTypeInfo{}} + case TypeCounter: + return SimpleCQLType{bigIntLikeTypeInfo{ + typ: TypeCounter, + }} + case TypeDate: + return SimpleCQLType{dateTypeInfo{}} + case TypeDecimal: + return SimpleCQLType{decimalTypeInfo{}} + case TypeDouble: + return SimpleCQLType{doubleTypeInfo{}} + case TypeDuration: + return SimpleCQLType{durationTypeInfo{}} + case TypeFloat: + return SimpleCQLType{floatTypeInfo{}} + case TypeInet: + return SimpleCQLType{inetType{}} + case TypeInt: + return SimpleCQLType{intTypeInfo{}} + case TypeSmallInt: + return SimpleCQLType{smallIntTypeInfo{}} + case TypeText: + return SimpleCQLType{varcharLikeTypeInfo{ + typ: TypeText, + }} + case TypeTime: + return SimpleCQLType{timeTypeInfo{}} + case TypeTimestamp: + return SimpleCQLType{timestampTypeInfo{}} + case TypeTimeUUID: + return SimpleCQLType{timeUUIDType{}} + case TypeTinyInt: + return SimpleCQLType{tinyIntTypeInfo{}} + case TypeUUID: + return SimpleCQLType{uuidType{}} + case TypeVarchar: + return SimpleCQLType{varcharLikeTypeInfo{ + typ: TypeVarchar, + }} + case TypeVarint: + return SimpleCQLType{varintTypeInfo{}} + case TypeCustom: + // this should never happen + panic("custom types cannot be returned from fastRegisteredTypeLookup") + default: + return r.byType[typ] + } +} + +// Copy returns a new shallow copy of the RegisteredTypes +func (r *RegisteredTypes) Copy() *RegisteredTypes { + r.mut.Lock() + defer r.mut.Unlock() + + copy := &RegisteredTypes{} + copy.init() + for typ, t := range r.byType { + copy.byType[typ] = t + } + for typ, t := range r.simples { + copy.simples[typ] = t + } + for name, typ := range r.byString { + copy.byString[name] = typ + } + for name, t := range r.custom { + copy.custom[name] = t + } + return copy +} + +// GlobalTypes is the set of types that are registered globally and are copied +// by all sessions that don't define their own RegisteredTypes in ClusterConfig. +// Since a new session copies this, you should be modifying this or creating your +// own before a session is created. +var GlobalTypes = func() *RegisteredTypes { + r := &RegisteredTypes{} + // we init because we end up calling GlobalTypes in tests and other spots before + // init would get called + r.init() + r.addDefaultTypes() + return r +}() + +// Type is the identifier of a Cassandra internal datatype. +type Type int + +const ( + TypeCustom Type = 0x0000 + TypeAscii Type = 0x0001 + TypeBigInt Type = 0x0002 + TypeBlob Type = 0x0003 + TypeBoolean Type = 0x0004 + TypeCounter Type = 0x0005 + TypeDecimal Type = 0x0006 + TypeDouble Type = 0x0007 + TypeFloat Type = 0x0008 + TypeInt Type = 0x0009 + TypeText Type = 0x000A + TypeTimestamp Type = 0x000B + TypeUUID Type = 0x000C + TypeVarchar Type = 0x000D + TypeVarint Type = 0x000E + TypeTimeUUID Type = 0x000F + TypeInet Type = 0x0010 + TypeDate Type = 0x0011 + TypeTime Type = 0x0012 + TypeSmallInt Type = 0x0013 + TypeTinyInt Type = 0x0014 + TypeDuration Type = 0x0015 + TypeList Type = 0x0020 + TypeMap Type = 0x0021 + TypeSet Type = 0x0022 + TypeUDT Type = 0x0030 + TypeTuple Type = 0x0031 +) + +// NewNativeType returns a TypeInfo from the global registered types. +// Deprecated. +func NewNativeType(proto byte, typ Type, custom string) TypeInfo { + if typ == TypeCustom { + t, err := GlobalTypes.typeInfoFromString(int(proto), custom) + if err != nil { + panic(err) + } + return t + } + rt := GlobalTypes.fastRegisteredTypeLookup(typ) + if rt == nil { + return unknownTypeInfo(fmt.Sprintf("%d", typ)) + } + // most of the time this will do nothing because it's a SimpleCQLType but if + // it's not then we don't have anything to pass but custom + t, err := rt.TypeInfoFromString(int(proto), custom) + if err != nil { + panic(err) + } + return t +} + +type unknownTypeInfo string + +func (unknownTypeInfo) Type() Type { + return TypeCustom +} + +// Zero returns the zero value for the unknown custom type. +func (unknownTypeInfo) Zero() interface{} { + return nil +} + +func (u unknownTypeInfo) Marshal(value interface{}) ([]byte, error) { + return nil, fmt.Errorf("can not marshal %T into %s", value, string(u)) +} + +func (u unknownTypeInfo) Unmarshal(_ []byte, value interface{}) error { + return fmt.Errorf("can not unmarshal %s into %T", string(u), value) +} + +type unknownTypeError string + +func (e unknownTypeError) Error() string { + return fmt.Sprintf("unknown type %v", string(e)) +} diff --git a/types_test.go b/types_test.go new file mode 100644 index 000000000..c3764dce9 --- /dev/null +++ b/types_test.go @@ -0,0 +1,120 @@ +//go:build all || unit +// +build all unit + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 + * Copyright (c) 2016, The Gocql authors, + * provided under the BSD-3-Clause License. + * See the NOTICE file distributed with this work for additional information. + */ + +package gocql + +import ( + "reflect" + "testing" +) + +var defaultLongTypes = []struct { + TypeName string +}{ + {"AsciiType"}, + {"LongType"}, + {"BytesType"}, + {"BooleanType"}, + {"CounterColumnType"}, + {"DecimalType"}, + {"DoubleType"}, + {"FloatType"}, + {"Int32Type"}, + {"DateType"}, + {"TimestampType"}, + {"UUIDType"}, + {"UTF8Type"}, + {"IntegerType"}, + {"TimeUUIDType"}, + {"InetAddressType"}, + {"MapType"}, + {"ListType"}, + {"SetType"}, + {"ShortType"}, + {"ByteType"}, + {"TupleType"}, + {"UserType"}, + {"VectorType"}, +} + +func testType(t *testing.T, str string) { + _, ok := GlobalTypes.getType(apacheCassandraTypePrefix + str) + if !ok { + t.Errorf("failed to get type for %v", apacheCassandraTypePrefix+str) + } +} + +func TestDefaultLongTypes(t *testing.T) { + for _, lookupTest := range defaultLongTypes { + testType(t, lookupTest.TypeName) + } +} + +func TestSplitCompositeTypes(t *testing.T) { + var testCases = []struct { + Name string + Split []string + }{ + { + Name: "boolean", + Split: []string{"boolean"}, + }, + { + Name: "boolean, int", + Split: []string{"boolean", "int"}, + }, + { + Name: apacheCassandraTypePrefix + "TupleType(a, b)", + Split: []string{apacheCassandraTypePrefix + "TupleType(a, b)"}, + }, + { + Name: "tuple", + Split: []string{"tuple"}, + }, + { + Name: "tuple,b>", + Split: []string{"tuple,b>"}, + }, + { + Name: "tuple,b>, int", + Split: []string{"tuple,b>", "int"}, + }, + { + Name: apacheCassandraTypePrefix + "TupleType(a, b), " + apacheCassandraTypePrefix + "IntType", + Split: []string{ + apacheCassandraTypePrefix + "TupleType(a, b)", + apacheCassandraTypePrefix + "IntType", + }, + }, + } + for _, tc := range testCases { + split := splitCompositeTypes(tc.Name) + if !reflect.DeepEqual(split, tc.Split) { + t.Errorf("[%v] expected %v, got %v", tc.Name, tc.Split, split) + } + } +} diff --git a/vector.go b/vector.go new file mode 100644 index 000000000..266e341e0 --- /dev/null +++ b/vector.go @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gocql + +import ( + "bytes" + "errors" + "fmt" + "math/bits" + "reflect" + "strconv" +) + +type vectorCQLType struct { + types *RegisteredTypes +} + +// Params returns nil +func (vectorCQLType) Params(int) []interface{} { + // we don't support frame params for custom types + return nil +} + +// TypeInfoFromParams builds a TypeInfo implementation for the composite type with +// the given parameters. +func (vectorCQLType) TypeInfoFromParams(int, []interface{}) (TypeInfo, error) { + return nil, errors.New("unsupported for vector") +} + +// TypeInfoFromString builds the VectorType from the given string. +func (v vectorCQLType) TypeInfoFromString(proto int, name string) (TypeInfo, error) { + params := splitCompositeTypes(name) + if len(params) != 2 { + return nil, fmt.Errorf("expected 2 params for vector, got %d", len(params)) + } + subType, err := v.types.typeInfoFromString(proto, params[0]) + if err != nil { + return nil, err + } + dim, _ := strconv.Atoi(params[1]) + return VectorType{ + SubType: subType, + Dimensions: dim, + }, nil +} + +type VectorType struct { + SubType TypeInfo + Dimensions int +} + +func (VectorType) Type() Type { + return TypeCustom +} + +// Zero returns the zero value for the vector CQL type. +func (v VectorType) Zero() interface{} { + return reflect.Zero(reflect.SliceOf(reflect.TypeOf(v.SubType.Zero()))).Interface() +} + +func (t VectorType) String() string { + return fmt.Sprintf("vector(%s, %d)", t.SubType, t.Dimensions) +} + +// Marshal marshals the value into a byte slice. +func (v VectorType) Marshal(value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } else if _, ok := value.(unsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + if k == reflect.Slice && rv.IsNil() { + return nil, nil + } + + switch k { + case reflect.Slice, reflect.Array: + buf := &bytes.Buffer{} + n := rv.Len() + if n != v.Dimensions { + return nil, marshalErrorf("expected vector with %d dimensions, received %d", v.Dimensions, n) + } + + for i := 0; i < n; i++ { + item, err := Marshal(v.SubType, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + if isVectorVariableLengthType(v.SubType) { + writeUnsignedVInt(buf, uint64(len(item))) + } + buf.Write(item) + } + return buf.Bytes(), nil + } + return nil, marshalErrorf("can not marshal %T into %s. Accepted types: slice, array.", value, v) +} + +// Unmarshal unmarshals the byte slice into the value. +func (v VectorType) Unmarshal(data []byte, value interface{}) error { + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + switch k { + case reflect.Slice, reflect.Array: + if data == nil { + if k == reflect.Array { + return unmarshalErrorf("unmarshal vector: can not store nil in array value") + } + if rv.IsNil() { + return nil + } + rv.Set(reflect.Zero(t)) + return nil + } + if k == reflect.Array { + if rv.Len() != v.Dimensions { + return unmarshalErrorf("unmarshal vector: array of size %d cannot store vector of %d dimensions", rv.Len(), v.Dimensions) + } + } else { + rv.Set(reflect.MakeSlice(t, v.Dimensions, v.Dimensions)) + } + elemSize := len(data) / v.Dimensions + for i := 0; i < v.Dimensions; i++ { + offset := 0 + if isVectorVariableLengthType(v.SubType) { + m, p, err := readUnsignedVInt(data) + if err != nil { + return err + } + elemSize = int(m) + offset = p + } + if offset > 0 { + data = data[offset:] + } + var unmarshalData []byte + if elemSize >= 0 { + if len(data) < elemSize { + return unmarshalErrorf("unmarshal vector: unexpected eof") + } + unmarshalData = data[:elemSize] + data = data[elemSize:] + } + err := Unmarshal(v.SubType, unmarshalData, rv.Index(i).Addr().Interface()) + if err != nil { + return unmarshalErrorf("failed to unmarshal %s into %T: %s", v.SubType, unmarshalData, err.Error()) + } + } + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: slice, array.", v, value) +} + +// isVectorVariableLengthType determines if a type requires explicit length serialization within a vector. +// Variable-length types need their length encoded before the actual data to allow proper deserialization. +// Fixed-length types, on the other hand, don't require this kind of length prefix. +func isVectorVariableLengthType(elemType TypeInfo) bool { + switch elemType.Type() { + case TypeBigInt, TypeBoolean, TypeTimestamp, TypeDouble, TypeFloat, TypeInt, + TypeTimeUUID, TypeUUID: + return false + case TypeCustom: + // vectors are special in that they rely on the underlying type + if vecType, ok := elemType.(VectorType); ok { + return isVectorVariableLengthType(vecType.SubType) + } + } + return true +} + +func writeUnsignedVInt(buf *bytes.Buffer, v uint64) { + numBytes := computeUnsignedVIntSize(v) + if numBytes <= 1 { + buf.WriteByte(byte(v)) + return + } + + extraBytes := numBytes - 1 + var tmp = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + tmp[i] = byte(v) + v >>= 8 + } + tmp[0] |= byte(^(0xff >> uint(extraBytes))) + buf.Write(tmp) +} + +func readUnsignedVInt(data []byte) (uint64, int, error) { + if len(data) <= 0 { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[0] + if firstByte&0x80 == 0 { + return uint64(firstByte), 1, nil + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + if len(data) < numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", numBytes+1, len(data)) + } + for i := 0; i < numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return ret, numBytes + 1, nil +} + +func computeUnsignedVIntSize(v uint64) int { + lead0 := bits.LeadingZeros64(v) + return (639 - lead0*9) >> 6 +} diff --git a/vector_test.go b/vector_test.go index 4e52a8856..6db07c766 100644 --- a/vector_test.go +++ b/vector_test.go @@ -29,12 +29,13 @@ package gocql import ( "fmt" - "github.com/stretchr/testify/require" - "gopkg.in/inf.v0" "net" "reflect" "testing" "time" + + "github.com/stretchr/testify/require" + "gopkg.in/inf.v0" ) type person struct { @@ -107,9 +108,9 @@ func TestVector_Types(t *testing.T) { date2, _ := time.Parse("2006-01-02", "2022-03-14") date3, _ := time.Parse("2006-01-02", "2024-12-31") - time1, _ := time.Parse("15:04:05", "01:00:00") - time2, _ := time.Parse("15:04:05", "15:23:59") - time3, _ := time.Parse("15:04:05.000", "10:31:45.987") + time1 := time.Duration(time.Hour) + time2 := time.Duration(15*time.Hour + 23*time.Minute + 59*time.Second) + time3 := time.Duration(10*time.Hour + 31*time.Minute + 45*time.Second + 987*time.Millisecond) duration1 := Duration{0, 1, 1920000000000} duration2 := Duration{1, 1, 1920000000000} @@ -129,24 +130,24 @@ func TestVector_Types(t *testing.T) { value interface{} comparator func(interface{}, interface{}) }{ - {name: "ascii", cqlType: TypeAscii.String(), value: []string{"a", "1", "Z"}}, - {name: "bigint", cqlType: TypeBigInt.String(), value: []int64{1, 2, 3}}, - {name: "blob", cqlType: TypeBlob.String(), value: [][]byte{[]byte{1, 2, 3}, []byte{4, 5, 6, 7}, []byte{8, 9}}}, - {name: "boolean", cqlType: TypeBoolean.String(), value: []bool{true, false, true}}, - {name: "counter", cqlType: TypeCounter.String(), value: []int64{5, 6, 7}}, - {name: "decimal", cqlType: TypeDecimal.String(), value: []inf.Dec{*inf.NewDec(1, 0), *inf.NewDec(2, 1), *inf.NewDec(-3, 2)}}, - {name: "double", cqlType: TypeDouble.String(), value: []float64{0.1, -1.2, 3}}, - {name: "float", cqlType: TypeFloat.String(), value: []float32{0.1, -1.2, 3}}, - {name: "int", cqlType: TypeInt.String(), value: []int32{1, 2, 3}}, - {name: "text", cqlType: TypeText.String(), value: []string{"a", "b", "c"}}, - {name: "timestamp", cqlType: TypeTimestamp.String(), value: []time.Time{timestamp1, timestamp2, timestamp3}}, - {name: "uuid", cqlType: TypeUUID.String(), value: []UUID{MustRandomUUID(), MustRandomUUID(), MustRandomUUID()}}, - {name: "varchar", cqlType: TypeVarchar.String(), value: []string{"abc", "def", "ghi"}}, - {name: "varint", cqlType: TypeVarint.String(), value: []uint64{uint64(1234), uint64(123498765), uint64(18446744073709551615)}}, - {name: "timeuuid", cqlType: TypeTimeUUID.String(), value: []UUID{TimeUUID(), TimeUUID(), TimeUUID()}}, + {name: "ascii", cqlType: "ascii", value: []string{"a", "1", "Z"}}, + {name: "bigint", cqlType: "bigint", value: []int64{1, 2, 3}}, + {name: "blob", cqlType: "blob", value: [][]byte{[]byte{1, 2, 3}, []byte{4, 5, 6, 7}, []byte{8, 9}}}, + {name: "boolean", cqlType: "boolean", value: []bool{true, false, true}}, + {name: "counter", cqlType: "counter", value: []int64{5, 6, 7}}, + {name: "decimal", cqlType: "decimal", value: []inf.Dec{*inf.NewDec(1, 0), *inf.NewDec(2, 1), *inf.NewDec(-3, 2)}}, + {name: "double", cqlType: "double", value: []float64{0.1, -1.2, 3}}, + {name: "float", cqlType: "float", value: []float32{0.1, -1.2, 3}}, + {name: "int", cqlType: "int", value: []int32{1, 2, 3}}, + {name: "text", cqlType: "text", value: []string{"a", "b", "c"}}, + {name: "timestamp", cqlType: "timestamp", value: []time.Time{timestamp1, timestamp2, timestamp3}}, + {name: "uuid", cqlType: "uuid", value: []UUID{MustRandomUUID(), MustRandomUUID(), MustRandomUUID()}}, + {name: "varchar", cqlType: "varchar", value: []string{"abc", "def", "ghi"}}, + {name: "varint", cqlType: "varint", value: []uint64{uint64(1234), uint64(123498765), uint64(18446744073709551615)}}, + {name: "timeuuid", cqlType: "timeuuid", value: []UUID{TimeUUID(), TimeUUID(), TimeUUID()}}, { name: "inet", - cqlType: TypeInet.String(), + cqlType: "inet", value: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(192, 168, 1, 1), net.IPv4(8, 8, 8, 8)}, comparator: func(e interface{}, a interface{}) { expected := e.([]net.IP) @@ -157,11 +158,11 @@ func TestVector_Types(t *testing.T) { } }, }, - {name: "date", cqlType: TypeDate.String(), value: []time.Time{date1, date2, date3}}, - {name: "time", cqlType: TypeTimestamp.String(), value: []time.Time{time1, time2, time3}}, - {name: "smallint", cqlType: TypeSmallInt.String(), value: []int16{127, 256, -1234}}, - {name: "tinyint", cqlType: TypeTinyInt.String(), value: []int8{127, 9, -123}}, - {name: "duration", cqlType: TypeDuration.String(), value: []Duration{duration1, duration2, duration3}}, + {name: "date", cqlType: "date", value: []time.Time{date1, date2, date3}}, + {name: "time", cqlType: "time", value: []time.Duration{time1, time2, time3}}, + {name: "smallint", cqlType: "smallint", value: []int16{127, 256, -1234}}, + {name: "tinyint", cqlType: "tinyint", value: []int8{127, 9, -123}}, + {name: "duration", cqlType: "duration", value: []Duration{duration1, duration2, duration3}}, {name: "vector_vector_float", cqlType: "vector", value: [][]float32{{0.1, -1.2, 3, 5, 5}, {10.1, -122222.0002, 35.0, 1, 1}, {0, 0, 0, 0, 0}}}, {name: "vector_vector_set_float", cqlType: "vector, 5>", value: [][][]float32{ {{1, 2}, {2, -1}, {3}, {0}, {-1.3}}, @@ -307,95 +308,33 @@ func TestVector_MissingDimension(t *testing.T) { require.Error(t, err, "expected vector with 3 dimensions, received 4") } -func TestVector_SubTypeParsing(t *testing.T) { +func TestReadUnsignedVInt(t *testing.T) { tests := []struct { - name string - custom string - expected TypeInfo + decodedInt uint64 + encodedVint []byte }{ - {name: "text", custom: "org.apache.cassandra.db.marshal.UTF8Type", expected: NativeType{typ: TypeVarchar}}, - {name: "set_int", custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type)", expected: CollectionType{NativeType{typ: TypeSet}, nil, NativeType{typ: TypeInt}}}, { - name: "udt", - custom: "org.apache.cassandra.db.marshal.UserType(gocql_test,706572736f6e,66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,6c6173745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,616765:org.apache.cassandra.db.marshal.Int32Type)", - expected: UDTTypeInfo{ - NativeType{typ: TypeUDT}, - "gocql_test", - "person", - []UDTField{ - UDTField{"first_name", NativeType{typ: TypeVarchar}}, - UDTField{"last_name", NativeType{typ: TypeVarchar}}, - UDTField{"age", NativeType{typ: TypeInt}}, - }, - }, + decodedInt: 0, + encodedVint: []byte{0}, }, { - name: "tuple", - custom: "org.apache.cassandra.db.marshal.TupleType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)", - expected: TupleTypeInfo{ - NativeType{typ: TypeTuple}, - []TypeInfo{ - NativeType{typ: TypeVarchar}, - NativeType{typ: TypeInt}, - NativeType{typ: TypeVarchar}, - }, - }, + decodedInt: 100, + encodedVint: []byte{100}, }, { - name: "vector_vector_inet", - custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)", - expected: VectorType{ - NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, - VectorType{ - NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, - NativeType{typ: TypeInet}, - 2, - }, - 3, - }, - }, - { - name: "map_int_vector_text", - custom: "org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 10))", - expected: CollectionType{ - NativeType{typ: TypeMap}, - NativeType{typ: TypeInt}, - VectorType{ - NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, - NativeType{typ: TypeVarchar}, - 10, - }, - }, - }, - { - name: "set_map_vector_text_text", - custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 10),org.apache.cassandra.db.marshal.UTF8Type))", - expected: CollectionType{ - NativeType{typ: TypeSet}, - nil, - CollectionType{ - NativeType{typ: TypeMap}, - VectorType{ - NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, - NativeType{typ: TypeInt}, - 10, - }, - NativeType{typ: TypeVarchar}, - }, - }, + decodedInt: 256000, + encodedVint: []byte{195, 232, 0}, }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - f := newFramer(nil, 0) - f.writeShort(0) - f.writeString(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom)) - parsedType := f.readTypeInfo() - require.IsType(t, parsedType, VectorType{}) - vectorType := parsedType.(VectorType) - assertEqual(t, "dimensions", 2, vectorType.Dimensions) - assertDeepEqual(t, "vector", test.expected, vectorType.SubType) + t.Run(fmt.Sprintf("%d", test.decodedInt), func(t *testing.T) { + actual, _, err := readUnsignedVInt(test.encodedVint) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if actual != test.decodedInt { + t.Fatalf("Expected %d, but got %d", test.decodedInt, actual) + } }) } }