Skip to content

Commit c368a43

Browse files
authored
Support more options with context (#27)
* feat: support DeriveContext * feat: support WithBeforeContextFunc & WithAfterContextFunc * test: add more cases * docs: comment for WithDeriveContext * test: fix lint
1 parent 11a2718 commit c368a43

4 files changed

Lines changed: 242 additions & 7 deletions

File tree

inner_job.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ type innerJob struct {
2626
entryGetter entryGetter
2727
key string
2828
spec string
29+
deriveContext DeriveContext
30+
ctxBefore BeforeContextFunc
2931
before BeforeFunc
3032
run RunFunc
33+
ctxAfter AfterContextFunc
3134
after AfterFunc
3235
retryTimes int
3336
retryInterval RetryInterval
@@ -75,9 +78,20 @@ func (j *innerJob) Run() {
7578
ctx, cancel := context.WithDeadline(context.WithValue(parentCtx, keyContextTask, task), nextAt)
7679
defer cancel()
7780

78-
if j.before != nil && j.before(task) {
79-
task.Skipped = true
80-
atomic.AddInt64(&j.statistics.SkippedTask, 1)
81+
if j.deriveContext != nil {
82+
ctx = j.deriveContext(ctx, task)
83+
}
84+
85+
if j.ctxBefore != nil {
86+
if j.ctxBefore(ctx, task) {
87+
task.Skipped = true
88+
atomic.AddInt64(&j.statistics.SkippedTask, 1)
89+
}
90+
} else if j.before != nil {
91+
if j.before(task) {
92+
task.Skipped = true
93+
atomic.AddInt64(&j.statistics.SkippedTask, 1)
94+
}
8195
}
8296

8397
if !task.Skipped {
@@ -128,7 +142,9 @@ func (j *innerJob) Run() {
128142
}
129143
}
130144

131-
if j.after != nil {
145+
if j.ctxAfter != nil {
146+
j.ctxAfter(ctx, task)
147+
} else if j.after != nil {
132148
j.after(task)
133149
}
134150

inner_job_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,11 @@ func Test_innerJob_Run(t *testing.T) {
9090
entryGetter entryGetter
9191
key string
9292
spec string
93+
deriveContext DeriveContext
94+
ctxBefore BeforeContextFunc
9395
before BeforeFunc
9496
run RunFunc
97+
ctxAfter AfterContextFunc
9598
after AfterFunc
9699
retryTimes int
97100
retryInterval RetryInterval
@@ -329,6 +332,80 @@ func Test_innerJob_Run(t *testing.T) {
329332
RetriedRun: 0,
330333
},
331334
},
335+
{
336+
name: "ctx_before_skip",
337+
fields: fields{
338+
cron: NewCron(WithAtomic(atomic)),
339+
entryID: 1,
340+
entryGetter: mockEntryGetter,
341+
ctxBefore: func(ctx context.Context, task Task) (skip bool) {
342+
return true
343+
},
344+
run: func(ctx context.Context) error {
345+
return nil
346+
},
347+
ctxAfter: func(ctx context.Context, task Task) {
348+
if !task.Skipped {
349+
t.Fatal("expected task to be skipped")
350+
}
351+
},
352+
retryTimes: 1,
353+
},
354+
statistics: Statistics{
355+
TotalTask: 1,
356+
SkippedTask: 1,
357+
},
358+
},
359+
{
360+
name: "ctx_before_no_skip",
361+
fields: fields{
362+
cron: NewCron(WithAtomic(atomic)),
363+
entryID: 1,
364+
entryGetter: mockEntryGetter,
365+
ctxBefore: func(ctx context.Context, task Task) (skip bool) {
366+
return false
367+
},
368+
run: func(ctx context.Context) error {
369+
return nil
370+
},
371+
ctxAfter: func(ctx context.Context, task Task) {
372+
if task.Return != nil {
373+
t.Fatal(task.Return)
374+
}
375+
},
376+
retryTimes: 1,
377+
},
378+
statistics: Statistics{
379+
TotalTask: 1,
380+
PassedTask: 1,
381+
TotalRun: 1,
382+
PassedRun: 1,
383+
},
384+
},
385+
{
386+
name: "derive_context",
387+
fields: fields{
388+
cron: NewCron(WithAtomic(atomic)),
389+
entryID: 1,
390+
entryGetter: mockEntryGetter,
391+
deriveContext: func(ctx context.Context, task Task) context.Context {
392+
return context.WithValue(ctx, ctxKey("test_key"), "test_value")
393+
},
394+
run: func(ctx context.Context) error {
395+
if ctx.Value(ctxKey("test_key")) != "test_value" {
396+
return errors.New("context value not found")
397+
}
398+
return nil
399+
},
400+
retryTimes: 1,
401+
},
402+
statistics: Statistics{
403+
TotalTask: 1,
404+
PassedTask: 1,
405+
TotalRun: 1,
406+
PassedRun: 1,
407+
},
408+
},
332409
{
333410
name: "panic by runtime",
334411
fields: fields{
@@ -370,8 +447,11 @@ func Test_innerJob_Run(t *testing.T) {
370447
entryGetter: tt.fields.entryGetter,
371448
key: tt.fields.key,
372449
spec: tt.fields.spec,
450+
deriveContext: tt.fields.deriveContext,
451+
ctxBefore: tt.fields.ctxBefore,
373452
before: tt.fields.before,
374453
run: tt.fields.run,
454+
ctxAfter: tt.fields.ctxAfter,
375455
after: tt.fields.after,
376456
retryTimes: tt.fields.retryTimes,
377457
retryInterval: tt.fields.retryInterval,

job_option.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,58 @@ import (
99
type JobOption func(job *innerJob)
1010

1111
// BeforeFunc represents the function could be called before Run.
12+
// Deprecated: Use BeforeContextFunc instead, it will be ignored if BeforeContextFunc is set.
1213
type BeforeFunc func(task Task) (skip bool)
1314

15+
// BeforeContextFunc represents the function could be called before Run with the given Task.
16+
type BeforeContextFunc func(ctx context.Context, task Task) (skip bool)
17+
1418
// RunFunc represents the function could be called by a cron.
1519
type RunFunc func(ctx context.Context) error
1620

1721
// AfterFunc represents the function could be called after Run.
22+
// Deprecated: Use AfterContextFunc instead, it will be ignored if AfterContextFunc is set.
1823
type AfterFunc func(task Task)
1924

25+
// AfterContextFunc represents the function could be called after Run with the given Task.
26+
type AfterContextFunc func(ctx context.Context, task Task)
27+
2028
// RetryInterval indicates how long should delay before retrying when run failed `triedTimes` times.
2129
type RetryInterval func(triedTimes int) time.Duration
2230

31+
// DeriveContext indicates how to derive a new context from the job's base context and the current Task.
32+
type DeriveContext func(ctx context.Context, task Task) context.Context
33+
2334
// WithBeforeFunc specifies what to do before Run.
35+
// Deprecated: Use WithBeforeContextFunc instead.
2436
func WithBeforeFunc(before BeforeFunc) JobOption {
2537
return func(job *innerJob) {
2638
job.before = before
2739
}
2840
}
2941

42+
// WithBeforeContextFunc specifies what to do before Run with the given Task.
43+
func WithBeforeContextFunc(before BeforeContextFunc) JobOption {
44+
return func(job *innerJob) {
45+
job.ctxBefore = before
46+
}
47+
}
48+
3049
// WithAfterFunc specifies what to do after Run.
50+
// Deprecated: Use WithAfterContextFunc instead.
3151
func WithAfterFunc(after AfterFunc) JobOption {
3252
return func(job *innerJob) {
3353
job.after = after
3454
}
3555
}
3656

57+
// WithAfterContextFunc specifies what to do after Run with the given Task.
58+
func WithAfterContextFunc(after AfterContextFunc) JobOption {
59+
return func(job *innerJob) {
60+
job.ctxAfter = after
61+
}
62+
}
63+
3764
// WithRetryTimes specifies max times to retry,
3865
// retryTimes will be set as 1 if it is less than 1.
3966
func WithRetryTimes(retryTimes int) JobOption {
@@ -63,3 +90,13 @@ func WithGroup(group Group) JobOption {
6390
job.group = group
6491
}
6592
}
93+
94+
// WithDeriveContext specifies how to derive a new context for the entire job execution, including
95+
// before/after hooks, Run, and retry logic. The returned context must derive from the provided ctx
96+
// to preserve the deadline, cancellation signal, and the embedded Task value (accessible via TaskFromContext).
97+
// Returning a detached context (e.g., context.Background()) will break deadline enforcement and retry timeout logic.
98+
func WithDeriveContext(deriveContext DeriveContext) JobOption {
99+
return func(job *innerJob) {
100+
job.deriveContext = deriveContext
101+
}
102+
}

job_option_test.go

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
package dcron
22

33
import (
4+
"context"
45
"fmt"
56
"testing"
67
"time"
78
)
89

910
func TestWithAfterFunc(t *testing.T) {
10-
after := func(task Task) {
11-
12-
}
11+
after := func(task Task) {}
1312

1413
type args struct {
1514
after AfterFunc
@@ -164,3 +163,106 @@ func TestWithNoMutex(t *testing.T) {
164163
})
165164
}
166165
}
166+
167+
func TestWithBeforeContextFunc(t *testing.T) {
168+
before := func(ctx context.Context, task Task) (skip bool) {
169+
return false
170+
}
171+
172+
type args struct {
173+
before BeforeContextFunc
174+
}
175+
tests := []struct {
176+
name string
177+
args args
178+
check func(t *testing.T, option JobOption)
179+
}{
180+
{
181+
name: "regular",
182+
args: args{
183+
before: before,
184+
},
185+
check: func(t *testing.T, option JobOption) {
186+
j := &innerJob{}
187+
option(j)
188+
if fmt.Sprintf("%p", j.ctxBefore) != fmt.Sprintf("%p", before) {
189+
t.Fatal()
190+
}
191+
},
192+
},
193+
}
194+
for _, tt := range tests {
195+
t.Run(tt.name, func(t *testing.T) {
196+
got := WithBeforeContextFunc(tt.args.before)
197+
tt.check(t, got)
198+
})
199+
}
200+
}
201+
202+
func TestWithAfterContextFunc(t *testing.T) {
203+
after := func(ctx context.Context, task Task) {}
204+
205+
type args struct {
206+
after AfterContextFunc
207+
}
208+
tests := []struct {
209+
name string
210+
args args
211+
check func(t *testing.T, option JobOption)
212+
}{
213+
{
214+
name: "regular",
215+
args: args{
216+
after: after,
217+
},
218+
check: func(t *testing.T, option JobOption) {
219+
j := &innerJob{}
220+
option(j)
221+
if fmt.Sprintf("%p", j.ctxAfter) != fmt.Sprintf("%p", after) {
222+
t.Fatal()
223+
}
224+
},
225+
},
226+
}
227+
for _, tt := range tests {
228+
t.Run(tt.name, func(t *testing.T) {
229+
got := WithAfterContextFunc(tt.args.after)
230+
tt.check(t, got)
231+
})
232+
}
233+
}
234+
235+
func TestWithDeriveContext(t *testing.T) {
236+
deriveContext := func(ctx context.Context, task Task) context.Context {
237+
return ctx
238+
}
239+
240+
type args struct {
241+
deriveContext DeriveContext
242+
}
243+
tests := []struct {
244+
name string
245+
args args
246+
check func(t *testing.T, option JobOption)
247+
}{
248+
{
249+
name: "regular",
250+
args: args{
251+
deriveContext: deriveContext,
252+
},
253+
check: func(t *testing.T, option JobOption) {
254+
j := &innerJob{}
255+
option(j)
256+
if fmt.Sprintf("%p", j.deriveContext) != fmt.Sprintf("%p", deriveContext) {
257+
t.Fatal()
258+
}
259+
},
260+
},
261+
}
262+
for _, tt := range tests {
263+
t.Run(tt.name, func(t *testing.T) {
264+
got := WithDeriveContext(tt.args.deriveContext)
265+
tt.check(t, got)
266+
})
267+
}
268+
}

0 commit comments

Comments
 (0)