Skip to content

Commit afb6c63

Browse files
committed
v3: Added context to execute job func for improve graceful shutdown
1 parent e843a09 commit afb6c63

File tree

5 files changed

+100
-87
lines changed

5 files changed

+100
-87
lines changed

chain.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cron
22

33
import (
4+
"context"
45
"fmt"
56
"runtime"
67
"sync"
@@ -24,9 +25,12 @@ func NewChain(c ...JobWrapper) Chain {
2425
// Then decorates the given job with all JobWrappers in the chain.
2526
//
2627
// This:
27-
// NewChain(m1, m2, m3).Then(job)
28+
//
29+
// NewChain(m1, m2, m3).Then(job)
30+
//
2831
// is equivalent to:
29-
// m1(m2(m3(job)))
32+
//
33+
// m1(m2(m3(job)))
3034
func (c Chain) Then(j Job) Job {
3135
for i := range c.wrappers {
3236
j = c.wrappers[len(c.wrappers)-i-1](j)
@@ -37,7 +41,7 @@ func (c Chain) Then(j Job) Job {
3741
// Recover panics in wrapped jobs and log them with the provided logger.
3842
func Recover(logger Logger) JobWrapper {
3943
return func(j Job) Job {
40-
return FuncJob(func() {
44+
return FuncJob(func(ctx context.Context) {
4145
defer func() {
4246
if r := recover(); r != nil {
4347
const size = 64 << 10
@@ -50,7 +54,7 @@ func Recover(logger Logger) JobWrapper {
5054
logger.Error(err, "panic", "stack", "...\n"+string(buf))
5155
}
5256
}()
53-
j.Run()
57+
j.Run(ctx)
5458
})
5559
}
5660
}
@@ -61,14 +65,14 @@ func Recover(logger Logger) JobWrapper {
6165
func DelayIfStillRunning(logger Logger) JobWrapper {
6266
return func(j Job) Job {
6367
var mu sync.Mutex
64-
return FuncJob(func() {
68+
return FuncJob(func(ctx context.Context) {
6569
start := time.Now()
6670
mu.Lock()
6771
defer mu.Unlock()
6872
if dur := time.Since(start); dur > time.Minute {
6973
logger.Info("delay", "duration", dur)
7074
}
71-
j.Run()
75+
j.Run(ctx)
7276
})
7377
}
7478
}
@@ -79,10 +83,10 @@ func SkipIfStillRunning(logger Logger) JobWrapper {
7983
var ch = make(chan struct{}, 1)
8084
ch <- struct{}{}
8185
return func(j Job) Job {
82-
return FuncJob(func() {
86+
return FuncJob(func(ctx context.Context) {
8387
select {
8488
case v := <-ch:
85-
j.Run()
89+
j.Run(ctx)
8690
ch <- v
8791
default:
8892
logger.Info("skip")

chain_test.go

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cron
22

33
import (
4+
"context"
45
"io/ioutil"
56
"log"
67
"reflect"
@@ -11,7 +12,7 @@ import (
1112

1213
func appendingJob(slice *[]int, value int) Job {
1314
var m sync.Mutex
14-
return FuncJob(func() {
15+
return FuncJob(func(ctx context.Context) {
1516
m.Lock()
1617
*slice = append(*slice, value)
1718
m.Unlock()
@@ -20,9 +21,9 @@ func appendingJob(slice *[]int, value int) Job {
2021

2122
func appendingWrapper(slice *[]int, value int) JobWrapper {
2223
return func(j Job) Job {
23-
return FuncJob(func() {
24-
appendingJob(slice, value).Run()
25-
j.Run()
24+
return FuncJob(func(ctx context.Context) {
25+
appendingJob(slice, value).Run(ctx)
26+
j.Run(ctx)
2627
})
2728
}
2829
}
@@ -35,14 +36,14 @@ func TestChain(t *testing.T) {
3536
append3 = appendingWrapper(&nums, 3)
3637
append4 = appendingJob(&nums, 4)
3738
)
38-
NewChain(append1, append2, append3).Then(append4).Run()
39+
NewChain(append1, append2, append3).Then(append4).Run(context.Background())
3940
if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) {
4041
t.Error("unexpected order of calls:", nums)
4142
}
4243
}
4344

4445
func TestChainRecover(t *testing.T) {
45-
panickingJob := FuncJob(func() {
46+
panickingJob := FuncJob(func(ctx context.Context) {
4647
panic("panickingJob panics")
4748
})
4849

@@ -53,19 +54,19 @@ func TestChainRecover(t *testing.T) {
5354
}
5455
}()
5556
NewChain().Then(panickingJob).
56-
Run()
57+
Run(context.Background())
5758
})
5859

5960
t.Run("Recovering JobWrapper recovers", func(t *testing.T) {
6061
NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))).
6162
Then(panickingJob).
62-
Run()
63+
Run(context.Background())
6364
})
6465

6566
t.Run("composed with the *IfStillRunning wrappers", func(t *testing.T) {
6667
NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))).
6768
Then(panickingJob).
68-
Run()
69+
Run(context.Background())
6970
})
7071
}
7172

@@ -76,7 +77,7 @@ type countJob struct {
7677
delay time.Duration
7778
}
7879

79-
func (j *countJob) Run() {
80+
func (j *countJob) Run(context.Context) {
8081
j.m.Lock()
8182
j.started++
8283
j.m.Unlock()
@@ -103,7 +104,7 @@ func TestChainDelayIfStillRunning(t *testing.T) {
103104
t.Run("runs immediately", func(t *testing.T) {
104105
var j countJob
105106
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
106-
go wrappedJob.Run()
107+
go wrappedJob.Run(context.Background())
107108
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
108109
if c := j.Done(); c != 1 {
109110
t.Errorf("expected job run once, immediately, got %d", c)
@@ -114,9 +115,9 @@ func TestChainDelayIfStillRunning(t *testing.T) {
114115
var j countJob
115116
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
116117
go func() {
117-
go wrappedJob.Run()
118+
go wrappedJob.Run(context.Background())
118119
time.Sleep(time.Millisecond)
119-
go wrappedJob.Run()
120+
go wrappedJob.Run(context.Background())
120121
}()
121122
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
122123
if c := j.Done(); c != 2 {
@@ -129,9 +130,9 @@ func TestChainDelayIfStillRunning(t *testing.T) {
129130
j.delay = 10 * time.Millisecond
130131
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
131132
go func() {
132-
go wrappedJob.Run()
133+
go wrappedJob.Run(context.Background())
133134
time.Sleep(time.Millisecond)
134-
go wrappedJob.Run()
135+
go wrappedJob.Run(context.Background())
135136
}()
136137

137138
// After 5ms, the first job is still in progress, and the second job was
@@ -157,7 +158,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
157158
t.Run("runs immediately", func(t *testing.T) {
158159
var j countJob
159160
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
160-
go wrappedJob.Run()
161+
go wrappedJob.Run(context.Background())
161162
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
162163
if c := j.Done(); c != 1 {
163164
t.Errorf("expected job run once, immediately, got %d", c)
@@ -168,9 +169,9 @@ func TestChainSkipIfStillRunning(t *testing.T) {
168169
var j countJob
169170
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
170171
go func() {
171-
go wrappedJob.Run()
172+
go wrappedJob.Run(context.Background())
172173
time.Sleep(time.Millisecond)
173-
go wrappedJob.Run()
174+
go wrappedJob.Run(context.Background())
174175
}()
175176
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
176177
if c := j.Done(); c != 2 {
@@ -183,9 +184,9 @@ func TestChainSkipIfStillRunning(t *testing.T) {
183184
j.delay = 10 * time.Millisecond
184185
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
185186
go func() {
186-
go wrappedJob.Run()
187+
go wrappedJob.Run(context.Background())
187188
time.Sleep(time.Millisecond)
188-
go wrappedJob.Run()
189+
go wrappedJob.Run(context.Background())
189190
}()
190191

191192
// After 5ms, the first job is still in progress, and the second job was
@@ -209,7 +210,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
209210
j.delay = 10 * time.Millisecond
210211
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
211212
for i := 0; i < 11; i++ {
212-
go wrappedJob.Run()
213+
go wrappedJob.Run(context.Background())
213214
}
214215
time.Sleep(200 * time.Millisecond)
215216
done := j.Done()

cron.go

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ type Cron struct {
2424
parser Parser
2525
nextID EntryID
2626
jobWaiter sync.WaitGroup
27+
ctx context.Context
28+
cancel context.CancelFunc
2729
}
2830

2931
// Job is an interface for submitted cron jobs.
3032
type Job interface {
31-
Run()
33+
Run(ctx context.Context)
3234
}
3335

3436
// Schedule describes a job's duty cycle.
@@ -92,20 +94,21 @@ func (s byTime) Less(i, j int) bool {
9294
//
9395
// Available Settings
9496
//
95-
// Time Zone
96-
// Description: The time zone in which schedules are interpreted
97-
// Default: time.Local
97+
// Time Zone
98+
// Description: The time zone in which schedules are interpreted
99+
// Default: time.Local
98100
//
99-
// Parser
100-
// Description: Parser converts cron spec strings into cron.Schedules.
101-
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
101+
// Parser
102+
// Description: Parser converts cron spec strings into cron.Schedules.
103+
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
102104
//
103-
// Chain
104-
// Description: Wrap submitted jobs to customize behavior.
105-
// Default: A chain that recovers panics and logs them to stderr.
105+
// Chain
106+
// Description: Wrap submitted jobs to customize behavior.
107+
// Default: A chain that recovers panics and logs them to stderr.
106108
//
107109
// See "cron.With*" to modify the default behavior.
108110
func New(opts ...Option) *Cron {
111+
ctx, cancel := context.WithCancel(context.Background())
109112
c := &Cron{
110113
entries: nil,
111114
chain: NewChain(),
@@ -118,6 +121,8 @@ func New(opts ...Option) *Cron {
118121
logger: DefaultLogger,
119122
location: time.Local,
120123
parser: standardParser,
124+
ctx: ctx,
125+
cancel: cancel,
121126
}
122127
for _, opt := range opts {
123128
opt(c)
@@ -126,14 +131,14 @@ func New(opts ...Option) *Cron {
126131
}
127132

128133
// FuncJob is a wrapper that turns a func() into a cron.Job
129-
type FuncJob func()
134+
type FuncJob func(ctx context.Context)
130135

131-
func (f FuncJob) Run() { f() }
136+
func (f FuncJob) Run(ctx context.Context) { f(ctx) }
132137

133138
// AddFunc adds a func to the Cron to be run on the given schedule.
134139
// The spec is parsed using the time zone of this Cron instance as the default.
135140
// An opaque ID is returned that can be used to later remove it.
136-
func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) {
141+
func (c *Cron) AddFunc(spec string, cmd func(ctx context.Context)) (EntryID, error) {
137142
return c.AddJob(spec, FuncJob(cmd))
138143
}
139144

@@ -304,7 +309,7 @@ func (c *Cron) startJob(j Job) {
304309
c.jobWaiter.Add(1)
305310
go func() {
306311
defer c.jobWaiter.Done()
307-
j.Run()
312+
j.Run(c.ctx)
308313
}()
309314
}
310315

@@ -319,6 +324,7 @@ func (c *Cron) Stop() context.Context {
319324
c.runningMu.Lock()
320325
defer c.runningMu.Unlock()
321326
if c.running {
327+
c.cancel()
322328
c.stop <- struct{}{}
323329
c.running = false
324330
}

0 commit comments

Comments
 (0)