Skip to content

Commit 6064d85

Browse files
committed
Protocol version negotiation doesn't work if server replies with stream id different than 0
Previously, protocol negotiation didn't work properly when C* was responding with stream id different from 0. This patch changes the way protocol negotiation works. Instead of parsing a supported protocol version from C* error response, the driver tries to connect with each supported protocol starting from the latest. Patch by Bohdan Siryk; Reviewed by <> for CASSGO-98
1 parent 0089073 commit 6064d85

File tree

7 files changed

+276
-111
lines changed

7 files changed

+276
-111
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
- Prevent panic with queries during session init (CASSGO-92)
1818
- Return correct values from RowData (CASSGO-95)
1919
- Prevent setting a compression flag in a frame header when native proto v5 is being used (CASSGO-98)
20+
- Use protocol downgrading approach during protocol negotiation (CASSGO-97)
2021

2122
## [2.0.0]
2223

conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ func (s *startupCoordinator) options(ctx context.Context, startupCompleted *atom
410410

411411
supported, ok := frame.(*supportedFrame)
412412
if !ok {
413-
return NewErrProtocol("Unknown type of response to startup frame: %T", frame)
413+
return NewErrProtocol("Unknown type of response to startup frame: %T (frame=%s)", frame, frame.String())
414414
}
415415

416416
return s.startup(ctx, supported.supported, startupCompleted)

conn_test.go

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,9 @@ type newTestServerOpts struct {
10541054
addr string
10551055
protocol uint8
10561056
recvHook func(*framer)
1057+
1058+
customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error
1059+
dontFailOnProtocolMismatch bool
10571060
}
10581061

10591062
func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestServer {
@@ -1078,6 +1081,9 @@ func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestS
10781081
cancel: cancel,
10791082

10801083
onRecv: nts.recvHook,
1084+
1085+
customRequestHandler: nts.customRequestHandler,
1086+
dontFailOnProtocolMismatch: nts.dontFailOnProtocolMismatch,
10811087
}
10821088

10831089
go srv.closeWatch()
@@ -1142,6 +1148,10 @@ type TestServer struct {
11421148

11431149
// onRecv is a hook point for tests, called in receive loop.
11441150
onRecv func(*framer)
1151+
1152+
// customRequestHandler allows overriding the default request handling for testing purposes.
1153+
customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error
1154+
dontFailOnProtocolMismatch bool
11451155
}
11461156

11471157
func (srv *TestServer) closeWatch() {
@@ -1162,9 +1172,26 @@ func (srv *TestServer) serve() {
11621172
}
11631173

11641174
go func(conn net.Conn) {
1175+
var startupCompleted bool
1176+
var useProtoV5 bool
1177+
11651178
defer conn.Close()
11661179
for !srv.isClosed() {
1167-
framer, err := srv.readFrame(conn)
1180+
var reader io.Reader = conn
1181+
1182+
if useProtoV5 && startupCompleted {
1183+
frame, _, err := readUncompressedSegment(conn)
1184+
if err != nil {
1185+
if errors.Is(err, io.EOF) {
1186+
return
1187+
}
1188+
srv.errorLocked(err)
1189+
return
1190+
}
1191+
reader = bytes.NewReader(frame)
1192+
}
1193+
1194+
framer, err := srv.readFrame(reader)
11681195
if err != nil {
11691196
if err == io.EOF {
11701197
return
@@ -1177,7 +1204,7 @@ func (srv *TestServer) serve() {
11771204
srv.onRecv(framer)
11781205
}
11791206

1180-
go srv.process(conn, framer)
1207+
srv.process(conn, framer, &useProtoV5, &startupCompleted)
11811208
}
11821209
}(conn)
11831210
}
@@ -1215,13 +1242,22 @@ func (srv *TestServer) errorLocked(err interface{}) {
12151242
srv.t.Error(err)
12161243
}
12171244

1218-
func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
1245+
func (srv *TestServer) process(conn net.Conn, reqFrame *framer, useProtoV5, startupCompleted *bool) {
12191246
head := reqFrame.header
12201247
if head == nil {
12211248
srv.errorLocked("process frame with a nil header")
12221249
return
12231250
}
1224-
respFrame := newFramer(nil, reqFrame.proto, GlobalTypes)
1251+
respFrame := newFramer(nil, byte(head.version), GlobalTypes)
1252+
1253+
if srv.customRequestHandler != nil {
1254+
if err := srv.customRequestHandler(srv, reqFrame, respFrame); err != nil {
1255+
srv.errorLocked(err)
1256+
return
1257+
}
1258+
// Dont like this but...
1259+
goto finish
1260+
}
12251261

12261262
switch head.op {
12271263
case opStartup:
@@ -1412,34 +1448,54 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
14121448
respFrame.writeString("not supported")
14131449
}
14141450

1415-
respFrame.buf[0] = srv.protocol | 0x80
1451+
finish:
1452+
1453+
respFrame.buf[0] |= 0x80
14161454

14171455
if err := respFrame.finish(); err != nil {
14181456
srv.errorLocked(err)
14191457
}
14201458

1421-
if err := respFrame.writeTo(conn); err != nil {
1422-
srv.errorLocked(err)
1459+
if *useProtoV5 && *startupCompleted {
1460+
segment, err := newUncompressedSegment(respFrame.buf, true)
1461+
if err == nil {
1462+
_, err = conn.Write(segment)
1463+
}
1464+
if err != nil {
1465+
srv.errorLocked(err)
1466+
return
1467+
}
1468+
} else {
1469+
if err := respFrame.writeTo(conn); err != nil {
1470+
srv.errorLocked(err)
1471+
}
1472+
1473+
if reqFrame.header.op == opStartup {
1474+
*startupCompleted = true
1475+
if head.version == protoVersion5 {
1476+
*useProtoV5 = true
1477+
}
1478+
}
14231479
}
14241480
}
14251481

1426-
func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {
1482+
func (srv *TestServer) readFrame(reader io.Reader) (*framer, error) {
14271483
buf := make([]byte, srv.headerSize)
1428-
head, err := readHeader(conn, buf)
1484+
head, err := readHeader(reader, buf)
14291485
if err != nil {
14301486
return nil, err
14311487
}
14321488
framer := newFramer(nil, srv.protocol, GlobalTypes)
14331489

1434-
err = framer.readFrame(conn, &head)
1490+
err = framer.readFrame(reader, &head)
14351491
if err != nil {
14361492
return nil, err
14371493
}
14381494

14391495
// should be a request frame
14401496
if head.version.response() {
14411497
return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version)
1442-
} else if head.version.version() != srv.protocol {
1498+
} else if !srv.dontFailOnProtocolMismatch && head.version.version() != srv.protocol {
14431499
return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version())
14441500
}
14451501

control.go

Lines changed: 20 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import (
3232
"math/rand"
3333
"net"
3434
"os"
35-
"regexp"
3635
"strconv"
3736
"sync"
3837
"sync/atomic"
@@ -202,56 +201,12 @@ func shuffleHosts(hosts []*HostInfo) []*HostInfo {
202201
return shuffled
203202
}
204203

205-
// this is going to be version dependant and a nightmare to maintain :(
206-
var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
207-
var betaProtocolRe = regexp.MustCompile(`Beta version of the protocol used \(.*\), but USE_BETA flag is unset`)
208-
209-
func parseProtocolFromError(err error) int {
210-
errStr := err.Error()
211-
212-
var errProtocol ErrProtocol
213-
if errors.As(err, &errProtocol) {
214-
err = errProtocol.error
215-
}
216-
217-
// I really wish this had the actual info in the error frame...
218-
matches := betaProtocolRe.FindAllStringSubmatch(errStr, -1)
219-
if len(matches) == 1 {
220-
var protoErr *protocolError
221-
if errors.As(err, &protoErr) {
222-
version := protoErr.frame.Header().version.version()
223-
if version > 0 {
224-
return int(version - 1)
225-
}
226-
}
227-
return 0
228-
}
229-
230-
matches = protocolSupportRe.FindAllStringSubmatch(errStr, -1)
231-
if len(matches) != 1 || len(matches[0]) != 2 {
232-
var protoErr *protocolError
233-
if errors.As(err, &protoErr) {
234-
return int(protoErr.frame.Header().version.version())
235-
}
236-
return 0
237-
}
238-
239-
max, err := strconv.Atoi(matches[0][1])
240-
if err != nil {
241-
return 0
242-
}
243-
244-
return max
245-
}
246-
247-
const highestProtocolVersionSupported = 5
204+
const highestProtocolVersionSupported = protoVersion5
205+
const lowestProtocolVersionSupported = protoVersion3
248206

249207
func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
250208
hosts = shuffleHosts(hosts)
251209

252-
connCfg := *c.session.connCfg
253-
connCfg.ProtoVersion = highestProtocolVersionSupported
254-
255210
handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
256211
// we should never get here, but if we do it means we connected to a
257212
// host successfully which means our attempted protocol version worked
@@ -262,26 +217,27 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
262217

263218
var err error
264219
for _, host := range hosts {
265-
var conn *Conn
266-
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
267-
if conn != nil {
268-
conn.Close()
269-
}
220+
connCfg := *c.session.connCfg
221+
for proto := highestProtocolVersionSupported; proto >= lowestProtocolVersionSupported; proto-- {
222+
connCfg.ProtoVersion = proto
223+
224+
var conn *Conn
225+
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
226+
if conn != nil {
227+
conn.Close()
228+
}
270229

271-
if err == nil {
272-
c.session.logger.Debug("Discovered protocol version using host.",
273-
NewLogFieldInt("protocol_version", connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()))
274-
return connCfg.ProtoVersion, nil
275-
}
230+
if err == nil {
231+
c.session.logger.Debug("Discovered protocol version using host.",
232+
NewLogFieldInt("protocol_version", connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()))
233+
return connCfg.ProtoVersion, nil
234+
}
276235

277-
if proto := parseProtocolFromError(err); proto > 0 {
278-
c.session.logger.Debug("Discovered protocol version using host after parsing protocol error.",
279-
NewLogFieldInt("protocol_version", proto), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()))
280-
return proto, nil
236+
c.session.logger.Debug("Failed to connect to the host using protocol version.",
237+
NewLogFieldIP("host_addr", host.ConnectAddress()),
238+
NewLogFieldInt("protocol_version", connCfg.ProtoVersion),
239+
NewLogFieldError("err", err))
281240
}
282-
283-
c.session.logger.Debug("Failed to discover protocol version using host.",
284-
NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()), NewLogFieldError("err", err))
285241
}
286242

287243
return 0, err

control_test.go

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -57,38 +57,3 @@ func TestHostInfo_Lookup(t *testing.T) {
5757
}
5858
}
5959
}
60-
61-
func TestParseProtocol(t *testing.T) {
62-
tests := [...]struct {
63-
err error
64-
proto int
65-
}{
66-
{
67-
err: &protocolError{
68-
frame: errorFrame{
69-
code: 0x10,
70-
message: "Invalid or unsupported protocol version (5); the lowest supported version is 3 and the greatest is 4",
71-
},
72-
},
73-
proto: 4,
74-
},
75-
{
76-
err: &protocolError{
77-
frame: errorFrame{
78-
frameHeader: frameHeader{
79-
version: 0x83,
80-
},
81-
code: 0x10,
82-
message: "Invalid or unsupported protocol version: 5",
83-
},
84-
},
85-
proto: 3,
86-
},
87-
}
88-
89-
for i, test := range tests {
90-
if proto := parseProtocolFromError(test.err); proto != test.proto {
91-
t.Errorf("%d: exepcted proto %d got %d", i, test.proto, proto)
92-
}
93-
}
94-
}

frame.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,6 +2370,14 @@ func (f *framer) writeStringMap(m map[string]string) {
23702370
}
23712371
}
23722372

2373+
func (f *framer) writeStringMultiMap(m map[string][]string) {
2374+
f.writeShort(uint16(len(m)))
2375+
for k, v := range m {
2376+
f.writeString(k)
2377+
f.writeStringList(v)
2378+
}
2379+
}
2380+
23732381
func (f *framer) writeBytesMap(m map[string][]byte) {
23742382
f.writeShort(uint16(len(m)))
23752383
for k, v := range m {

0 commit comments

Comments
 (0)