Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions inner_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ type innerJob struct {
entryGetter entryGetter
key string
spec string
deriveContext DeriveContext
ctxBefore BeforeContextFunc
before BeforeFunc
run RunFunc
ctxAfter AfterContextFunc
after AfterFunc
retryTimes int
retryInterval RetryInterval
Expand Down Expand Up @@ -75,9 +78,20 @@ func (j *innerJob) Run() {
ctx, cancel := context.WithDeadline(context.WithValue(parentCtx, keyContextTask, task), nextAt)
defer cancel()

if j.before != nil && j.before(task) {
task.Skipped = true
atomic.AddInt64(&j.statistics.SkippedTask, 1)
if j.deriveContext != nil {
ctx = j.deriveContext(ctx, task)
Comment on lines 78 to +82
}
Comment thread
wolfogre marked this conversation as resolved.

Comment thread
wolfogre marked this conversation as resolved.
if j.ctxBefore != nil {
if j.ctxBefore(ctx, task) {
task.Skipped = true
atomic.AddInt64(&j.statistics.SkippedTask, 1)
}
} else if j.before != nil {
if j.before(task) {
task.Skipped = true
atomic.AddInt64(&j.statistics.SkippedTask, 1)
}
}

if !task.Skipped {
Expand Down Expand Up @@ -128,7 +142,9 @@ func (j *innerJob) Run() {
}
}

if j.after != nil {
if j.ctxAfter != nil {
j.ctxAfter(ctx, task)
} else if j.after != nil {
j.after(task)
}

Expand Down
80 changes: 80 additions & 0 deletions inner_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,11 @@ func Test_innerJob_Run(t *testing.T) {
entryGetter entryGetter
key string
spec string
deriveContext DeriveContext
ctxBefore BeforeContextFunc
before BeforeFunc
run RunFunc
ctxAfter AfterContextFunc
after AfterFunc
retryTimes int
retryInterval RetryInterval
Expand Down Expand Up @@ -329,6 +332,80 @@ func Test_innerJob_Run(t *testing.T) {
RetriedRun: 0,
},
},
{
name: "ctx_before_skip",
fields: fields{
cron: NewCron(WithAtomic(atomic)),
entryID: 1,
entryGetter: mockEntryGetter,
ctxBefore: func(ctx context.Context, task Task) (skip bool) {
return true
},
run: func(ctx context.Context) error {
return nil
},
ctxAfter: func(ctx context.Context, task Task) {
if !task.Skipped {
t.Fatal("expected task to be skipped")
}
},
retryTimes: 1,
},
statistics: Statistics{
TotalTask: 1,
SkippedTask: 1,
},
},
{
name: "ctx_before_no_skip",
fields: fields{
cron: NewCron(WithAtomic(atomic)),
entryID: 1,
entryGetter: mockEntryGetter,
ctxBefore: func(ctx context.Context, task Task) (skip bool) {
return false
},
run: func(ctx context.Context) error {
return nil
},
ctxAfter: func(ctx context.Context, task Task) {
if task.Return != nil {
t.Fatal(task.Return)
}
},
retryTimes: 1,
},
statistics: Statistics{
TotalTask: 1,
PassedTask: 1,
TotalRun: 1,
PassedRun: 1,
},
},
{
name: "derive_context",
fields: fields{
cron: NewCron(WithAtomic(atomic)),
entryID: 1,
entryGetter: mockEntryGetter,
deriveContext: func(ctx context.Context, task Task) context.Context {
return context.WithValue(ctx, ctxKey("test_key"), "test_value")
},
run: func(ctx context.Context) error {
if ctx.Value(ctxKey("test_key")) != "test_value" {
return errors.New("context value not found")
}
return nil
},
retryTimes: 1,
},
statistics: Statistics{
TotalTask: 1,
PassedTask: 1,
TotalRun: 1,
PassedRun: 1,
},
},
{
name: "panic by runtime",
fields: fields{
Expand Down Expand Up @@ -370,8 +447,11 @@ func Test_innerJob_Run(t *testing.T) {
entryGetter: tt.fields.entryGetter,
key: tt.fields.key,
spec: tt.fields.spec,
deriveContext: tt.fields.deriveContext,
ctxBefore: tt.fields.ctxBefore,
before: tt.fields.before,
run: tt.fields.run,
ctxAfter: tt.fields.ctxAfter,
after: tt.fields.after,
retryTimes: tt.fields.retryTimes,
retryInterval: tt.fields.retryInterval,
Expand Down
37 changes: 37 additions & 0 deletions job_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,58 @@ import (
type JobOption func(job *innerJob)

// BeforeFunc represents the function could be called before Run.
// Deprecated: Use BeforeContextFunc instead, it will be ignored if BeforeContextFunc is set.
type BeforeFunc func(task Task) (skip bool)

// BeforeContextFunc represents the function could be called before Run with the given Task.
type BeforeContextFunc func(ctx context.Context, task Task) (skip bool)

// RunFunc represents the function could be called by a cron.
type RunFunc func(ctx context.Context) error

// AfterFunc represents the function could be called after Run.
// Deprecated: Use AfterContextFunc instead, it will be ignored if AfterContextFunc is set.
type AfterFunc func(task Task)

// AfterContextFunc represents the function could be called after Run with the given Task.
type AfterContextFunc func(ctx context.Context, task Task)

// RetryInterval indicates how long should delay before retrying when run failed `triedTimes` times.
type RetryInterval func(triedTimes int) time.Duration

// DeriveContext indicates how to derive a new context from the job's base context and the current Task.
type DeriveContext func(ctx context.Context, task Task) context.Context

// WithBeforeFunc specifies what to do before Run.
// Deprecated: Use WithBeforeContextFunc instead.
func WithBeforeFunc(before BeforeFunc) JobOption {
return func(job *innerJob) {
job.before = before
}
}

// WithBeforeContextFunc specifies what to do before Run with the given Task.
func WithBeforeContextFunc(before BeforeContextFunc) JobOption {
return func(job *innerJob) {
job.ctxBefore = before
}
}
Comment thread
wolfogre marked this conversation as resolved.

// WithAfterFunc specifies what to do after Run.
// Deprecated: Use WithAfterContextFunc instead.
func WithAfterFunc(after AfterFunc) JobOption {
return func(job *innerJob) {
job.after = after
}
}

// WithAfterContextFunc specifies what to do after Run with the given Task.
func WithAfterContextFunc(after AfterContextFunc) JobOption {
return func(job *innerJob) {
job.ctxAfter = after
}
}

// WithRetryTimes specifies max times to retry,
// retryTimes will be set as 1 if it is less than 1.
func WithRetryTimes(retryTimes int) JobOption {
Expand Down Expand Up @@ -63,3 +90,13 @@ func WithGroup(group Group) JobOption {
job.group = group
}
}

// WithDeriveContext specifies how to derive a new context for the entire job execution, including
// before/after hooks, Run, and retry logic. The returned context must derive from the provided ctx
// to preserve the deadline, cancellation signal, and the embedded Task value (accessible via TaskFromContext).
// Returning a detached context (e.g., context.Background()) will break deadline enforcement and retry timeout logic.
func WithDeriveContext(deriveContext DeriveContext) JobOption {
return func(job *innerJob) {
job.deriveContext = deriveContext
}
}
108 changes: 105 additions & 3 deletions job_option_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package dcron

import (
"context"
"fmt"
"testing"
"time"
)

func TestWithAfterFunc(t *testing.T) {
after := func(task Task) {

}
after := func(task Task) {}

type args struct {
after AfterFunc
Expand Down Expand Up @@ -164,3 +163,106 @@ func TestWithNoMutex(t *testing.T) {
})
}
}

func TestWithBeforeContextFunc(t *testing.T) {
before := func(ctx context.Context, task Task) (skip bool) {
return false
}

type args struct {
before BeforeContextFunc
}
tests := []struct {
name string
args args
check func(t *testing.T, option JobOption)
}{
{
name: "regular",
args: args{
before: before,
},
check: func(t *testing.T, option JobOption) {
j := &innerJob{}
option(j)
if fmt.Sprintf("%p", j.ctxBefore) != fmt.Sprintf("%p", before) {
t.Fatal()
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := WithBeforeContextFunc(tt.args.before)
tt.check(t, got)
})
}
}

func TestWithAfterContextFunc(t *testing.T) {
after := func(ctx context.Context, task Task) {}

type args struct {
after AfterContextFunc
}
tests := []struct {
name string
args args
check func(t *testing.T, option JobOption)
}{
{
name: "regular",
args: args{
after: after,
},
check: func(t *testing.T, option JobOption) {
j := &innerJob{}
option(j)
if fmt.Sprintf("%p", j.ctxAfter) != fmt.Sprintf("%p", after) {
t.Fatal()
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := WithAfterContextFunc(tt.args.after)
tt.check(t, got)
})
}
}

func TestWithDeriveContext(t *testing.T) {
deriveContext := func(ctx context.Context, task Task) context.Context {
return ctx
}

type args struct {
deriveContext DeriveContext
}
tests := []struct {
name string
args args
check func(t *testing.T, option JobOption)
}{
{
name: "regular",
args: args{
deriveContext: deriveContext,
},
check: func(t *testing.T, option JobOption) {
j := &innerJob{}
option(j)
if fmt.Sprintf("%p", j.deriveContext) != fmt.Sprintf("%p", deriveContext) {
t.Fatal()
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := WithDeriveContext(tt.args.deriveContext)
tt.check(t, got)
})
}
}
Loading