diff --git a/transport/cluster.go b/transport/cluster.go index 15f94530..f7a724b5 100644 --- a/transport/cluster.go +++ b/transport/cluster.go @@ -167,14 +167,12 @@ func (c *Cluster) NewControl(ctx context.Context) (*Conn, error) { if err := conn.RegisterEventHandler(ctx, c.handleEvent, c.handledEvents...); err == nil { return conn, nil } else { + conn.Close() errs = append(errs, fmt.Sprintf("%s failed to register for events: %s", conn, err)) } } else { errs = append(errs, fmt.Sprintf("%s failed to connect: %s", addr, err)) } - if conn != nil { - conn.Close() - } } return nil, fmt.Errorf("couldn't open control connection to any known host:\n%s", strings.Join(errs, "\n")) diff --git a/transport/conn.go b/transport/conn.go index 290fefb1..7792f612 100644 --- a/transport/conn.go +++ b/transport/conn.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" "sync" + "syscall" "time" "unicode" @@ -377,6 +378,17 @@ const ( comprBufferSize = 64 * 1024 // 64 Kb ) +/* +Checks if this error indicates that a chosen source port/address cannot be bound. + +This is caused by one of the following: + - The source address is already used by another socket, + - The source address is reserved and the process does not have sufficient privileges to use it. +*/ +func isAddrUnavailableForUseErr(err error) bool { + return errors.Is(err, syscall.EADDRINUSE) || errors.Is(err, syscall.EPERM) +} + // OpenShardConn opens connection mapped to a specific shard on Scylla node. func OpenShardConn(ctx context.Context, addr string, si ShardInfo, cfg ConnConfig) (*Conn, error) { it := ShardPortIterator(si) @@ -385,15 +397,16 @@ func OpenShardConn(ctx context.Context, addr string, si ShardInfo, cfg ConnConfi conn, err := OpenLocalPortConn(ctx, addr, it(), cfg) if err != nil { cfg.Logger.Infof("%s dial error: %s (try %d/%d)", addr, err, i, maxTries) - if conn != nil { - conn.Close() + if isAddrUnavailableForUseErr(err) { + continue } - continue + + return nil, fmt.Errorf("failed to open connection to shard: %w", err) } return conn, nil } - return nil, fmt.Errorf("failed to open connection on shard %d: all local ports are busy", si.Shard) + return nil, fmt.Errorf("failed to open connection on shard %d: all local ports are unavailable for use", si.Shard) } // OpenLocalPortConn opens connection on a given local port. @@ -506,7 +519,8 @@ func WrapConn(ctx context.Context, conn net.Conn, cfg ConnConfig) (*Conn, error) go c.r.loop(ctx) if err := c.init(ctx); err != nil { - return c, err + c.Close() + return nil, err } return c, nil diff --git a/transport/conn_integration_test.go b/transport/conn_integration_test.go index a161ca61..875c740a 100644 --- a/transport/conn_integration_test.go +++ b/transport/conn_integration_test.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "math/rand" + "net" "os/signal" "strconv" "sync" @@ -278,3 +279,65 @@ func testCompression(ctx context.Context, t *testing.T, c frame.Compression, toS } } } + +func TestConnectedToNonCqlServer(t *testing.T) { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGABRT, syscall.SIGTERM) + defer cancel() + + t.Logf("%+v", testingConnConfig) + testCases := []struct { + name string + response []byte + }{ + { + name: "non-cql response", + response: []byte("0"), + }, + { + name: "non supported cql response", + response: func() []byte { + var buf frame.Buffer + frame := frame.Header{ + Version: frame.CQLv4, + OpCode: frame.OpReady, + } + + frame.WriteTo(&buf) + return buf.Bytes() + }(), + }, + } + + for i := 0; i < len(testCases); i++ { + tc := testCases[i] + t.Run(tc.name, func(t *testing.T) { + server, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + defer server.Close() + go func() { + conn, err := server.Accept() + if err != nil { + t.Log(err) + t.Fail() + return + } + go func(conn net.Conn) { + defer conn.Close() + conn.Write(tc.response) + }(conn) + }() + + addr := server.Addr().String() + conn, err := OpenConn(ctx, addr, nil, testingConnConfig) + if err == nil { + t.Fatal("connecting to non-cql server should fail") + } + t.Log(err) + if conn != nil { + t.Fatal("connecting to non-cql server should return a nil-conn") + } + }) + } +} diff --git a/transport/pool.go b/transport/pool.go index b15ec62a..39480551 100644 --- a/transport/pool.go +++ b/transport/pool.go @@ -141,9 +141,6 @@ func (r *PoolRefiller) init(ctx context.Context, host string) error { conn, err := OpenConn(ctx, host, nil, r.cfg) span.stop() if err != nil { - if conn != nil { - conn.Close() - } return err } @@ -236,7 +233,9 @@ func (r *PoolRefiller) fill(ctx context.Context) { if r.pool.loadConn(i) != nil { continue } - + if ctx.Err() != nil { + return + } si.Shard = uint16(i) span := startSpan() conn, err := OpenShardConn(ctx, r.addr, si, r.cfg) @@ -245,9 +244,6 @@ func (r *PoolRefiller) fill(ctx context.Context) { if r.pool.connObs != nil { r.pool.connObs.OnConnect(ConnectEvent{ConnEvent: ConnEvent{Addr: r.addr, Shard: si.Shard}, span: span, Err: err}) } - if conn != nil { - conn.Close() - } continue } if r.pool.connObs != nil {