Skip to content

Commit 685c6fd

Browse files
authored
close Pool when failed to Connect it (#64)
* return initial connection errors and wait for error handlers to complete * close pool if Connect errors * fix how we collect and return errors
1 parent dd5cee4 commit 685c6fd

File tree

3 files changed

+37
-9
lines changed

3 files changed

+37
-9
lines changed

pool.go

+32-5
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,38 @@ func (p *Pool) handleError(err error) {
5353
return
5454
}
5555

56-
go p.Opts.ErrorHandler(err)
56+
p.wg.Add(1)
57+
go func() {
58+
defer p.wg.Done()
59+
p.Opts.ErrorHandler(err)
60+
}()
5761
}
5862

5963
// Connect creates poll of connections by calling Factory method and connect them all
6064
func (p *Pool) Connect() error {
65+
// We need to close pool (with all potentially running goroutines) if
66+
// connection creation fails. Example of such situation is when we
67+
// successfully created 2 connections, but 3rd failed and minimum
68+
// connections is 3.
69+
// Because `Close` uses same mutex as `Connect` we need to unlock it
70+
// before calling `Close`. That's why we use `connectErr` variable and
71+
// `defer` statement here, before the next `defer` which unlocks mutex.
72+
var connectErr error
73+
defer func() {
74+
if connectErr != nil {
75+
p.Close()
76+
}
77+
}()
78+
6179
p.mu.Lock()
6280
defer p.mu.Unlock()
6381
if p.isClosed {
6482
return errors.New("pool is closed")
6583
}
6684

85+
// errors from initial connections creation
86+
var errs []error
87+
6788
// build connections
6889
for _, addr := range p.Addrs {
6990
conn, err := p.Factory(addr)
@@ -76,7 +97,8 @@ func (p *Pool) Connect() error {
7697

7798
err = conn.Connect()
7899
if err != nil {
79-
p.handleError(fmt.Errorf("connecting to %s: %w", conn.addr, err))
100+
errs = append(errs, fmt.Errorf("connecting to %s: %w", addr, err))
101+
p.handleError(fmt.Errorf("failed to connect to %s: %w", conn.addr, err))
80102
p.wg.Add(1)
81103
go p.recreateConnection(conn)
82104
continue
@@ -85,11 +107,16 @@ func (p *Pool) Connect() error {
85107
p.connections = append(p.connections, conn)
86108
}
87109

88-
if len(p.connections) < p.Opts.MinConnections {
89-
return fmt.Errorf("minimum %d connections is required, established: %d", p.Opts.MinConnections, len(p.connections))
110+
if len(p.connections) >= p.Opts.MinConnections {
111+
return nil
90112
}
91113

92-
return nil
114+
if len(errs) == 0 {
115+
connectErr = fmt.Errorf("minimum %d connections is required, established: %d", p.Opts.MinConnections, len(p.connections))
116+
} else {
117+
connectErr = fmt.Errorf("minimum %d connections is required, established: %d, errors: %w", p.Opts.MinConnections, len(p.connections), errors.Join(errs...))
118+
}
119+
return connectErr
93120
}
94121

95122
// Connections returns copy of all connections from the pool

pool_options.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ type PoolOptions struct {
2020
MaxReconnectWait time.Duration
2121

2222
// ErrorHandler is called in a goroutine with the errors that can't be
23-
// returned to the caller
23+
// returned to the caller. Don't block in this function as it will
24+
// block the connection pool Close() method.
2425
ErrorHandler func(err error)
2526

2627
// MinConnections is the number of connections required to be established when

pool_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,9 @@ func TestPool(t *testing.T) {
230230
require.NoError(t, err)
231231

232232
err = pool.Connect()
233-
defer pool.Close()
234-
235-
require.EqualError(t, err, "minimum 3 connections is required, established: 2")
233+
require.Error(t, err)
234+
require.Contains(t, err.Error(), "minimum 3 connections is required, established: 2")
235+
require.Zero(t, len(pool.Connections()), "all connections should be closed")
236236
})
237237

238238
t.Run("Connect() returns no error when established >= MinConnections and later re-establish failed connections", func(t *testing.T) {

0 commit comments

Comments
 (0)