Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Prevent panic with queries during session init (CASSGO-92)
- Return correct values from RowData (CASSGO-95)
- Prevent setting a compression flag in a frame header when native proto v5 is being used (CASSGO-98)
- Use protocol downgrading approach during protocol negotiation (CASSGO-97)

## [2.0.0]

Expand Down
2 changes: 1 addition & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func (s *startupCoordinator) options(ctx context.Context, startupCompleted *atom

supported, ok := frame.(*supportedFrame)
if !ok {
return NewErrProtocol("Unknown type of response to startup frame: %T", frame)
return NewErrProtocol("Unknown type of response to startup frame: %T (frame=%s)", frame, frame.String())
}

return s.startup(ctx, supported.supported, startupCompleted)
Expand Down
78 changes: 67 additions & 11 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,9 @@ type newTestServerOpts struct {
addr string
protocol uint8
recvHook func(*framer)

customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error
dontFailOnProtocolMismatch bool
}

func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestServer {
Expand All @@ -1078,6 +1081,9 @@ func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestS
cancel: cancel,

onRecv: nts.recvHook,

customRequestHandler: nts.customRequestHandler,
dontFailOnProtocolMismatch: nts.dontFailOnProtocolMismatch,
}

go srv.closeWatch()
Expand Down Expand Up @@ -1142,6 +1148,10 @@ type TestServer struct {

// onRecv is a hook point for tests, called in receive loop.
onRecv func(*framer)

// customRequestHandler allows overriding the default request handling for testing purposes.
customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error
dontFailOnProtocolMismatch bool
}

func (srv *TestServer) closeWatch() {
Expand All @@ -1162,9 +1172,26 @@ func (srv *TestServer) serve() {
}

go func(conn net.Conn) {
var startupCompleted bool
var useProtoV5 bool

defer conn.Close()
for !srv.isClosed() {
framer, err := srv.readFrame(conn)
var reader io.Reader = conn

if useProtoV5 && startupCompleted {
frame, _, err := readUncompressedSegment(conn)
if err != nil {
if errors.Is(err, io.EOF) {
return
}
srv.errorLocked(err)
return
}
reader = bytes.NewReader(frame)
}

framer, err := srv.readFrame(reader)
if err != nil {
if err == io.EOF {
return
Expand All @@ -1177,7 +1204,7 @@ func (srv *TestServer) serve() {
srv.onRecv(framer)
}

go srv.process(conn, framer)
srv.process(conn, framer, &useProtoV5, &startupCompleted)
}
}(conn)
}
Expand Down Expand Up @@ -1215,13 +1242,22 @@ func (srv *TestServer) errorLocked(err interface{}) {
srv.t.Error(err)
}

func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
func (srv *TestServer) process(conn net.Conn, reqFrame *framer, useProtoV5, startupCompleted *bool) {
head := reqFrame.header
if head == nil {
srv.errorLocked("process frame with a nil header")
return
}
respFrame := newFramer(nil, reqFrame.proto, GlobalTypes)
respFrame := newFramer(nil, byte(head.version), GlobalTypes)

if srv.customRequestHandler != nil {
if err := srv.customRequestHandler(srv, reqFrame, respFrame); err != nil {
srv.errorLocked(err)
return
}
// Dont like this but...
goto finish
}

switch head.op {
case opStartup:
Expand Down Expand Up @@ -1412,34 +1448,54 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
respFrame.writeString("not supported")
}

respFrame.buf[0] = srv.protocol | 0x80
finish:

respFrame.buf[0] |= 0x80

if err := respFrame.finish(); err != nil {
srv.errorLocked(err)
}

if err := respFrame.writeTo(conn); err != nil {
srv.errorLocked(err)
if *useProtoV5 && *startupCompleted {
segment, err := newUncompressedSegment(respFrame.buf, true)
if err == nil {
_, err = conn.Write(segment)
}
if err != nil {
srv.errorLocked(err)
return
}
} else {
if err := respFrame.writeTo(conn); err != nil {
srv.errorLocked(err)
}

if reqFrame.header.op == opStartup {
*startupCompleted = true
if head.version == protoVersion5 {
*useProtoV5 = true
}
}
}
}

func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {
func (srv *TestServer) readFrame(reader io.Reader) (*framer, error) {
buf := make([]byte, srv.headerSize)
head, err := readHeader(conn, buf)
head, err := readHeader(reader, buf)
if err != nil {
return nil, err
}
framer := newFramer(nil, srv.protocol, GlobalTypes)

err = framer.readFrame(conn, &head)
err = framer.readFrame(reader, &head)
if err != nil {
return nil, err
}

// should be a request frame
if head.version.response() {
return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version)
} else if head.version.version() != srv.protocol {
} else if !srv.dontFailOnProtocolMismatch && head.version.version() != srv.protocol {
return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version())
}

Expand Down
84 changes: 20 additions & 64 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"math/rand"
"net"
"os"
"regexp"
"strconv"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -202,56 +201,12 @@ func shuffleHosts(hosts []*HostInfo) []*HostInfo {
return shuffled
}

// this is going to be version dependant and a nightmare to maintain :(
var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
var betaProtocolRe = regexp.MustCompile(`Beta version of the protocol used \(.*\), but USE_BETA flag is unset`)

func parseProtocolFromError(err error) int {
errStr := err.Error()

var errProtocol ErrProtocol
if errors.As(err, &errProtocol) {
err = errProtocol.error
}

// I really wish this had the actual info in the error frame...
matches := betaProtocolRe.FindAllStringSubmatch(errStr, -1)
if len(matches) == 1 {
var protoErr *protocolError
if errors.As(err, &protoErr) {
version := protoErr.frame.Header().version.version()
if version > 0 {
return int(version - 1)
}
}
return 0
}

matches = protocolSupportRe.FindAllStringSubmatch(errStr, -1)
if len(matches) != 1 || len(matches[0]) != 2 {
var protoErr *protocolError
if errors.As(err, &protoErr) {
return int(protoErr.frame.Header().version.version())
}
return 0
}

max, err := strconv.Atoi(matches[0][1])
if err != nil {
return 0
}

return max
}

const highestProtocolVersionSupported = 5
const highestProtocolVersionSupported = protoVersion5
const lowestProtocolVersionSupported = protoVersion3

func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
hosts = shuffleHosts(hosts)

connCfg := *c.session.connCfg
connCfg.ProtoVersion = highestProtocolVersionSupported

handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
// we should never get here, but if we do it means we connected to a
// host successfully which means our attempted protocol version worked
Expand All @@ -262,26 +217,27 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {

var err error
for _, host := range hosts {
var conn *Conn
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
if conn != nil {
conn.Close()
}
connCfg := *c.session.connCfg
for proto := highestProtocolVersionSupported; proto >= lowestProtocolVersionSupported; proto-- {
connCfg.ProtoVersion = proto

var conn *Conn
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
if conn != nil {
conn.Close()
}

if err == nil {
c.session.logger.Debug("Discovered protocol version using host.",
NewLogFieldInt("protocol_version", connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()))
return connCfg.ProtoVersion, nil
}
if err == nil {
c.session.logger.Debug("Discovered protocol version using host.",
NewLogFieldInt("protocol_version", connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any particular reason you removed the host_id log field from both calls? Is it because it's not populated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is not populated, so there is really no useful information

return connCfg.ProtoVersion, nil
}

if proto := parseProtocolFromError(err); proto > 0 {
c.session.logger.Debug("Discovered protocol version using host after parsing protocol error.",
NewLogFieldInt("protocol_version", proto), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()))
return proto, nil
c.session.logger.Debug("Failed to connect to the host using protocol version.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only attempt to reconnect to the same host with a lower protocol version if the error is assumed to be related to unsupported protocol version:

  • t's an error to the first request (OPTIONS)
  • the error type is PROTOCOL_ERROR or SERVER_ERROR - SERVER_ERROR is for old C* versions that reported this error as SERVER_ERROR instead of PROTOCOL_ERROR

I propose we use a custom internal error type (unsupportedProtocolVersionErr for example) just for this case that we can return inside the dial method and we test for it here in discoverProtocol(). Reference code in the java driver here.

Technically the java driver does check for the error string which I didn't know but I still believe we shouldn't do that, checking the above 2 conditions is enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, we don't really have to re-try each time if the error is not protocol-related

NewLogFieldIP("host_addr", host.ConnectAddress()),
NewLogFieldInt("protocol_version", connCfg.ProtoVersion),
NewLogFieldError("err", err))
}

c.session.logger.Debug("Failed to discover protocol version using host.",
NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()), NewLogFieldError("err", err))
}

return 0, err
Expand Down
35 changes: 0 additions & 35 deletions control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,38 +57,3 @@ func TestHostInfo_Lookup(t *testing.T) {
}
}
}

func TestParseProtocol(t *testing.T) {
tests := [...]struct {
err error
proto int
}{
{
err: &protocolError{
frame: errorFrame{
code: 0x10,
message: "Invalid or unsupported protocol version (5); the lowest supported version is 3 and the greatest is 4",
},
},
proto: 4,
},
{
err: &protocolError{
frame: errorFrame{
frameHeader: frameHeader{
version: 0x83,
},
code: 0x10,
message: "Invalid or unsupported protocol version: 5",
},
},
proto: 3,
},
}

for i, test := range tests {
if proto := parseProtocolFromError(test.err); proto != test.proto {
t.Errorf("%d: exepcted proto %d got %d", i, test.proto, proto)
}
}
}
8 changes: 8 additions & 0 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -2370,6 +2370,14 @@ func (f *framer) writeStringMap(m map[string]string) {
}
}

func (f *framer) writeStringMultiMap(m map[string][]string) {
f.writeShort(uint16(len(m)))
for k, v := range m {
f.writeString(k)
f.writeStringList(v)
}
}

func (f *framer) writeBytesMap(m map[string][]byte) {
f.writeShort(uint16(len(m)))
for k, v := range m {
Expand Down
Loading