diff --git a/client.go b/client.go index be9230ab..dee6dd2d 100644 --- a/client.go +++ b/client.go @@ -15,6 +15,7 @@ type singleClient struct { stop uint32 cmd Builder retry bool + hasLftm bool DisableCache bool } @@ -32,11 +33,11 @@ func newSingleClient(opt *ClientOption, prev conn, connFn connFn, retryer retryH if err := conn.Dial(); err != nil { return nil, err } - return newSingleClientWithConn(conn, cmds.NewBuilder(cmds.NoSlot), !opt.DisableRetry, opt.DisableCache, retryer), nil + return newSingleClientWithConn(conn, cmds.NewBuilder(cmds.NoSlot), !opt.DisableRetry, opt.DisableCache, retryer, opt.ConnLifetime > 0), nil } -func newSingleClientWithConn(conn conn, builder Builder, retry, disableCache bool, retryer retryHandler) *singleClient { - return &singleClient{cmd: builder, conn: conn, retry: retry, retryHandler: retryer, DisableCache: disableCache} +func newSingleClientWithConn(conn conn, builder Builder, retry, disableCache bool, retryer retryHandler, hasLftm bool) *singleClient { + return &singleClient{cmd: builder, conn: conn, retry: retry, retryHandler: retryer, hasLftm: hasLftm, DisableCache: disableCache} } func (c *singleClient) B() Builder { @@ -47,6 +48,9 @@ func (c *singleClient) Do(ctx context.Context, cmd Completed) (resp RedisResult) attempts := 1 retry: resp = c.conn.Do(ctx, cmd) + if resp.Error() == errConnExpired { + goto retry + } if c.retry && cmd.IsReadOnly() && c.isRetryable(resp.Error(), ctx) { shouldRetry := c.retryHandler.WaitOrSkipRetry( ctx, attempts, cmd, resp.Error(), @@ -86,6 +90,22 @@ func (c *singleClient) DoMulti(ctx context.Context, multi ...Completed) (resps [ attempts := 1 retry: resps = c.conn.DoMulti(ctx, multi...).s + if c.hasLftm { + var ml []Completed + recover: + ml = ml[:0] + for i, resp := range resps { + if resp.Error() == errConnExpired { + ml = multi[i:] + break + } + } + if len(ml) > 0 { + rs := c.conn.DoMulti(ctx, ml...).s + resps = append(resps[:len(resps)-len(rs)], rs...) + goto recover + } + } if c.retry && allReadOnly(multi) { for i, resp := range resps { if c.isRetryable(resp.Error(), ctx) { @@ -114,6 +134,22 @@ func (c *singleClient) DoMultiCache(ctx context.Context, multi ...CacheableTTL) attempts := 1 retry: resps = c.conn.DoMultiCache(ctx, multi...).s + if c.hasLftm { + var ml []CacheableTTL + recover: + ml = ml[:0] + for i, resp := range resps { + if resp.Error() == errConnExpired { + ml = multi[i:] + break + } + } + if len(ml) > 0 { + rs := c.conn.DoMultiCache(ctx, ml...).s + resps = append(resps[:len(resps)-len(rs)], rs...) + goto recover + } + } if c.retry { for i, resp := range resps { if c.isRetryable(resp.Error(), ctx) { @@ -139,6 +175,9 @@ func (c *singleClient) DoCache(ctx context.Context, cmd Cacheable, ttl time.Dura attempts := 1 retry: resp = c.conn.DoCache(ctx, cmd, ttl) + if resp.Error() == errConnExpired { + goto retry + } if c.retry && c.isRetryable(resp.Error(), ctx) { shouldRetry := c.retryHandler.WaitOrSkipRetry(ctx, attempts, Completed(cmd), resp.Error()) if shouldRetry { @@ -156,6 +195,9 @@ func (c *singleClient) Receive(ctx context.Context, subscribe Completed, fn func attempts := 1 retry: err = c.conn.Receive(ctx, subscribe, fn) + if err == errConnExpired { + goto retry + } if c.retry { if _, ok := err.(*RedisError); !ok && c.isRetryable(err, ctx) { shouldRetry := c.retryHandler.WaitOrSkipRetry(ctx, attempts, subscribe, err) diff --git a/client_test.go b/client_test.go index 37d83a17..1b29e70b 100644 --- a/client_test.go +++ b/client_test.go @@ -1424,6 +1424,130 @@ func TestSingleClientLoadingRetry(t *testing.T) { }) } +func TestSingleClientConnLifetime(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + + setup := func() (*singleClient, *mockConn) { + m := &mockConn{} + client, err := newSingleClient( + &ClientOption{InitAddress: []string{""}, ConnLifetime: 5 * time.Second}, + m, + func(dst string, opt *ClientOption) conn { return m }, + newRetryer(defaultRetryDelayFn), + ) + if err != nil { + t.Fatalf("unexpected err %v", err) + } + return client, m + } + + t.Run("Do", func(t *testing.T) { + client, m := setup() + m.DoFn = func(cmd Completed) RedisResult { + return newResult(strmsg('+', "OK"), nil) + } + if v, err := client.Do(context.Background(), client.B().Get().Key("Do").Build()).ToString(); err != nil || v != "OK" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti", func(t *testing.T) { + client, m := setup() + m.DoMultiFn = func(multi ...Completed) *redisresults { + return &redisresults{s: []RedisResult{newResult(strmsg('+', "OK"), nil)}} + } + if v, err := client.DoMulti(context.Background(), client.B().Get().Key("Do").Build())[0].ToString(); err != nil || v != "OK" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti ConnLifetime - at the head of processing", func(t *testing.T) { + client, m := setup() + attempts := 0 + m.DoMultiFn = func(multi ...Completed) *redisresults { + attempts++ + if attempts == 1 { + return &redisresults{s: []RedisResult{newErrResult(errConnExpired)}} + } + return &redisresults{s: []RedisResult{newResult(strmsg('+', "OK"), nil)}} + } + if v, err := client.DoMulti(context.Background(), client.B().Get().Key("Do").Build())[0].ToString(); err != nil || v != "OK" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMulti ConnLifetime in the middle of processing", func(t *testing.T) { + client, m := setup() + attempts := 0 + m.DoMultiFn = func(multi ...Completed) *redisresults { + attempts++ + if attempts == 1 { + return &redisresults{s: []RedisResult{newResult(strmsg('+', "OK"), nil), newErrResult(errConnExpired)}} + } + // recover the failure of the first call + return &redisresults{s: []RedisResult{newResult(strmsg('+', "OK"), nil)}} + } + resps := client.DoMulti(context.Background(), client.B().Get().Key("Do").Build(), client.B().Get().Key("Do").Build()) + if len(resps) != 2 { + t.Errorf("unexpected response length %v", len(resps)) + } + for _, resp := range resps { + if v, err := resp.ToString(); err != nil || v != "OK" { + t.Fatalf("unexpected response %v %v", v, err) + } + } + }) + + t.Run("DoMultiCache", func(t *testing.T) { + client, m := setup() + m.DoMultiCacheFn = func(multi ...CacheableTTL) *redisresults { + return &redisresults{s: []RedisResult{newResult(strmsg('+', "OK"), nil)}} + } + cmd := client.B().Get().Key("Do").Cache() + if v, err := client.DoMultiCache(context.Background(), CT(cmd, 0))[0].ToString(); err != nil || v != "OK" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMultiCache ConnLifetime - at the head of processing", func(t *testing.T) { + client, m := setup() + attempts := 0 + m.DoMultiCacheFn = func(multi ...CacheableTTL) *redisresults { + attempts++ + if attempts == 1 { + return &redisresults{s: []RedisResult{newErrResult(errConnExpired)}} + } + return &redisresults{s: []RedisResult{newResult(strmsg('+', "OK"), nil)}} + } + cmd := client.B().Get().Key("Do").Cache() + if v, err := client.DoMultiCache(context.Background(), CT(cmd, 0))[0].ToString(); err != nil || v != "OK" { + t.Fatalf("unexpected response %v %v", v, err) + } + }) + + t.Run("DoMultiCache ConnLifetime in the middle of processing", func(t *testing.T) { + client, m := setup() + attempts := 0 + m.DoMultiCacheFn = func(multi ...CacheableTTL) *redisresults { + attempts++ + if attempts == 1 { + return &redisresults{s: []RedisResult{newResult(strmsg('+', "OK"), nil), newErrResult(errConnExpired)}} + } + // recover the failure of the first call + return &redisresults{s: []RedisResult{newResult(strmsg('+', "OK"), nil)}} + } + resps := client.DoMultiCache(context.Background(), CT(client.B().Get().Key("Do").Cache(), 0), CT(client.B().Get().Key("Do").Cache(), 0)) + if len(resps) != 2 { + t.Errorf("unexpected response length %v", len(resps)) + } + for _, resp := range resps { + if v, err := resp.ToString(); err != nil || v != "OK" { + t.Fatalf("unexpected response %v %v", v, err) + } + } + }) +} + func BenchmarkSingleClient_DoCache(b *testing.B) { ctx := context.Background() client, err := NewClient(ClientOption{InitAddress: []string{"127.0.0.1:6379"}, Dialer: net.Dialer{KeepAlive: -1}}) diff --git a/cluster.go b/cluster.go index 918b075b..cf9466e9 100644 --- a/cluster.go +++ b/cluster.go @@ -1199,7 +1199,7 @@ func (c *clusterClient) Nodes() map[string]Client { disableCache := c.opt != nil && c.opt.DisableCache for addr, cc := range c.conns { if !cc.hidden { - _nodes[addr] = newSingleClientWithConn(cc.conn, c.cmd, c.retry, disableCache, c.retryHandler) + _nodes[addr] = newSingleClientWithConn(cc.conn, c.cmd, c.retry, disableCache, c.retryHandler, false) } } c.mu.RUnlock() diff --git a/mux_test.go b/mux_test.go index a8a8e970..f119a56e 100644 --- a/mux_test.go +++ b/mux_test.go @@ -1131,6 +1131,8 @@ type mockWire struct { VersionFn func() int ErrorFn func() error CloseFn func() + StopTimerFn func() bool + ResetTimerFn func() bool CleanSubscriptionsFn func() SetPubSubHooksFn func(hooks PubSubHooks) <-chan error @@ -1205,6 +1207,20 @@ func (m *mockWire) SetOnCloseHook(fn func(error)) { } } +func (m *mockWire) StopTimer() bool { + if m.StopTimerFn != nil { + return m.StopTimerFn() + } + return true +} + +func (m *mockWire) ResetTimer() bool { + if m.ResetTimerFn != nil { + return m.ResetTimerFn() + } + return true +} + func (m *mockWire) Info() map[string]RedisMessage { if m.InfoFn != nil { return m.InfoFn() diff --git a/pipe.go b/pipe.go index eb776ca8..8f975c9d 100644 --- a/pipe.go +++ b/pipe.go @@ -55,6 +55,8 @@ type wire interface { CleanSubscriptions() SetPubSubHooks(hooks PubSubHooks) <-chan error SetOnCloseHook(fn func(error)) + StopTimer() bool + ResetTimer() bool } var _ wire = (*pipe)(nil) @@ -77,11 +79,13 @@ type pipe struct { psubs *subs // pubsub pmessage subscriptions pingTimer *time.Timer // timer for background ping info map[string]RedisMessage + lftmTimer *time.Timer // lifetime timer timeout time.Duration pinggap time.Duration maxFlushDelay time.Duration - r2mu sync.Mutex wrCounter atomic.Uint64 + lftm time.Duration // lifetime + r2mu sync.Mutex version int32 blcksig int32 state int32 @@ -328,6 +332,10 @@ func _newPipe(ctx context.Context, connFn func(context.Context) (net.Conn, error p.backgroundPing() } } + if option.ConnLifetime > 0 { + p.lftm = option.ConnLifetime + p.lftmTimer = time.AfterFunc(option.ConnLifetime, p.expired) + } return p, nil } @@ -344,6 +352,7 @@ func (p *pipe) _exit(err error) { p.error.CompareAndSwap(nil, &errs{error: err}) atomic.CompareAndSwapInt32(&p.state, 1, 2) // stop accepting new requests _ = p.conn.Close() // force both read & write goroutine to exit + p.StopTimer() p.clhks.Load().(func(error))(err) } @@ -495,6 +504,9 @@ func (p *pipe) _backgroundRead() (err error) { defer func() { resp := newErrResult(err) + if e := p.Error(); e == errConnExpired { + resp = newErrResult(e) + } if err != nil && ff < len(multi) { for ; ff < len(resps); ff++ { resps[ff] = resp @@ -1633,6 +1645,25 @@ func (p *pipe) Close() { p.r2mu.Unlock() } +func (p *pipe) StopTimer() bool { + if p.lftmTimer == nil { + return true + } + return p.lftmTimer.Stop() +} + +func (p *pipe) ResetTimer() bool { + if p.lftmTimer == nil || p.Error() != nil { + return true + } + return p.lftmTimer.Reset(p.lftm) +} + +func (p *pipe) expired() { + p.error.CompareAndSwap(nil, errExpired) + p.Close() +} + type pshks struct { hooks PubSubHooks close chan error @@ -1672,6 +1703,9 @@ const ( ) var cacheMark = &(RedisMessage{}) -var errClosing = &errs{error: ErrClosing} +var ( + errClosing = &errs{error: ErrClosing} + errExpired = &errs{error: errConnExpired} +) type errs struct{ error } diff --git a/pipe_test.go b/pipe_test.go index 7289aac3..c79dfa40 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -2773,6 +2773,66 @@ func TestOnInvalidations(t *testing.T) { } } +func TestConnLifetime(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + + t.Run("Enabled ConnLifetime", func(t *testing.T) { + p, _, _, closeConn := setup(t, ClientOption{ + ConnLifetime: 50 * time.Millisecond, + }) + defer closeConn() + + if p.Error() != nil { + t.Fatalf("unexpected error %v", p.Error()) + } + time.Sleep(60 * time.Millisecond) + if p.Error() != errConnExpired { + t.Fatalf("unexpected error, expected: %v, got: %v", errConnExpired, p.Error()) + } + }) + + t.Run("Disabled ConnLifetime", func(t *testing.T) { + p, _, _, closeConn := setup(t, ClientOption{}) + defer closeConn() + + time.Sleep(60 * time.Millisecond) + if p.Error() != nil { + t.Fatalf("unexpected error %v", p.Error()) + } + }) + + t.Run("StopTimer", func(t *testing.T) { + p, _, _, closeConn := setup(t, ClientOption{ + ConnLifetime: 50 * time.Millisecond, + }) + defer closeConn() + + p.StopTimer() + time.Sleep(60 * time.Millisecond) + if p.Error() != nil { + t.Fatalf("unexpected error %v", p.Error()) + } + }) + + t.Run("ResetTimer", func(t *testing.T) { + p, _, _, closeConn := setup(t, ClientOption{ + ConnLifetime: 50 * time.Millisecond, + }) + defer closeConn() + + time.Sleep(20 * time.Millisecond) + p.ResetTimer() + time.Sleep(40 * time.Millisecond) + if p.Error() != nil { + t.Fatalf("unexpected error %v", p.Error()) + } + time.Sleep(20 * time.Millisecond) + if p.Error() != errConnExpired { + t.Fatalf("unexpected error, expected: %v, got: %v", errConnExpired, p.Error()) + } + }) +} + func TestMultiHalfErr(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) p, mock, _, closeConn := setup(t, ClientOption{}) diff --git a/pool.go b/pool.go index cbc49f57..3d954924 100644 --- a/pool.go +++ b/pool.go @@ -81,6 +81,7 @@ retry: // allowing others to make wires concurrently instead of waiting in line p.cond.L.Unlock() v = p.make(ctx) + v.StopTimer() return v } @@ -88,7 +89,7 @@ retry: v = p.list[i] p.list[i] = nil p.list = p.list[:i] - if v.Error() != nil { + if !v.StopTimer() || v.Error() != nil { p.size-- v.Close() goto retry @@ -102,6 +103,7 @@ func (p *pool) Store(v wire) { if !p.down && v.Error() == nil { p.list = append(p.list, v) p.startTimerIfNeeded() + v.ResetTimer() } else { p.size-- v.Close() diff --git a/pool_test.go b/pool_test.go index 14bb50f1..2c66ddf4 100644 --- a/pool_test.go +++ b/pool_test.go @@ -332,6 +332,83 @@ func TestPoolWithIdleTTL(t *testing.T) { }) } +func TestPoolWithConnLifetime(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + setup := func(wires []wire) *pool { + var count int32 + return newPool(len(wires), dead, 0, 0, func(ctx context.Context) wire { + idx := atomic.AddInt32(&count, 1) - 1 + return wires[idx] + }) + } + + t.Run("Reuse without expired connections", func(t *testing.T) { + stopTimerCall := 0 + wires := []wire{ + &mockWire{}, + &mockWire{ + StopTimerFn: func() bool { + stopTimerCall++ + return false + }, // connection lifetime timer is already fired + }, + } + conn := make([]wire, 0, len(wires)) + pool := setup(wires) + for i := 0; i < len(wires); i++ { + conn = append(conn, pool.Acquire(context.Background())) + } + for i := 0; i < len(conn); i++ { + pool.Store(conn[i]) + } + + if stopTimerCall != 1 { + t.Errorf("StopTimer must be called when making wire") + } + + pool.cond.L.Lock() + if pool.size != 2 { + t.Errorf("size must be equal to 2, actual: %d", pool.size) + } + if len(pool.list) != 2 { + t.Errorf("list len must equal to 2, actual: %d", len(pool.list)) + } + pool.cond.L.Unlock() + + // stop timer failed, so drop the expired connection + pool.Store(pool.Acquire(context.Background())) + + if stopTimerCall != 2 { + t.Errorf("StopTimer must be called when acquiring from pool") + } + + pool.cond.L.Lock() + if pool.size != 1 { + t.Errorf("size must be equal to 1, actual: %d", pool.size) + } + if len(pool.list) != 1 { + t.Errorf("list len must equal to 1, actual: %d", len(pool.list)) + } + pool.cond.L.Unlock() + }) + + t.Run("Reset timer when storing to pool", func(t *testing.T) { + call := false + w := &mockWire{ + ResetTimerFn: func() bool { + call = true + return true + }, + } + pool := setup([]wire{w}) + pool.Store(pool.Acquire(context.Background())) + + if !call { + t.Error("ResetTimer must be called when storing") + } + }) +} + func TestPoolWithAcquireCtx(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) setup := func(size int, delay time.Duration) *pool { diff --git a/rueidis.go b/rueidis.go index 60d888e4..a87faed0 100644 --- a/rueidis.go +++ b/rueidis.go @@ -192,6 +192,10 @@ type ClientOption struct { // This default is ClientOption.Dialer.KeepAlive * (9+1), where 9 is the default of tcp_keepalive_probes on Linux. ConnWriteTimeout time.Duration + // ConnLiftime is lifetime for each connection. If specified, + // connections will close after passing lifetime. Note that the connection which dedicated client and blocking use is not closed. + ConnLifetime time.Duration + // MaxFlushDelay when greater than zero pauses pipeline write loop for some time (not larger than MaxFlushDelay) // after each flushing of data to the connection. This gives pipeline a chance to collect more commands to send // to Redis. Adding this delay increases latency, reduces throughput – but in most cases may significantly reduce @@ -505,3 +509,8 @@ func dial(ctx context.Context, dst string, opt *ClientOption) (conn net.Conn, er } const redisErrMsgCommandNotAllow = "command is not allowed" + +var ( + // errConnExpired means wrong connection that ClientOption.ConnLifetime had passed since connecting + errConnExpired = errors.New("connection is expired") +) diff --git a/sentinel.go b/sentinel.go index e6bb6234..8b43a915 100644 --- a/sentinel.go +++ b/sentinel.go @@ -215,7 +215,7 @@ func (c *sentinelClient) Dedicate() (DedicatedClient, func()) { func (c *sentinelClient) Nodes() map[string]Client { conn := c.mConn.Load().(conn) disableCache := c.mOpt != nil && c.mOpt.DisableCache - return map[string]Client{conn.Addr(): newSingleClientWithConn(conn, c.cmd, c.retry, disableCache, c.retryHandler)} + return map[string]Client{conn.Addr(): newSingleClientWithConn(conn, c.cmd, c.retry, disableCache, c.retryHandler, false)} } func (c *sentinelClient) Mode() ClientMode { diff --git a/standalone.go b/standalone.go index 83a231a8..60c2cd29 100644 --- a/standalone.go +++ b/standalone.go @@ -18,7 +18,7 @@ func newStandaloneClient(opt *ClientOption, connFn connFn, retryer retryHandler) } s := &standalone{ toReplicas: opt.SendToReplicas, - primary: newSingleClientWithConn(p, cmds.NewBuilder(cmds.NoSlot), !opt.DisableRetry, opt.DisableCache, retryer), + primary: newSingleClientWithConn(p, cmds.NewBuilder(cmds.NoSlot), !opt.DisableRetry, opt.DisableCache, retryer, false), replicas: make([]*singleClient, len(opt.Standalone.ReplicaAddress)), } opt.ReplicaOnly = true @@ -31,7 +31,7 @@ func newStandaloneClient(opt *ClientOption, connFn connFn, retryer retryHandler) } return nil, err } - s.replicas[i] = newSingleClientWithConn(replicaConn, cmds.NewBuilder(cmds.NoSlot), !opt.DisableRetry, opt.DisableCache, retryer) + s.replicas[i] = newSingleClientWithConn(replicaConn, cmds.NewBuilder(cmds.NoSlot), !opt.DisableRetry, opt.DisableCache, retryer, false) } return s, nil }