diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 98fb6b791..e8f1379a6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,7 +37,7 @@ jobs: go: [ '1.19', '1.20' ] cassandra_version: [ '4.0.8', '4.1.1' ] auth: [ "false" ] - compressor: [ "snappy" ] + compressor: [ "lz4" ] tags: [ "cassandra", "integration", "ccm" ] steps: - uses: actions/checkout@v2 @@ -101,7 +101,7 @@ jobs: ccm status ccm node1 nodetool status - args="-gocql.timeout=60s -runssl -proto=4 -rf=3 -clusterSize=3 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." + args="-gocql.timeout=60s -runssl -proto=5 -rf=3 -clusterSize=3 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." echo "args=$args" >> $GITHUB_ENV echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV @@ -127,7 +127,7 @@ jobs: matrix: go: [ '1.19', '1.20' ] cassandra_version: [ '4.0.8' ] - compressor: [ "snappy" ] + compressor: [ "lz4" ] tags: [ "integration" ] steps: @@ -190,7 +190,7 @@ jobs: ccm status ccm node1 nodetool status - args="-gocql.timeout=60s -runssl -proto=4 -rf=3 -clusterSize=1 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." + args="-gocql.timeout=60s -runssl -proto=5 -rf=3 -clusterSize=1 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." echo "args=$args" >> $GITHUB_ENV echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..1a5b7e1bd 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -35,6 +35,7 @@ import ( "io" "math" "math/big" + "math/rand" "net" "reflect" "strconv" @@ -3288,3 +3289,40 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +func TestLargeSizeQuery(t *testing.T) { + session := createSession(t) + defer session.Close() + + if err := createTable(session, "CREATE TABLE gocql_test.large_size_query(id int, text_col text, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + defer session.Close() + + longString := randomString(2_000_000) + + err := session.Query("INSERT INTO gocql_test.large_size_query (id, text_col) VALUES (?, ?)", "1", longString).Exec() + if err != nil { + t.Fatal(err) + } + + var result string + err = session.Query("SELECT text_col FROM gocql_test.large_size_query").Scan(&result) + if err != nil { + t.Fatal(err) + } + + assertEqual(t, "result should equal inserted longString", longString, result) +} + +func randomString(n int) string { + source := rand.NewSource(time.Now().UnixMilli()) + r := rand.New(source) + var aplhabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + buf := make([]byte, n) + for i := 0; i < n; i++ { + buf[i] = aplhabet[r.Intn(len(aplhabet))] + } + return string(buf) +} diff --git a/common_test.go b/common_test.go index a5edb03c6..01c42bb52 100644 --- a/common_test.go +++ b/common_test.go @@ -111,6 +111,8 @@ func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { switch *flagCompressTest { case "snappy": cluster.Compressor = &SnappyCompressor{} + case "lz4": + cluster.Compressor = &LZ4Compressor{} case "": default: panic("invalid compressor: " + *flagCompressTest) diff --git a/compressor.go b/compressor.go index f3d451a9f..c56452154 100644 --- a/compressor.go +++ b/compressor.go @@ -25,13 +25,17 @@ package gocql import ( + "encoding/binary" + "fmt" "github.com/golang/snappy" + "github.com/pierrec/lz4/v4" ) type Compressor interface { Name() string Encode(data []byte) ([]byte, error) Decode(data []byte) ([]byte, error) + DecodeSized(data []byte, size uint32) ([]byte, error) } // SnappyCompressor implements the Compressor interface and can be used to @@ -50,3 +54,51 @@ func (s SnappyCompressor) Encode(data []byte) ([]byte, error) { func (s SnappyCompressor) Decode(data []byte) ([]byte, error) { return snappy.Decode(nil, data) } + +func (s SnappyCompressor) DecodeSized(data []byte, size uint32) ([]byte, error) { + buf := make([]byte, size) + return snappy.Decode(buf, data) +} + +type LZ4Compressor struct{} + +func (s LZ4Compressor) Name() string { + return "lz4" +} + +func (s LZ4Compressor) Encode(data []byte) ([]byte, error) { + buf := make([]byte, lz4.CompressBlockBound(len(data)+4)) + var compressor lz4.Compressor + n, err := compressor.CompressBlock(data, buf[4:]) + // According to lz4.CompressBlock doc, it doesn't fail as long as the dst + // buffer length is at least lz4.CompressBlockBound(len(data))) bytes, but + // we check for error anyway just to be thorough. + if err != nil { + return nil, err + } + binary.BigEndian.PutUint32(buf, uint32(len(data))) + return buf[:n+4], nil +} + +func (s LZ4Compressor) Decode(data []byte) ([]byte, error) { + if len(data) < 4 { + return nil, fmt.Errorf("cassandra lz4 block size should be >4, got=%d", len(data)) + } + uncompressedLength := binary.BigEndian.Uint32(data) + if uncompressedLength == 0 { + return nil, nil + } + buf := make([]byte, uncompressedLength) + n, err := lz4.UncompressBlock(data[4:], buf) + return buf[:n], err +} + +func (s LZ4Compressor) DecodeSized(data []byte, size uint32) ([]byte, error) { + buf := make([]byte, size) + _, err := lz4.UncompressBlock(data, buf) + if err != nil { + return nil, err + } + + return buf, nil +} diff --git a/conn.go b/conn.go index 3daca6250..ac81bfb72 100644 --- a/conn.go +++ b/conn.go @@ -26,6 +26,7 @@ package gocql import ( "bufio" + "bytes" "context" "crypto/tls" "errors" @@ -215,6 +216,14 @@ type Conn struct { host *HostInfo isSchemaV2 bool + // Only for proto v5+. + // Indicates if Conn is ready to use Native Protocol V5. + // github.com/apache/cassandra/blob/trunk/doc/native_protocol_v5.spec + // 2.3.1 Initial Handshake + // In order to support both v5 and earlier formats, the v5 framing format is not + // applied to message exchanges before an initial handshake is completed. + connReady bool + session *Session // true if connection close process for the connection started. @@ -474,8 +483,12 @@ func (s *startupCoordinator) startup(ctx context.Context, supported map[string][ case error: return v case *readyFrame: + // Connection is successfully set up and ready to use Native Protocol v5 + s.conn.connReady = true return nil case *authenticateFrame: + // Connection is successfully set up and ready to use Native Protocol v5 + s.conn.connReady = true return s.authenticateHandshake(ctx, v) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) @@ -593,8 +606,8 @@ func (c *Conn) serve(ctx context.Context) { c.closeWithError(err) } -func (c *Conn) discardFrame(head frameHeader) error { - _, err := io.CopyN(ioutil.Discard, c, int64(head.length)) +func (c *Conn) discardFrame(r io.Reader, head frameHeader) error { + _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) if err != nil { return err } @@ -660,6 +673,16 @@ func (c *Conn) heartBeat(ctx context.Context) { } func (c *Conn) recv(ctx context.Context) error { + // If native proto v5+ is used and conn is set up, then we should + // unwrap payload body from v5 compressed/uncompressed frame + if c.version > protoVersion4 && c.connReady { + return c.recvProtoV5Frame(ctx) + } + + return c.processFrame(ctx, c) +} + +func (c *Conn) processFrame(ctx context.Context, r io.Reader) error { // not safe for concurrent reads // read a full header, ignore timeouts, as this is being ran in a loop @@ -670,7 +693,7 @@ func (c *Conn) recv(ctx context.Context) error { headStartTime := time.Now() // were just reading headers over and over and copy bodies - head, err := readHeader(c.r, c.headerBuf[:]) + head, err := readHeader(r, c.headerBuf[:]) headEndTime := time.Now() if err != nil { return err @@ -694,7 +717,7 @@ func (c *Conn) recv(ctx context.Context) error { } else if head.stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently framer := newFramer(c.compressor, c.version) - if err := framer.readFrame(c, &head); err != nil { + if err := framer.readFrame(r, &head); err != nil { return err } go c.session.handleEvent(framer) @@ -727,14 +750,14 @@ func (c *Conn) recv(ctx context.Context) error { c.mu.Unlock() if call == nil || !ok { c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) - return c.discardFrame(head) + return c.discardFrame(r, head) } else if head.stream != call.streamID { panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) } framer := newFramer(c.compressor, c.version) - err = framer.readFrame(c, &head) + err = framer.readFrame(r, &head) if err != nil { // only net errors should cause the connection to be closed. Though // cassandra returning corrupt frames will be returned here as well. @@ -777,6 +800,48 @@ func (c *Conn) handleTimeout() { } } +func (c *Conn) recvProtoV5Frame(ctx context.Context) error { + var ( + payload []byte + isSelfContained bool + err error + ) + + // Read frame based on compression + if c.compressor != nil { + payload, isSelfContained, err = readCompressedFrame(c.r, c.compressor) + } else { + payload, isSelfContained, err = readUncompressedFrame(c.r) + } + if err != nil { + return err + } + + if isSelfContained { + // TODO handle case when there are more than 1 envelop inside the frame + return c.processFrame(ctx, bytes.NewBuffer(payload)) + } + + head, err := readHeader(bytes.NewBuffer(payload), c.headerBuf[:]) + if err != nil { + return err + } + + const envelopeHeaderLength = 9 + buf := bytes.NewBuffer(make([]byte, 0, head.length+envelopeHeaderLength)) + buf.Write(payload) + + // Computing how many bytes of message left to read + bytesToRead := head.length - len(payload) + envelopeHeaderLength + + err = c.recvLastsFrames(buf, bytesToRead) + if err != nil { + return err + } + + return c.processFrame(ctx, buf) +} + type callReq struct { // resp will receive the frame that was sent as a response to this stream. resp chan callResp @@ -1086,7 +1151,29 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram return nil, err } - n, err := c.w.writeContext(ctx, framer.buf) + var n int + + if c.version > protoVersion4 && c.connReady { + err = framer.prepareModernLayout() + if err != nil { + // closeWithError will block waiting for this stream to either receive a response + // or for us to timeout. + close(call.timeout) + // We failed to serialize the frame into a buffer. + // This should not affect the connection as we didn't write anything. We just free the current call. + c.mu.Lock() + if !c.closed { + delete(c.calls, call.streamID) + } + c.mu.Unlock() + // We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil + // check above could fail. + c.releaseStream(call) + return nil, err + } + } + + n, err = c.w.writeContext(ctx, framer.buf) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout, close the timeout chan here. Im not entirely sure @@ -1223,9 +1310,10 @@ type StreamObserverContext interface { } type preparedStatment struct { - id []byte - request preparedMetadata - response resultMetadata + id []byte + metadataID []byte + request preparedMetadata + response resultMetadata } type inflightPrepare struct { @@ -1284,7 +1372,8 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) flight.preparedStatment = &preparedStatment{ // defensively copy as we will recycle the underlying buffer after we // return. - id: copyBytes(x.preparedID), + id: copyBytes(x.preparedID), + metadataID: copyBytes(x.reqMeta.id), // the type info's should _not_ have a reference to the framers read buffer, // therefore we can just copy them directly. request: x.reqMeta, @@ -1394,9 +1483,10 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) frame = &writeExecuteFrame{ - preparedID: info.id, - params: params, - customPayload: qry.customPayload, + preparedID: info.id, + preparedMetadataID: info.metadataID, + params: params, + customPayload: qry.customPayload, } // Set "keyspace" and "table" property in the query if it is present in preparedMetadata @@ -1756,6 +1846,32 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas) } +// recvLastsFrames reads proto v5 frames from Conn.r and writes decoded payload to dst. +// It reads data until the bytesToRead is reached. +// If Conn.compressor is not nil, it processes Compressed Format frames. +func (c *Conn) recvLastsFrames(dst *bytes.Buffer, bytesToRead int) error { + var read int + var segment []byte + var err error + for read != bytesToRead { + // Read frame based on compression + if c.compressor != nil { + segment, _, err = readCompressedFrame(c.r, c.compressor) + } else { + segment, _, err = readUncompressedFrame(c.r) + } + if err != nil { + return fmt.Errorf("gocql: failed to read non self-contained frame: %w", err) + } + + // Write the segment to the destination writer + n, _ := dst.Write(segment) + read += n + } + + return nil +} + var ( ErrQueryArgLength = errors.New("gocql: query argument length mismatch") ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period") diff --git a/control.go b/control.go index b30b44ea3..a2ce62a5f 100644 --- a/control.go +++ b/control.go @@ -216,7 +216,7 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { hosts = shuffleHosts(hosts) connCfg := *c.session.connCfg - connCfg.ProtoVersion = 4 // TODO: define maxProtocol + connCfg.ProtoVersion = 5 // TODO: define maxProtocol handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) { // we should never get here, but if we do it means we connected to a diff --git a/crc.go b/crc.go new file mode 100644 index 000000000..2ca23c5cf --- /dev/null +++ b/crc.go @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gocql + +import "hash/crc32" + +var ( + // Initial CRC32 bytes: 0xFA, 0x2D, 0x55, 0xCA + initialCRC32Bytes = []byte{0xfa, 0x2d, 0x55, 0xca} +) + +// ChecksumIEEE calculates the CRC32 checksum of the given byte slice. +func ChecksumIEEE(b []byte) uint32 { + crc := crc32.NewIEEE() + crc.Reset() + crc.Write(initialCRC32Bytes) // Include initial CRC32 bytes + crc.Write(b) + return crc.Sum32() +} + +const ( + crc24Init = 0x875060 // Initial value for CRC24 calculation + crc24Poly = 0x1974F0B // Polynomial for CRC24 calculation +) + +// KoopmanChecksum calculates the CRC24 checksum using the Koopman polynomial. +func KoopmanChecksum(buf []byte) uint32 { + crc := crc24Init // Initialize CRC with crc24Init value + for _, b := range buf { + crc ^= int(b) << 16 // XOR the byte shifted left by 16 bits with the current CRC value + + for i := 0; i < 8; i++ { // Process each bit in the byte + crc <<= 1 // Shift CRC left by 1 bit + if crc&0x1000000 != 0 { // If the highest bit (24th bit) is set + crc ^= crc24Poly // XOR the CRC value with the crc24Poly + } + } + } + + return uint32(crc) +} diff --git a/frame.go b/frame.go index d374ae574..07e509f4c 100644 --- a/frame.go +++ b/frame.go @@ -25,7 +25,9 @@ package gocql import ( + "bytes" "context" + "encoding/binary" "errors" "fmt" "io" @@ -524,7 +526,7 @@ func (f *framer) readFrame(r io.Reader, head *frameHeader) error { return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err) } - if head.flags&flagCompress == flagCompress { + if f.proto < protoVersion5 && head.flags&flagCompress == flagCompress { if f.compres == nil { return NewErrProtocol("no compressor available with compressed frame body") } @@ -768,7 +770,7 @@ func (f *framer) finish() error { return ErrFrameTooBig } - if f.buf[1]&flagCompress == flagCompress { + if f.proto < protoVersion5 && f.buf[1]&flagCompress == flagCompress { if f.compres == nil { panic("compress flag set with no compressor") } @@ -934,6 +936,9 @@ func (f *framer) readTypeInfo() TypeInfo { } type preparedMetadata struct { + // only for native proto >= 5 + id []byte + resultMetadata // proto v4+ @@ -952,6 +957,10 @@ func (f *framer) parsePreparedMetadata() preparedMetadata { // TODO: deduplicate this from parseMetadata meta := preparedMetadata{} + if f.proto > protoVersion4 { + meta.id = copyBytes(f.readShortBytes()) + } + meta.flags = f.readInt() meta.colCount = f.readInt() if meta.colCount < 0 { @@ -1599,8 +1608,9 @@ func (f frameWriterFunc) buildFrame(framer *framer, streamID int) error { } type writeExecuteFrame struct { - preparedID []byte - params queryParams + preparedID []byte + preparedMetadataID []byte + params queryParams // v4+ customPayload map[string][]byte @@ -1611,16 +1621,21 @@ func (e *writeExecuteFrame) String() string { } func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error { - return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) + return fr.writeExecuteFrame(streamID, e.preparedID, e.preparedMetadataID, &e.params, &e.customPayload) } -func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error { +func (f *framer) writeExecuteFrame(streamID int, preparedID, preparedMetadataID []byte, params *queryParams, customPayload *map[string][]byte) error { if len(*customPayload) > 0 { f.payload() } f.writeHeader(f.flags, opExecute, streamID) f.writeCustomPayload(customPayload) f.writeShortBytes(preparedID) + + if f.proto > protoVersion4 { + f.writeShortBytes(preparedMetadataID) + } + if f.proto > protoVersion1 { f.writeQueryParams(params) } else { @@ -2070,3 +2085,251 @@ func (f *framer) writeBytesMap(m map[string][]byte) { f.writeBytes(v) } } + +func (f *framer) prepareModernLayout() error { + // Ensure protocol version is V5 or higher + if f.proto < protoVersion5 { + panic("Modern layout is not supported with version V4 or less") + } + + selfContained := true + + var ( + adjustedBuf []byte + tempBuf []byte + err error + ) + + // Process the buffer in chunks if it exceeds the max payload size + for len(f.buf) > maxPayloadSize { + if f.compres != nil { + tempBuf, err = newCompressedFrame(f.buf[:maxPayloadSize], false, f.compres) + } else { + tempBuf, err = newUncompressedFrame(f.buf[:maxPayloadSize], false) + } + if err != nil { + return err + } + + adjustedBuf = append(adjustedBuf, tempBuf...) + f.buf = f.buf[maxPayloadSize:] + selfContained = false + } + + // Process the remaining buffer + if f.compres != nil { + tempBuf, err = newCompressedFrame(f.buf, selfContained, f.compres) + } else { + tempBuf, err = newUncompressedFrame(f.buf, selfContained) + } + if err != nil { + return err + } + + adjustedBuf = append(adjustedBuf, tempBuf...) + f.buf = adjustedBuf + + return nil +} + +func readUncompressedFrame(r io.Reader) ([]byte, bool, error) { + const headerSize = 6 + header := [headerSize + 1]byte{} + + // Read the frame header + if _, err := io.ReadFull(r, header[:headerSize]); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read uncompressed frame, err: %w", err) + } + + // Compute and verify the header CRC24 + computedHeaderCRC24 := KoopmanChecksum(header[:3]) + readHeaderCRC24 := binary.LittleEndian.Uint32(header[3:]) & 0xFFFFFF + if computedHeaderCRC24 != readHeaderCRC24 { + return nil, false, fmt.Errorf("gocql: header crc24 mismatch, computed: %d, got: %d", computedHeaderCRC24, readHeaderCRC24) + } + + // Extract the payload length and self-contained flag + headerInt := binary.LittleEndian.Uint32(header[:4]) + payloadLen := int(headerInt & 0x1FFFF) + isSelfContained := (headerInt & (1 << 17)) != 0 + + // Read the payload + payload := make([]byte, payloadLen) + if _, err := io.ReadFull(r, payload); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read uncompressed frame payload, err: %w", err) + } + + // Read and verify the payload CRC32 + if _, err := io.ReadFull(r, header[:4]); err != nil { + return nil, false, fmt.Errorf("gocql: failed to read payload crc32, err: %w", err) + } + + computedPayloadCRC32 := ChecksumIEEE(payload) + readPayloadCRC32 := binary.LittleEndian.Uint32(header[:4]) + if computedPayloadCRC32 != readPayloadCRC32 { + return nil, false, fmt.Errorf("gocql: payload crc32 mismatch, computed: %d, got: %d", computedPayloadCRC32, readPayloadCRC32) + } + + return payload, isSelfContained, nil +} + +const maxPayloadSize = 128*1024 - 1 + +func newUncompressedFrame(payload []byte, isSelfContained bool) ([]byte, error) { + const ( + headerSize = 6 + selfContainedBit = 1 << 17 + ) + + payloadLen := len(payload) + if payloadLen > maxPayloadSize { + return nil, fmt.Errorf("payload length (%d) exceeds maximum size of 128 KiB", payloadLen) + } + + header := make([]byte, headerSize) + + // First 3 bytes: payload length and self-contained flag + headerInt := uint32(payloadLen) & 0x1FFFF + if isSelfContained { + headerInt |= selfContainedBit // Set the self-contained flag + } + + // Encode the first 3 bytes as a single little-endian integer + header[0] = byte(headerInt) + header[1] = byte(headerInt >> 8) + header[2] = byte(headerInt >> 16) + + // Calculate CRC24 for the first 3 bytes of the header + crc := KoopmanChecksum(header[:3]) + + // Encode CRC24 into the next 3 bytes of the header + header[3] = byte(crc) + header[4] = byte(crc >> 8) + header[5] = byte(crc >> 16) + + // Create the frame + frameSize := headerSize + payloadLen + 4 // 4 bytes for CRC32 + frame := make([]byte, frameSize) + copy(frame, header) // Copy the header to the frame + copy(frame[headerSize:], payload) // Copy the payload to the frame + + // Calculate CRC32 for the payload + payloadCRC32 := ChecksumIEEE(payload) + binary.LittleEndian.PutUint32(frame[headerSize+payloadLen:], payloadCRC32) + + return frame, nil +} + +func newCompressedFrame(uncompressedPayload []byte, isSelfContained bool, compressor Compressor) ([]byte, error) { + uncompressedLen := len(uncompressedPayload) + if uncompressedLen > maxPayloadSize { + return nil, fmt.Errorf("uncompressed compressed payload length exceedes max size of frame payload %d/%d", uncompressedLen, maxPayloadSize) + } + + compressedPayload, err := compressor.Encode(uncompressedPayload) + if err != nil { + return nil, err + } + + // Skip the first 4 bytes because the size of the uncompressed payload is written in the frame header, not in the + // body of the compressed envelope + compressedPayload = compressedPayload[4:] + + compressedLen := len(compressedPayload) + + // Compression is not worth it + if uncompressedLen < compressedLen { + // native_protocol_v5.spec + // 2.2 + // An uncompressed length of 0 signals that the compressed payload + // should be used as-is and not decompressed. + compressedPayload = uncompressedPayload + compressedLen = uncompressedLen + uncompressedLen = 0 + } + + // Combine compressed and uncompressed lengths and set the self-contained flag if needed + combined := uint64(compressedLen) | uint64(uncompressedLen)<<17 + if isSelfContained { + combined |= 1 << 34 + } + + var headerBuf [8]byte + + // Write the combined value into the header buffer + binary.LittleEndian.PutUint64(headerBuf[:], combined) + + // Create a buffer with enough capacity to hold the header, compressed payload, and checksums + buf := bytes.NewBuffer(make([]byte, 0, 8+compressedLen+4)) + + // Write the first 5 bytes of the header (compressed and uncompressed sizes) + buf.Write(headerBuf[:5]) + + // Compute and write the CRC24 checksum of the first 5 bytes + headerChecksum := KoopmanChecksum(headerBuf[:5]) + binary.LittleEndian.PutUint32(headerBuf[:], headerChecksum) + buf.Write(headerBuf[:3]) + buf.Write(compressedPayload) + + // Compute and write the CRC32 checksum of the payload + payloadChecksum := ChecksumIEEE(compressedPayload) + binary.LittleEndian.PutUint32(headerBuf[:], payloadChecksum) + buf.Write(headerBuf[:4]) + + return buf.Bytes(), nil +} + +func readCompressedFrame(r io.Reader, compressor Compressor) ([]byte, bool, error) { + var ( + headerBuf [8]byte + err error + ) + + if _, err = io.ReadFull(r, headerBuf[:]); err != nil { + return nil, false, err + } + + // Reading checksum from frame header + readHeaderChecksum := uint32(headerBuf[5]) | uint32(headerBuf[6])<<8 | uint32(headerBuf[7])<<16 + if computedHeaderChecksum := KoopmanChecksum(headerBuf[:5]); computedHeaderChecksum != readHeaderChecksum { + return nil, false, fmt.Errorf("gocql: crc24 mismatch in frame header, read: %d, computed: %d", readHeaderChecksum, computedHeaderChecksum) + } + + // First 17 bits - payload size after compression + compressedLen := uint32(headerBuf[0]) | uint32(headerBuf[1])<<8 | uint32(headerBuf[2]&0x1)<<16 + + // The next 17 bits - payload size before compression + uncompressedLen := (uint32(headerBuf[2]) >> 1) | uint32(headerBuf[3])<<7 | uint32(headerBuf[4]&0b11)<<15 + + // Self-contained flag + selfContained := (headerBuf[4] & 0b100) != 0 + + compressedPayload := make([]byte, compressedLen) + if _, err = io.ReadFull(r, compressedPayload); err != nil { + return nil, false, err + } + + if _, err = io.ReadFull(r, headerBuf[:4]); err != nil { + return nil, false, err + } + + // Ensuring if payload checksum matches + readPayloadChecksum := binary.LittleEndian.Uint32(headerBuf[:4]) + if computedPayloadChecksum := ChecksumIEEE(compressedPayload); readPayloadChecksum != computedPayloadChecksum { + return nil, false, fmt.Errorf("gocql: crc32 mismatch in payload, read: %d, computed: %d", readPayloadChecksum, computedPayloadChecksum) + } + + var uncompressedPayload []byte + if uncompressedLen > 0 { + if uncompressedPayload, err = compressor.DecodeSized(compressedPayload, uncompressedLen); err != nil { + return nil, false, err + } + if uint32(len(uncompressedPayload)) != uncompressedLen { + return nil, false, fmt.Errorf("gocql: length mismatch after payload decompression, got %d, expected %d", len(uncompressedPayload), uncompressedLen) + } + } else { + uncompressedPayload = compressedPayload + } + + return uncompressedPayload, selfContained, nil +} diff --git a/go.mod b/go.mod index 0aea881ec..ef0253c7d 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,8 @@ require ( github.com/golang/snappy v0.0.3 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed github.com/kr/pretty v0.1.0 // indirect - github.com/stretchr/testify v1.3.0 // indirect + github.com/pierrec/lz4/v4 v4.1.21 // indirect + github.com/stretchr/testify v1.9.0 // indirect gopkg.in/inf.v0 v0.9.1 ) diff --git a/go.sum b/go.sum index 2e3892bcb..27cc59371 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,9 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYE github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= @@ -13,10 +14,23 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= +github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lz4/lz4.go b/lz4/lz4.go index 049fdc0bb..eb8329ee6 100644 --- a/lz4/lz4.go +++ b/lz4/lz4.go @@ -73,3 +73,13 @@ func (s LZ4Compressor) Decode(data []byte) ([]byte, error) { n, err := lz4.UncompressBlock(data[4:], buf) return buf[:n], err } + +func (s LZ4Compressor) DecodeSized(data []byte, size uint32) ([]byte, error) { + buf := make([]byte, size) + _, err := lz4.UncompressBlock(data, buf) + if err != nil { + return nil, err + } + + return buf, nil +}