Skip to content

Commit e9dd80f

Browse files
authored
Call OnFatalError for workers using Start (#823)
Fixes #822
1 parent 17c0144 commit e9dd80f

File tree

3 files changed

+79
-43
lines changed

3 files changed

+79
-43
lines changed

internal/internal_worker.go

+39-22
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ type (
192192
// WorkerStopChannel is a read only channel listen on worker close. The worker will close the channel before exit.
193193
WorkerStopChannel <-chan struct{}
194194

195-
// WorkerFatalErrorChannel is a channel for fatal errors that should stop
196-
// the worker. This is sent to asynchronously, so it should be buffered.
197-
WorkerFatalErrorChannel chan<- error
195+
// WorkerFatalErrorCallback is a callback for fatal errors that should stop
196+
// the worker.
197+
WorkerFatalErrorCallback func(error)
198198

199199
// SessionResourceID is a unique identifier of the resource the session will consume
200200
SessionResourceID string
@@ -288,7 +288,7 @@ func newWorkflowTaskWorkerInternal(
288288
identity: params.Identity,
289289
workerType: "WorkflowWorker",
290290
stopTimeout: params.WorkerStopTimeout,
291-
fatalErrCh: params.WorkerFatalErrorChannel},
291+
fatalErrCb: params.WorkerFatalErrorCallback},
292292
params.Logger,
293293
params.MetricsHandler,
294294
nil,
@@ -312,7 +312,7 @@ func newWorkflowTaskWorkerInternal(
312312
identity: params.Identity,
313313
workerType: "LocalActivityWorker",
314314
stopTimeout: params.WorkerStopTimeout,
315-
fatalErrCh: params.WorkerFatalErrorChannel},
315+
fatalErrCb: params.WorkerFatalErrorCallback},
316316
params.Logger,
317317
params.MetricsHandler,
318318
nil,
@@ -428,7 +428,7 @@ func newActivityTaskWorker(taskHandler ActivityTaskHandler, service workflowserv
428428
identity: workerParams.Identity,
429429
workerType: "ActivityWorker",
430430
stopTimeout: workerParams.WorkerStopTimeout,
431-
fatalErrCh: workerParams.WorkerFatalErrorChannel,
431+
fatalErrCb: workerParams.WorkerFatalErrorCallback,
432432
userContextCancel: workerParams.UserContextCancel},
433433
workerParams.Logger,
434434
workerParams.MetricsHandler,
@@ -856,8 +856,8 @@ type AggregatedWorker struct {
856856
logger log.Logger
857857
registry *registry
858858
stopC chan struct{}
859-
fatalErrCh chan error
860-
fatalErrCb func(error)
859+
fatalErr error
860+
fatalErrLock sync.Mutex
861861
}
862862

863863
// RegisterWorkflow registers workflow implementation with the AggregatedWorker
@@ -1026,14 +1026,11 @@ func (aw *AggregatedWorker) Run(interruptCh <-chan interface{}) error {
10261026
case s := <-interruptCh:
10271027
aw.logger.Info("Worker has been stopped.", "Signal", s)
10281028
aw.Stop()
1029-
case err := <-aw.fatalErrCh:
1030-
// Fatal error will already have been logged where it is set
1031-
if aw.fatalErrCb != nil {
1032-
aw.fatalErrCb(err)
1033-
}
1034-
aw.Stop()
1035-
return err
10361029
case <-aw.stopC:
1030+
aw.fatalErrLock.Lock()
1031+
defer aw.fatalErrLock.Unlock()
1032+
// This may be nil if this wasn't stopped due to fatal error
1033+
return aw.fatalErr
10371034
}
10381035
return nil
10391036
}
@@ -1311,9 +1308,30 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke
13111308
panic("cannot set MaxConcurrentWorkflowTaskPollers to 1")
13121309
}
13131310

1314-
// We need this buffered since the sender will be sending async and we only
1315-
// need the first fatal error
1316-
fatalErrCh := make(chan error, 1)
1311+
// Need reference to result for fatal error handler
1312+
var aw *AggregatedWorker
1313+
fatalErrorCallback := func(err error) {
1314+
// Set the fatal error if not already set
1315+
aw.fatalErrLock.Lock()
1316+
alreadySet := aw.fatalErr != nil
1317+
if !alreadySet {
1318+
aw.fatalErr = err
1319+
}
1320+
aw.fatalErrLock.Unlock()
1321+
// Only do the rest if not already set
1322+
if !alreadySet {
1323+
// Invoke the callback if present
1324+
if options.OnFatalError != nil {
1325+
options.OnFatalError(err)
1326+
}
1327+
// Stop the worker if not already stopped
1328+
select {
1329+
case <-aw.stopC:
1330+
default:
1331+
aw.Stop()
1332+
}
1333+
}
1334+
}
13171335

13181336
cache := NewWorkerCache()
13191337
workerParams := workerExecutionParameters{
@@ -1337,7 +1355,7 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke
13371355
WorkflowPanicPolicy: options.WorkflowPanicPolicy,
13381356
DataConverter: client.dataConverter,
13391357
WorkerStopTimeout: options.WorkerStopTimeout,
1340-
WorkerFatalErrorChannel: fatalErrCh,
1358+
WorkerFatalErrorCallback: fatalErrorCallback,
13411359
ContextPropagators: client.contextPropagators,
13421360
DeadlockDetectionTimeout: options.DeadlockDetectionTimeout,
13431361
DefaultHeartbeatThrottleInterval: options.DefaultHeartbeatThrottleInterval,
@@ -1393,17 +1411,16 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke
13931411
})
13941412
}
13951413

1396-
return &AggregatedWorker{
1414+
aw = &AggregatedWorker{
13971415
client: client,
13981416
workflowWorker: workflowWorker,
13991417
activityWorker: activityWorker,
14001418
sessionWorker: sessionWorker,
14011419
logger: workerParams.Logger,
14021420
registry: registry,
14031421
stopC: make(chan struct{}),
1404-
fatalErrCh: fatalErrCh,
1405-
fatalErrCb: options.OnFatalError,
14061422
}
1423+
return aw
14071424
}
14081425

14091426
func processTestTags(wOptions *WorkerOptions, ep *workerExecutionParameters) {

internal/internal_worker_base.go

+5-7
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ type (
148148
identity string
149149
workerType string
150150
stopTimeout time.Duration
151-
fatalErrCh chan<- error
151+
fatalErrCb func(error)
152152
userContextCancel context.CancelFunc
153153
}
154154

@@ -172,7 +172,7 @@ type (
172172

173173
pollerRequestCh chan struct{}
174174
taskQueueCh chan interface{}
175-
fatalErrCh chan<- error
175+
fatalErrCb func(error)
176176
sessionTokenBucket *sessionTokenBucket
177177

178178
lastPollTaskErrMessage string
@@ -214,7 +214,7 @@ func newBaseWorker(
214214
taskSlotsAvailable: int32(options.maxConcurrentTask),
215215
pollerRequestCh: make(chan struct{}, options.maxConcurrentTask),
216216
taskQueueCh: make(chan interface{}), // no buffer, so poller only able to poll new task after previous is dispatched.
217-
fatalErrCh: options.fatalErrCh,
217+
fatalErrCb: options.fatalErrCb,
218218

219219
limiterContext: ctx,
220220
limiterContextCancel: cancel,
@@ -317,10 +317,8 @@ func (bw *baseWorker) pollTask() {
317317
if err != nil {
318318
if isNonRetriableError(err) {
319319
bw.logger.Error("Worker received non-retriable error. Shutting down.", tagError, err)
320-
// Set the error and assume it is buffered with room
321-
select {
322-
case bw.fatalErrCh <- err:
323-
default:
320+
if bw.fatalErrCb != nil {
321+
bw.fatalErrCb(err)
324322
}
325323
return
326324
}

test/integration_test.go

+35-14
Original file line numberDiff line numberDiff line change
@@ -2157,7 +2157,15 @@ func (ts *IntegrationTestSuite) TestLargeHistoryReplay() {
21572157
ts.Contains(err.Error(), "intentional panic")
21582158
}
21592159

2160-
func (ts *IntegrationTestSuite) TestWorkerFatalError() {
2160+
func (ts *IntegrationTestSuite) TestWorkerFatalErrorOnRun() {
2161+
ts.testWorkerFatalError(true)
2162+
}
2163+
2164+
func (ts *IntegrationTestSuite) TestWorkerFatalErrorOnStart() {
2165+
ts.testWorkerFatalError(false)
2166+
}
2167+
2168+
func (ts *IntegrationTestSuite) testWorkerFatalError(useWorkerRun bool) {
21612169
// Make a new client that will fail a poll with a namespace not found
21622170
c, err := client.Dial(client.Options{
21632171
HostPort: ts.config.ServiceAddr,
@@ -2175,6 +2183,8 @@ func (ts *IntegrationTestSuite) TestWorkerFatalError() {
21752183
opts ...grpc.CallOption,
21762184
) error {
21772185
if method == "/temporal.api.workflowservice.v1.WorkflowService/PollWorkflowTaskQueue" {
2186+
// We sleep a bit to let all internal workers start
2187+
time.Sleep(1 * time.Second)
21782188
return serviceerror.NewNamespaceNotFound(ts.config.Namespace)
21792189
}
21802190
return invoker(ctx, method, req, reply, cc, opts...)
@@ -2186,22 +2196,33 @@ func (ts *IntegrationTestSuite) TestWorkerFatalError() {
21862196
defer c.Close()
21872197

21882198
// Create a worker that uses that client
2189-
var lastErr error
2190-
w := worker.New(c, "ignored-task-queue", worker.Options{OnFatalError: func(err error) { lastErr = err }})
2199+
callbackErrCh := make(chan error, 1)
2200+
w := worker.New(c, "ignored-task-queue", worker.Options{OnFatalError: func(err error) { callbackErrCh <- err }})
2201+
2202+
// Do run-based or start-based worker
21912203
runErrCh := make(chan error, 1)
2204+
if useWorkerRun {
2205+
go func() { runErrCh <- w.Run(nil) }()
2206+
} else {
2207+
ts.NoError(w.Start())
2208+
}
21922209

2193-
// Run it and confirm it fails
2194-
go func() { runErrCh <- w.Run(nil) }()
2195-
var runErr error
2196-
select {
2197-
case <-time.After(10 * time.Second):
2198-
ts.Fail("timeout")
2199-
case runErr = <-runErrCh:
2210+
// Wait for done
2211+
var callbackErr, runErr error
2212+
for callbackErr == nil || (useWorkerRun && runErr == nil) {
2213+
select {
2214+
case <-time.After(10 * time.Second):
2215+
ts.Fail("timeout")
2216+
case callbackErr = <-callbackErrCh:
2217+
case runErr = <-runErrCh:
2218+
}
2219+
}
2220+
2221+
// Check error
2222+
ts.IsType(&serviceerror.NamespaceNotFound{}, callbackErr)
2223+
if runErr != nil {
2224+
ts.Equal(callbackErr, runErr)
22002225
}
2201-
ts.Error(lastErr)
2202-
ts.Error(runErr)
2203-
ts.Equal(lastErr, runErr)
2204-
ts.IsType(&serviceerror.NamespaceNotFound{}, runErr)
22052226
}
22062227

22072228
func (ts *IntegrationTestSuite) TestNonDeterminismFailureCauseBadStateMachine() {

0 commit comments

Comments
 (0)