Skip to content

Commit ab314c0

Browse files
committed
impl of compressed frames
1 parent 85b9598 commit ab314c0

File tree

7 files changed

+197
-32
lines changed

7 files changed

+197
-32
lines changed

cassandra_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"context"
99
"errors"
1010
"fmt"
11+
"github.com/gocql/gocql/lz4"
1112
"io"
1213
"math"
1314
"math/big"
@@ -3373,3 +3374,35 @@ func randomString(n int) string {
33733374
}
33743375
return string(buf)
33753376
}
3377+
3378+
func TestCompressedQuery(t *testing.T) {
3379+
session := createSession(t, func(config *ClusterConfig) {
3380+
config.ProtoVersion = 5
3381+
config.Compressor = lz4.LZ4Compressor{}
3382+
config.Timeout = time.Hour
3383+
config.ConnectTimeout = time.Hour
3384+
config.WriteTimeout = time.Hour
3385+
})
3386+
defer session.Close()
3387+
3388+
if err := createTable(session, "CREATE TABLE gocql_test.native_v5_query_compressed(id int, text_col text, PRIMARY KEY (id))"); err != nil {
3389+
t.Fatal(err)
3390+
}
3391+
3392+
defer session.Close()
3393+
3394+
str := randomString(20)
3395+
3396+
err := session.Query("INSERT INTO gocql_test.native_v5_query_compressed (id, text_col) VALUES (?, ?)", "1", str).Exec()
3397+
if err != nil {
3398+
t.Fatal(err)
3399+
}
3400+
3401+
var result string
3402+
err = session.Query("SELECT text_col FROM gocql_test.native_v5_query_compressed").Scan(&result)
3403+
if err != nil {
3404+
t.Fatal(err)
3405+
}
3406+
3407+
assertEqual(t, "result should equal inserted str", str, result)
3408+
}

common_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
105105
// TODO: tb.Helper()
106106
c := *cluster
107107
c.Keyspace = "system"
108-
c.Timeout = 30 * time.Second
108+
c.Timeout = 30 * time.Hour
109109
session, err := c.CreateSession()
110110
if err != nil {
111111
panic(err)

compressor.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ type Compressor interface {
88
Name() string
99
Encode(data []byte) ([]byte, error)
1010
Decode(data []byte) ([]byte, error)
11+
DecodeSized(data []byte, size uint32) ([]byte, error)
1112
}
1213

1314
// SnappyCompressor implements the Compressor interface and can be used to
@@ -26,3 +27,8 @@ func (s SnappyCompressor) Encode(data []byte) ([]byte, error) {
2627
func (s SnappyCompressor) Decode(data []byte) ([]byte, error) {
2728
return snappy.Decode(nil, data)
2829
}
30+
31+
func (s SnappyCompressor) DecodeSized(data []byte, size uint32) ([]byte, error) {
32+
buf := make([]byte, size)
33+
return snappy.Decode(buf, data)
34+
}

conn.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ func (c *Conn) recvV5Frame(ctx context.Context) error {
591591
var err error
592592

593593
if c.compressor != nil {
594-
// TODO implement reading of compressed frames
594+
payload, isSelfContained, err = readCompressedFrame(c.r, c.compressor)
595595
} else {
596596
payload, isSelfContained, err = readUncompressedFrame(c.r)
597597
}
@@ -1815,8 +1815,14 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
18151815

18161816
func (c *Conn) recvMultiFrame(ctx context.Context, src io.Writer, bytesToRead int) error {
18171817
var read int
1818+
var segment []byte
1819+
var err error
18181820
for read != bytesToRead {
1819-
segment, _, err := readUncompressedFrame(c.r)
1821+
if c.compressor != nil {
1822+
segment, _, err = readCompressedFrame(c.r, c.compressor)
1823+
} else {
1824+
segment, _, err = readUncompressedFrame(c.r)
1825+
}
18201826
if err != nil {
18211827
return fmt.Errorf("failed to read multi-frame frame: %w", err)
18221828
}

crc.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ var (
77
initialCRC32Bytes = []byte{0xfa, 0x2d, 0x55, 0xca}
88
)
99

10-
func cassandraCrc32(b []byte) uint32 {
10+
func Checksum(b []byte) uint32 {
1111
crc := crc32.NewIEEE()
1212
crc.Reset()
1313
crc.Write(initialCRC32Bytes)
@@ -20,7 +20,7 @@ const (
2020
crc24Poly = 0x1974F0B
2121
)
2222

23-
func cassandraCrc24(buf []byte) uint32 {
23+
func KoopmanChecksum(buf []byte) uint32 {
2424
crc := crc24Init
2525
for _, b := range buf {
2626
crc ^= int(b) << 16

frame.go

Lines changed: 137 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ func (f *framer) readFrame(r io.Reader, head *frameHeader) error {
507507
return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err)
508508
}
509509

510-
if head.flags&flagCompress == flagCompress {
510+
if f.proto < protoVersion5 && head.flags&flagCompress == flagCompress {
511511
if f.compres == nil {
512512
return NewErrProtocol("no compressor available with compressed frame body")
513513
}
@@ -523,21 +523,6 @@ func (f *framer) readFrame(r io.Reader, head *frameHeader) error {
523523
return nil
524524
}
525525

526-
func (f *framer) readMultiFrame(r io.Reader) error {
527-
buf, _, err := readUncompressedFrame(r)
528-
if err != nil {
529-
return err
530-
}
531-
532-
n, err := io.ReadFull(bytes.NewBuffer(buf), f.buf[f.read:])
533-
if err != nil {
534-
return fmt.Errorf("unable to read multi frame body: read %d/%d bytes: %v", n, f.read, err)
535-
}
536-
537-
f.read += n
538-
return nil
539-
}
540-
541526
func (f *framer) parseFrame() (frame frame, err error) {
542527
defer func() {
543528
if r := recover(); r != nil {
@@ -2091,37 +2076,49 @@ func (f *framer) prepareModernLayout() error {
20912076
selfContained := true
20922077

20932078
var adjustedBuf []byte
2079+
var tempBuf []byte
2080+
var err error
2081+
20942082
for len(f.buf) > maxPayloadSize {
2095-
frame, err := newUncompressedFrame(f.buf[:maxPayloadSize], false)
2083+
if f.compres != nil {
2084+
tempBuf, err = newCompressedFrame(f.buf[:maxPayloadSize], false, f.compres)
2085+
} else {
2086+
tempBuf, err = newUncompressedFrame(f.buf[:maxPayloadSize], false)
2087+
}
20962088
if err != nil {
20972089
return err
20982090
}
20992091

2100-
adjustedBuf = append(adjustedBuf, frame...)
2092+
adjustedBuf = append(adjustedBuf, tempBuf...)
21012093
f.buf = f.buf[maxPayloadSize:]
21022094
selfContained = false
21032095
}
21042096

2105-
frame, err := newUncompressedFrame(f.buf, selfContained)
2097+
if f.compres != nil {
2098+
tempBuf, err = newCompressedFrame(f.buf, selfContained, f.compres)
2099+
} else {
2100+
tempBuf, err = newUncompressedFrame(f.buf, selfContained)
2101+
}
21062102
if err != nil {
21072103
return err
21082104
}
2109-
adjustedBuf = append(adjustedBuf, frame...)
2105+
2106+
adjustedBuf = append(adjustedBuf, tempBuf...)
21102107
f.buf = adjustedBuf
21112108

21122109
return nil
21132110
}
21142111

21152112
func readUncompressedFrame(r io.Reader) ([]byte, bool, error) {
2116-
const uncompressedFrameHeaderSize = 6
2117-
header := [uncompressedFrameHeaderSize + 1]byte{}
2113+
const headerSize = 6
2114+
header := [headerSize + 1]byte{}
21182115

2119-
_, err := io.ReadFull(r, header[:uncompressedFrameHeaderSize])
2116+
_, err := io.ReadFull(r, header[:headerSize])
21202117
if err != nil {
21212118
return nil, false, fmt.Errorf("gocql: failed to read uncompressed frame, err: %w", err)
21222119
}
21232120

2124-
computedHeaderCRC24 := cassandraCrc24(header[:3])
2121+
computedHeaderCRC24 := KoopmanChecksum(header[:3])
21252122
readHeaderCRC24 := binary.LittleEndian.Uint32(header[3:]) & 0xFFFFFF
21262123
if computedHeaderCRC24 != readHeaderCRC24 {
21272124
return nil, false, fmt.Errorf("gocql: header crc24 mismatch, computed: %d, got: %d", computedHeaderCRC24, readHeaderCRC24)
@@ -2144,7 +2141,7 @@ func readUncompressedFrame(r io.Reader) ([]byte, bool, error) {
21442141
return nil, false, fmt.Errorf("gocql: failed to read payload crc32, err: %w", err)
21452142
}
21462143

2147-
computedPayloadCRC32 := cassandraCrc32(payload)
2144+
computedPayloadCRC32 := Checksum(payload)
21482145
readPayloadCRC32 := binary.LittleEndian.Uint32(header[:4])
21492146
if computedPayloadCRC32 != readPayloadCRC32 {
21502147
return nil, false, fmt.Errorf("gocql: payload crc32 mismatch, computed: %d, got: %d", computedPayloadCRC32, readPayloadCRC32)
@@ -2181,7 +2178,7 @@ func newUncompressedFrame(payload []byte, isSelfContained bool) ([]byte, error)
21812178
header[2] = byte(headerInt >> 16)
21822179

21832180
// Calculate CRC24 for the first 3 bytes of the header
2184-
crc := cassandraCrc24(header[:3])
2181+
crc := KoopmanChecksum(header[:3])
21852182

21862183
// Encode CRC24 into the next 3 bytes of the header
21872184
header[3] = byte(crc)
@@ -2195,8 +2192,121 @@ func newUncompressedFrame(payload []byte, isSelfContained bool) ([]byte, error)
21952192
copy(frame[headerSize:], payload)
21962193

21972194
// Calculate CRC32 for the payload
2198-
payloadCRC32 := cassandraCrc32(payload)
2195+
payloadCRC32 := Checksum(payload)
21992196
binary.LittleEndian.PutUint32(frame[headerSize+payloadLen:], payloadCRC32)
22002197

22012198
return frame, nil
22022199
}
2200+
2201+
func newCompressedFrame(uncompressedPayload []byte, isSelfContained bool, compressor Compressor) ([]byte, error) {
2202+
compressedPayload, err := compressor.Encode(uncompressedPayload)
2203+
if err != nil {
2204+
return nil, err
2205+
}
2206+
// skipping first 4 bytes because size of uncompressed payload now is written in frame header,
2207+
// not in the body of compressed envelope
2208+
compressedPayload = compressedPayload[4:]
2209+
2210+
compressedLen := len(compressedPayload)
2211+
uncompressedLen := len(uncompressedPayload)
2212+
2213+
if compressedLen > maxPayloadSize {
2214+
return nil, fmt.Errorf("compressed payload length exceedes max size of frame payload %d/%d", compressedLen, maxPayloadSize)
2215+
}
2216+
2217+
if uncompressedLen > maxPayloadSize {
2218+
return nil, fmt.Errorf("uncompressed compressed payload length exceedes max size of frame payload %d/%d", uncompressedLen, maxPayloadSize)
2219+
}
2220+
2221+
combined := uint64(compressedLen) | uint64(uncompressedLen)<<17
2222+
if isSelfContained {
2223+
combined |= 1 << 34
2224+
}
2225+
2226+
var headerBuf [8]byte
2227+
2228+
binary.LittleEndian.PutUint64(headerBuf[:], combined)
2229+
2230+
// 8 - size of header, 4 - size of crc32 for payload
2231+
buf := bytes.NewBuffer(make([]byte, 0, 8+compressedLen+4))
2232+
2233+
// writing compressed and uncompressed sizes
2234+
buf.Write(headerBuf[:5])
2235+
2236+
// writing crc24 of first 5 bytes
2237+
headerChecksum := KoopmanChecksum(headerBuf[:5])
2238+
binary.LittleEndian.PutUint32(headerBuf[:], headerChecksum)
2239+
buf.Write(headerBuf[:3])
2240+
2241+
// writing compressed payload
2242+
buf.Write(compressedPayload)
2243+
2244+
// writing checksum of payload
2245+
payloadChecksum := Checksum(compressedPayload)
2246+
binary.LittleEndian.PutUint32(headerBuf[:], payloadChecksum)
2247+
buf.Write(headerBuf[:4])
2248+
2249+
return buf.Bytes(), err
2250+
}
2251+
2252+
func readCompressedFrame(r io.Reader, compressor Compressor) ([]byte, bool, error) {
2253+
var headerBuf [8]byte
2254+
_, err := io.ReadFull(r, headerBuf[:])
2255+
if err != nil {
2256+
return nil, false, err
2257+
}
2258+
2259+
// reading checksum from frame header
2260+
readHeaderChecksum := uint32(headerBuf[5]) | uint32(headerBuf[6])<<8 | uint32(headerBuf[7])<<16
2261+
computedHeaderChecksum := KoopmanChecksum(headerBuf[:5])
2262+
if computedHeaderChecksum != readHeaderChecksum {
2263+
return nil, false, fmt.Errorf("gocql: crc24 mismatch in frame header, read: %d, computed: %d", readHeaderChecksum, computedHeaderChecksum)
2264+
}
2265+
2266+
// first 17 bits - payload size after compression
2267+
compressedLen := uint32(headerBuf[0]) |
2268+
uint32(headerBuf[1])<<8 |
2269+
uint32(headerBuf[2]&0x1)<<16
2270+
2271+
// the next 17 bits - payload size before compression
2272+
uncompressedLen := (uint32(headerBuf[2]) >> 1) |
2273+
uint32(headerBuf[3])<<7 |
2274+
uint32(headerBuf[4]&0b11)<<15
2275+
2276+
// self-contained flag
2277+
selfContained := (headerBuf[4] & 0b100) != 0
2278+
2279+
compressedPayload := make([]byte, compressedLen)
2280+
_, err = io.ReadFull(r, compressedPayload)
2281+
if err != nil {
2282+
return nil, false, err
2283+
}
2284+
2285+
_, err = io.ReadFull(r, headerBuf[:4])
2286+
if err != nil {
2287+
return nil, false, err
2288+
}
2289+
2290+
// ensuring if payload checksum matches
2291+
readPayloadChecksum := binary.LittleEndian.Uint32(headerBuf[:4])
2292+
computedPayloadChecksum := Checksum(compressedPayload)
2293+
if readPayloadChecksum != computedPayloadChecksum {
2294+
return nil, false, fmt.Errorf("gocql: crc32 mismatch in payload, read: %d, computed: %d", readPayloadChecksum, computedPayloadChecksum)
2295+
}
2296+
2297+
var uncompressedPayload []byte
2298+
if uncompressedLen > 0 {
2299+
uncompressedPayload, err = compressor.DecodeSized(compressedPayload, uncompressedLen)
2300+
if err != nil {
2301+
return nil, false, err
2302+
}
2303+
2304+
if uint32(len(uncompressedPayload)) != uncompressedLen {
2305+
return nil, false, fmt.Errorf("gocql: length mismatch after payload decompression, got %d, read %d", len(uncompressedPayload), uncompressedLen)
2306+
}
2307+
} else {
2308+
uncompressedPayload = compressedPayload
2309+
}
2310+
2311+
return uncompressedPayload, selfContained, nil
2312+
}

lz4/lz4.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,13 @@ func (s LZ4Compressor) Decode(data []byte) ([]byte, error) {
4949
n, err := lz4.UncompressBlock(data[4:], buf)
5050
return buf[:n], err
5151
}
52+
53+
func (s LZ4Compressor) DecodeSized(data []byte, size uint32) ([]byte, error) {
54+
buf := make([]byte, size)
55+
_, err := lz4.UncompressBlock(data, buf)
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
return buf, nil
61+
}

0 commit comments

Comments
 (0)