Skip to content
This repository was archived by the owner on Jul 21, 2025. It is now read-only.

Commit a443dbf

Browse files
committed
transport: WrapConn, close connection when initialization fails
WrapConn didn't close connReader and connWriter loops, leaving this responsibility to its callers, potentially leading to goroutine leaks, this is mitigated by closing those loops and never returning a non-nil Conn pointer when WrapConn fails. Fixes #289
1 parent 5369a73 commit a443dbf

File tree

4 files changed

+66
-13
lines changed

4 files changed

+66
-13
lines changed

transport/cluster.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,12 @@ func (c *Cluster) NewControl(ctx context.Context) (*Conn, error) {
167167
if err := conn.RegisterEventHandler(ctx, c.handleEvent, c.handledEvents...); err == nil {
168168
return conn, nil
169169
} else {
170+
conn.Close()
170171
errs = append(errs, fmt.Sprintf("%s failed to register for events: %s", conn, err))
171172
}
172173
} else {
173174
errs = append(errs, fmt.Sprintf("%s failed to connect: %s", addr, err))
174175
}
175-
if conn != nil {
176-
conn.Close()
177-
}
178176
}
179177

180178
return nil, fmt.Errorf("couldn't open control connection to any known host:\n%s", strings.Join(errs, "\n"))

transport/conn.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,6 @@ func OpenShardConn(ctx context.Context, addr string, si ShardInfo, cfg ConnConfi
385385
conn, err := OpenLocalPortConn(ctx, addr, it(), cfg)
386386
if err != nil {
387387
cfg.Logger.Infof("%s dial error: %s (try %d/%d)", addr, err, i, maxTries)
388-
if conn != nil {
389-
conn.Close()
390-
}
391388
continue
392389
}
393390
return conn, nil
@@ -506,7 +503,8 @@ func WrapConn(ctx context.Context, conn net.Conn, cfg ConnConfig) (*Conn, error)
506503
go c.r.loop(ctx)
507504

508505
if err := c.init(ctx); err != nil {
509-
return c, err
506+
c.Close()
507+
return nil, err
510508
}
511509

512510
return c, nil

transport/conn_integration_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"context"
77
"fmt"
88
"math/rand"
9+
"net"
910
"os/signal"
1011
"strconv"
1112
"sync"
@@ -278,3 +279,65 @@ func testCompression(ctx context.Context, t *testing.T, c frame.Compression, toS
278279
}
279280
}
280281
}
282+
283+
func TestConnectedToNonCqlServer(t *testing.T) {
284+
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGABRT, syscall.SIGTERM)
285+
defer cancel()
286+
287+
t.Logf("%+v", testingConnConfig)
288+
testCases := []struct {
289+
name string
290+
response []byte
291+
}{
292+
{
293+
name: "non-cql response",
294+
response: []byte("0"),
295+
},
296+
{
297+
name: "non supported cql response",
298+
response: func() []byte {
299+
var buf frame.Buffer
300+
frame := frame.Header{
301+
Version: frame.CQLv4,
302+
OpCode: frame.OpReady,
303+
}
304+
305+
frame.WriteTo(&buf)
306+
return buf.Bytes()
307+
}(),
308+
},
309+
}
310+
311+
for i := 0; i < len(testCases); i++ {
312+
tc := testCases[i]
313+
t.Run(tc.name, func(t *testing.T) {
314+
server, err := net.Listen("tcp", "127.0.0.1:")
315+
if err != nil {
316+
t.Fatal(err)
317+
}
318+
defer server.Close()
319+
go func() {
320+
conn, err := server.Accept()
321+
if err != nil {
322+
t.Log(err)
323+
t.Fail()
324+
return
325+
}
326+
go func(conn net.Conn) {
327+
defer conn.Close()
328+
conn.Write(tc.response)
329+
}(conn)
330+
}()
331+
332+
addr := server.Addr().String()
333+
conn, err := OpenConn(ctx, addr, nil, testingConnConfig)
334+
if err == nil {
335+
t.Fatal("connecting to non-cql server should fail")
336+
}
337+
t.Log(err)
338+
if conn != nil {
339+
t.Fatal("connecting to non-cql server should return a nil-conn")
340+
}
341+
})
342+
}
343+
}

transport/pool.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,6 @@ func (r *PoolRefiller) init(ctx context.Context, host string) error {
141141
conn, err := OpenConn(ctx, host, nil, r.cfg)
142142
span.stop()
143143
if err != nil {
144-
if conn != nil {
145-
conn.Close()
146-
}
147144
return err
148145
}
149146

@@ -245,9 +242,6 @@ func (r *PoolRefiller) fill(ctx context.Context) {
245242
if r.pool.connObs != nil {
246243
r.pool.connObs.OnConnect(ConnectEvent{ConnEvent: ConnEvent{Addr: r.addr, Shard: si.Shard}, span: span, Err: err})
247244
}
248-
if conn != nil {
249-
conn.Close()
250-
}
251245
continue
252246
}
253247
if r.pool.connObs != nil {

0 commit comments

Comments
 (0)