diff --git a/CHANGELOG.md b/CHANGELOG.md index cd327073b..872ebdd1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Prevent panic with queries during session init (CASSGO-92) - Return correct values from RowData (CASSGO-95) - Prevent setting a compression flag in a frame header when native proto v5 is being used (CASSGO-98) +- Use protocol downgrading approach during protocol negotiation (CASSGO-97) ## [2.0.0] diff --git a/conn.go b/conn.go index 40044565d..fbbed090c 100644 --- a/conn.go +++ b/conn.go @@ -410,7 +410,7 @@ func (s *startupCoordinator) options(ctx context.Context, startupCompleted *atom supported, ok := frame.(*supportedFrame) if !ok { - return NewErrProtocol("Unknown type of response to startup frame: %T", frame) + return NewErrProtocol("Unknown type of response to startup frame: %T (frame=%s)", frame, frame.String()) } return s.startup(ctx, supported.supported, startupCompleted) diff --git a/conn_test.go b/conn_test.go index 60e4a2a8a..ad4e66e54 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1054,6 +1054,9 @@ type newTestServerOpts struct { addr string protocol uint8 recvHook func(*framer) + + customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error + dontFailOnProtocolMismatch bool } func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestServer { @@ -1078,6 +1081,9 @@ func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestS cancel: cancel, onRecv: nts.recvHook, + + customRequestHandler: nts.customRequestHandler, + dontFailOnProtocolMismatch: nts.dontFailOnProtocolMismatch, } go srv.closeWatch() @@ -1142,6 +1148,10 @@ type TestServer struct { // onRecv is a hook point for tests, called in receive loop. onRecv func(*framer) + + // customRequestHandler allows overriding the default request handling for testing purposes. + customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error + dontFailOnProtocolMismatch bool } func (srv *TestServer) closeWatch() { @@ -1162,9 +1172,26 @@ func (srv *TestServer) serve() { } go func(conn net.Conn) { + var startupCompleted bool + var useProtoV5 bool + defer conn.Close() for !srv.isClosed() { - framer, err := srv.readFrame(conn) + var reader io.Reader = conn + + if useProtoV5 && startupCompleted { + frame, _, err := readUncompressedSegment(conn) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + srv.errorLocked(err) + return + } + reader = bytes.NewReader(frame) + } + + framer, err := srv.readFrame(reader) if err != nil { if err == io.EOF { return @@ -1177,7 +1204,7 @@ func (srv *TestServer) serve() { srv.onRecv(framer) } - go srv.process(conn, framer) + srv.process(conn, framer, &useProtoV5, &startupCompleted) } }(conn) } @@ -1215,13 +1242,22 @@ func (srv *TestServer) errorLocked(err interface{}) { srv.t.Error(err) } -func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { +func (srv *TestServer) process(conn net.Conn, reqFrame *framer, useProtoV5, startupCompleted *bool) { head := reqFrame.header if head == nil { srv.errorLocked("process frame with a nil header") return } - respFrame := newFramer(nil, reqFrame.proto, GlobalTypes) + respFrame := newFramer(nil, byte(head.version), GlobalTypes) + + if srv.customRequestHandler != nil { + if err := srv.customRequestHandler(srv, reqFrame, respFrame); err != nil { + srv.errorLocked(err) + return + } + // Dont like this but... + goto finish + } switch head.op { case opStartup: @@ -1412,26 +1448,46 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { respFrame.writeString("not supported") } - respFrame.buf[0] = srv.protocol | 0x80 +finish: + + respFrame.buf[0] |= 0x80 if err := respFrame.finish(); err != nil { srv.errorLocked(err) } - if err := respFrame.writeTo(conn); err != nil { - srv.errorLocked(err) + if *useProtoV5 && *startupCompleted { + segment, err := newUncompressedSegment(respFrame.buf, true) + if err == nil { + _, err = conn.Write(segment) + } + if err != nil { + srv.errorLocked(err) + return + } + } else { + if err := respFrame.writeTo(conn); err != nil { + srv.errorLocked(err) + } + + if reqFrame.header.op == opStartup { + *startupCompleted = true + if head.version == protoVersion5 { + *useProtoV5 = true + } + } } } -func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { +func (srv *TestServer) readFrame(reader io.Reader) (*framer, error) { buf := make([]byte, srv.headerSize) - head, err := readHeader(conn, buf) + head, err := readHeader(reader, buf) if err != nil { return nil, err } framer := newFramer(nil, srv.protocol, GlobalTypes) - err = framer.readFrame(conn, &head) + err = framer.readFrame(reader, &head) if err != nil { return nil, err } @@ -1439,7 +1495,7 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { // should be a request frame if head.version.response() { return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version) - } else if head.version.version() != srv.protocol { + } else if !srv.dontFailOnProtocolMismatch && head.version.version() != srv.protocol { return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version()) } diff --git a/control.go b/control.go index e59acb402..8f09f9003 100644 --- a/control.go +++ b/control.go @@ -32,7 +32,6 @@ import ( "math/rand" "net" "os" - "regexp" "strconv" "sync" "sync/atomic" @@ -202,56 +201,12 @@ func shuffleHosts(hosts []*HostInfo) []*HostInfo { return shuffled } -// this is going to be version dependant and a nightmare to maintain :( -var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`) -var betaProtocolRe = regexp.MustCompile(`Beta version of the protocol used \(.*\), but USE_BETA flag is unset`) - -func parseProtocolFromError(err error) int { - errStr := err.Error() - - var errProtocol ErrProtocol - if errors.As(err, &errProtocol) { - err = errProtocol.error - } - - // I really wish this had the actual info in the error frame... - matches := betaProtocolRe.FindAllStringSubmatch(errStr, -1) - if len(matches) == 1 { - var protoErr *protocolError - if errors.As(err, &protoErr) { - version := protoErr.frame.Header().version.version() - if version > 0 { - return int(version - 1) - } - } - return 0 - } - - matches = protocolSupportRe.FindAllStringSubmatch(errStr, -1) - if len(matches) != 1 || len(matches[0]) != 2 { - var protoErr *protocolError - if errors.As(err, &protoErr) { - return int(protoErr.frame.Header().version.version()) - } - return 0 - } - - max, err := strconv.Atoi(matches[0][1]) - if err != nil { - return 0 - } - - return max -} - -const highestProtocolVersionSupported = 5 +const highestProtocolVersionSupported = protoVersion5 +const lowestProtocolVersionSupported = protoVersion3 func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { hosts = shuffleHosts(hosts) - connCfg := *c.session.connCfg - connCfg.ProtoVersion = highestProtocolVersionSupported - handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) { // we should never get here, but if we do it means we connected to a // host successfully which means our attempted protocol version worked @@ -262,26 +217,27 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { var err error for _, host := range hosts { - var conn *Conn - conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler) - if conn != nil { - conn.Close() - } + connCfg := *c.session.connCfg + for proto := highestProtocolVersionSupported; proto >= lowestProtocolVersionSupported; proto-- { + connCfg.ProtoVersion = proto + + var conn *Conn + conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler) + if conn != nil { + conn.Close() + } - if err == nil { - c.session.logger.Debug("Discovered protocol version using host.", - NewLogFieldInt("protocol_version", connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID())) - return connCfg.ProtoVersion, nil - } + if err == nil { + c.session.logger.Debug("Discovered protocol version using host.", + NewLogFieldInt("protocol_version", connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress())) + return connCfg.ProtoVersion, nil + } - if proto := parseProtocolFromError(err); proto > 0 { - c.session.logger.Debug("Discovered protocol version using host after parsing protocol error.", - NewLogFieldInt("protocol_version", proto), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID())) - return proto, nil + c.session.logger.Debug("Failed to connect to the host using protocol version.", + NewLogFieldIP("host_addr", host.ConnectAddress()), + NewLogFieldInt("protocol_version", connCfg.ProtoVersion), + NewLogFieldError("err", err)) } - - c.session.logger.Debug("Failed to discover protocol version using host.", - NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()), NewLogFieldError("err", err)) } return 0, err diff --git a/control_test.go b/control_test.go index 9f83ec955..7d9311a68 100644 --- a/control_test.go +++ b/control_test.go @@ -57,38 +57,3 @@ func TestHostInfo_Lookup(t *testing.T) { } } } - -func TestParseProtocol(t *testing.T) { - tests := [...]struct { - err error - proto int - }{ - { - err: &protocolError{ - frame: errorFrame{ - code: 0x10, - message: "Invalid or unsupported protocol version (5); the lowest supported version is 3 and the greatest is 4", - }, - }, - proto: 4, - }, - { - err: &protocolError{ - frame: errorFrame{ - frameHeader: frameHeader{ - version: 0x83, - }, - code: 0x10, - message: "Invalid or unsupported protocol version: 5", - }, - }, - proto: 3, - }, - } - - for i, test := range tests { - if proto := parseProtocolFromError(test.err); proto != test.proto { - t.Errorf("%d: exepcted proto %d got %d", i, test.proto, proto) - } - } -} diff --git a/frame.go b/frame.go index e86c538c8..467047c10 100644 --- a/frame.go +++ b/frame.go @@ -2370,6 +2370,14 @@ func (f *framer) writeStringMap(m map[string]string) { } } +func (f *framer) writeStringMultiMap(m map[string][]string) { + f.writeShort(uint16(len(m))) + for k, v := range m { + f.writeString(k) + f.writeStringList(v) + } +} + func (f *framer) writeBytesMap(m map[string][]byte) { f.writeShort(uint16(len(m))) for k, v := range m { diff --git a/protocol_negotiation_test.go b/protocol_negotiation_test.go new file mode 100644 index 000000000..e500f66a7 --- /dev/null +++ b/protocol_negotiation_test.go @@ -0,0 +1,179 @@ +//go:build all || unit +// +build all unit + +package gocql + +import ( + "context" + "encoding/binary" + "fmt" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type requestHandlerForProtocolNegotiationTest struct { + supportedProtocolVersions []protoVersion + supportedBetaProtocols []protoVersion +} + +func (r *requestHandlerForProtocolNegotiationTest) supportsBetaProtocol(version protoVersion) bool { + return slices.Contains(r.supportedBetaProtocols, version) +} + +func (r *requestHandlerForProtocolNegotiationTest) supportsProtocol(version protoVersion) bool { + return slices.Contains(r.supportedProtocolVersions, version) +} + +func (r *requestHandlerForProtocolNegotiationTest) hasBetaFlag(header *frameHeader) bool { + return header.flags&flagBetaProtocol == flagBetaProtocol +} + +func (r *requestHandlerForProtocolNegotiationTest) createBetaFlagUnsetProtocolErrorMessage(version protoVersion) string { + return fmt.Sprintf("Beta version of the protocol used (%d/v%d-beta), but USE_BETA flag is unset", version, version) +} + +func (r *requestHandlerForProtocolNegotiationTest) handle(_ *TestServer, reqFrame, respFrame *framer) error { + // If a client uses beta protocol, but the USE_BETA flag is not set, we respond with an error + if r.supportsBetaProtocol(reqFrame.header.version) && !r.hasBetaFlag(reqFrame.header) { + respFrame.writeHeader(0, opError, reqFrame.header.stream) + respFrame.writeInt(ErrCodeProtocol) + respFrame.writeString(r.createBetaFlagUnsetProtocolErrorMessage(reqFrame.header.version)) + return nil + } + + // if a client uses an unsupported protocol version, we respond with an error + if !r.supportsProtocol(reqFrame.header.version) { + respFrame.writeHeader(0, opError, reqFrame.header.stream) + respFrame.writeInt(ErrCodeProtocol) + respFrame.writeString(fmt.Sprintf("NEGOTITATION TEST: Unsupported protocol version %d", reqFrame.header.version)) + return nil + } + + stream := reqFrame.header.stream + + switch reqFrame.header.op { + case opStartup, opRegister: + respFrame.writeHeader(0, opReady, stream) + case opOptions: + respFrame.writeHeader(0, opSupported, stream) + var supportedVersionsWithDesc []string + for _, supportedVersion := range r.supportedProtocolVersions { + supportedVersionsWithDesc = append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d", supportedVersion, supportedVersion)) + } + for _, betaProtocol := range r.supportedBetaProtocols { + supportedVersionsWithDesc = append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d-beta", betaProtocol, betaProtocol)) + } + supported := map[string][]string{ + "PROTOCOL_VERSIONS": supportedVersionsWithDesc, + } + respFrame.writeStringMultiMap(supported) + case opQuery: + respFrame.writeHeader(0, opResult, stream) + respFrame.writeInt(resultKindRows) + respFrame.writeInt(int32(flagGlobalTableSpec)) + respFrame.writeInt(1) + respFrame.writeString("system") + respFrame.writeString("local") + respFrame.writeString("rack") + respFrame.writeShort(uint16(TypeVarchar)) + respFrame.writeInt(1) + respFrame.writeInt(int32(len("rack-1"))) + respFrame.writeString("rack-1") + case opPrepare: + // This doesn't really make any sense, but it's enough to test the protocol negotiation + respFrame.writeHeader(0, opResult, stream) + respFrame.writeInt(resultKindPrepared) + // + respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 111)) + if respFrame.proto >= protoVersion5 { + respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 222)) + } + // + respFrame.writeInt(0) // + respFrame.writeInt(0) // + if reqFrame.header.version >= protoVersion4 { + respFrame.writeInt(0) // + } + // + respFrame.writeInt(int32(flagGlobalTableSpec)) // + respFrame.writeInt(1) // + // + respFrame.writeString("system") + respFrame.writeString("keyspaces") + // + respFrame.writeString("col0") // + respFrame.writeShort(uint16(TypeBoolean)) // + case opExecute: + // This doesn't really make any sense, but it's enough to test the protocol negotiation + respFrame.writeHeader(0, opResult, stream) + respFrame.writeInt(resultKindRows) + // + respFrame.writeInt(0) // + respFrame.writeInt(0) // + // + respFrame.writeInt(0) + } + + return nil +} + +func TestProtocolNegotiation(t *testing.T) { + testCases := []struct { + name string + supportedVersions []protoVersion + supportedBetaVersions []protoVersion + expectedVersion protoVersion + }{ + { + name: "all supported versions", + supportedVersions: []protoVersion{protoVersion3, protoVersion4, protoVersion5}, + expectedVersion: protoVersion5, + }, + { + name: "v5-beta is supported", + supportedVersions: []protoVersion{protoVersion3, protoVersion4}, + supportedBetaVersions: []protoVersion{protoVersion5}, + expectedVersion: protoVersion4, + }, + { + name: "v5 is unsupported", + supportedVersions: []protoVersion{protoVersion3, protoVersion4}, + expectedVersion: protoVersion4, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + handler := &requestHandlerForProtocolNegotiationTest{ + supportedProtocolVersions: tc.supportedVersions, + supportedBetaProtocols: tc.supportedBetaVersions, + } + + srv := newTestServerOpts{ + addr: "127.0.0.1:0", + protocol: 5, + customRequestHandler: handler.handle, + dontFailOnProtocolMismatch: true, + }.newServer(t, context.Background()) + + go srv.serve() + defer srv.Stop() + + cluster := NewCluster(srv.Address) + cluster.Compressor = nil + cluster.ProtoVersion = 0 + cluster.Logger = NewLogger(LogLevelDebug) + cluster.ConnectTimeout = time.Hour + cluster.Timeout = time.Hour + cluster.DisableInitialHostLookup = true + + s, err := cluster.CreateSession() + require.NoError(t, err) + + require.Equal(t, tc.expectedVersion, protoVersion(s.cfg.ProtoVersion)) + }) + } +}