Skip to content

Commit 0198bbf

Browse files
authored
RSDK-10618: Fix race contition in managedProcess (#432)
1 parent f2d4345 commit 0198bbf

File tree

3 files changed

+149
-108
lines changed

3 files changed

+149
-108
lines changed

pexec/managed_process.go

Lines changed: 144 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ import (
1818

1919
var errAlreadyStopped = errors.New("already stopped")
2020

21+
// UnexpectedExitHandler is the signature for functions that can optionally be
22+
// provided to run when a managed process unexpectedly exits. The return value
23+
// indicates whether pexec should continue with its own attempt to restart the
24+
// process: true means pexec will attempt its own restart, false means the
25+
// no restart will be attempted and the process will remain dead.
26+
type UnexpectedExitHandler = func(exitCode int) bool
27+
2128
// A ManagedProcess controls the lifecycle of a single system process. Based on
2229
// its configuration, it will ensure the process is revived if it every unexpectedly
2330
// perishes.
@@ -101,8 +108,7 @@ type managedProcess struct {
101108
shouldLog bool
102109
cmd *exec.Cmd
103110

104-
stopped bool
105-
onUnexpectedExit func(int) bool
111+
onUnexpectedExit UnexpectedExitHandler
106112
managingCh chan struct{}
107113
killCh chan struct{}
108114
stopSig syscall.Signal
@@ -159,7 +165,11 @@ func (p *managedProcess) validateCWD() error {
159165
func (p *managedProcess) Start(ctx context.Context) error {
160166
p.mu.Lock()
161167
defer p.mu.Unlock()
168+
return p.start(ctx)
169+
}
162170

171+
// This internal version of start must be called with the process lock held.
172+
func (p *managedProcess) start(ctx context.Context) error {
163173
// In the event this Start happened from a restart but a
164174
// stop happened while we were acquiring the lock, we may
165175
// need to return early.
@@ -268,85 +278,24 @@ func (p *managedProcess) manage(stdOut, stdErr io.ReadCloser) {
268278
}
269279
}()
270280

271-
// This block here logs as much as possible if it's requested until the
272-
// pipes are closed.
273-
stopLogging := make(chan struct{})
274-
var activeLoggers sync.WaitGroup
275-
if p.shouldLog || p.logWriter != nil {
276-
logPipe := func(name string, pipe io.ReadCloser, isErr bool, logger utils.ZapCompatibleLogger) {
277-
defer activeLoggers.Done()
278-
pipeR := bufio.NewReader(pipe)
279-
logWriterError := false
280-
for {
281-
select {
282-
case <-stopLogging:
283-
return
284-
default:
285-
}
286-
line, _, err := pipeR.ReadLine()
287-
if err != nil {
288-
if !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
289-
p.logger.Errorw("error reading output", "name", name, "error", err)
290-
}
291-
return
292-
}
293-
if p.shouldLog {
294-
if isErr {
295-
logger.Error("\n\\_ " + string(line))
296-
} else {
297-
logger.Info("\n\\_ " + string(line))
298-
}
299-
}
300-
if p.logWriter != nil && !logWriterError {
301-
_, err := p.logWriter.Write(line)
302-
if err == nil {
303-
_, err = p.logWriter.Write([]byte("\n"))
304-
}
305-
if err != nil {
306-
if !errors.Is(err, io.ErrClosedPipe) {
307-
p.logger.Debugw("error writing process output to log writer", "name", name, "error", err)
308-
}
309-
if !p.shouldLog {
310-
return
311-
}
312-
logWriterError = true
313-
}
314-
}
315-
}
316-
}
317-
activeLoggers.Add(2)
318-
utils.PanicCapturingGo(func() {
319-
name := "StdOut"
320-
var logger utils.ZapCompatibleLogger
321-
if p.stdoutLogger != nil {
322-
logger = p.stdoutLogger
323-
} else {
324-
logger = utils.Sublogger(p.logger, name)
325-
}
326-
logPipe(name, stdOut, false, logger)
327-
})
328-
utils.PanicCapturingGo(func() {
329-
name := "StdErr"
330-
var logger utils.ZapCompatibleLogger
331-
if p.stderrLogger != nil {
332-
logger = p.stderrLogger
333-
} else {
334-
logger = utils.Sublogger(p.logger, name)
335-
}
336-
logPipe(name, stdErr, true, logger)
337-
})
338-
}
281+
stopAndDrainLoggers := p.startLoggers(stdOut, stdErr)
339282

340283
err := p.cmd.Wait()
341284
// This is safe to write to because it is only read in Stop which
342285
// is waiting for us to stop managing.
343-
if err == nil {
344-
p.lastWaitErr = nil
345-
} else {
346-
p.lastWaitErr = err
347-
}
348-
close(stopLogging)
349-
activeLoggers.Wait()
286+
p.lastWaitErr = err
287+
288+
stopAndDrainLoggers()
289+
290+
// Take the lock to prevent a race where a crashed process restarts even
291+
// though Stop is called.
292+
p.mu.Lock()
293+
locked := true
294+
defer func() {
295+
if locked {
296+
p.mu.Unlock()
297+
}
298+
}()
350299

351300
// It's possible that Stop was called and is the reason why Wait returned.
352301
select {
@@ -356,10 +305,23 @@ func (p *managedProcess) manage(stdOut, stdErr io.ReadCloser) {
356305
}
357306

358307
// Run onUnexpectedExit if it exists. Do not attempt restart if
359-
// onUnexpectedExit returns false.
360-
if p.onUnexpectedExit != nil &&
361-
!p.onUnexpectedExit(p.cmd.ProcessState.ExitCode()) {
362-
return
308+
// onUnexpectedExit returns false. Drop the lock to avoid deadlocking other
309+
// goroutines that my try to call Stop while we're handling a crash.
310+
if p.onUnexpectedExit != nil {
311+
p.mu.Unlock()
312+
locked = false
313+
if !p.onUnexpectedExit(p.cmd.ProcessState.ExitCode()) {
314+
return
315+
}
316+
p.mu.Lock()
317+
locked = true
318+
// Dropped the lock to call the oue handler, check if we were stopped during
319+
// that time.
320+
select {
321+
case <-p.killCh:
322+
return
323+
default:
324+
}
363325
}
364326

365327
// Otherwise, let's try restarting the process.
@@ -387,40 +349,123 @@ func (p *managedProcess) manage(stdOut, stdErr io.ReadCloser) {
387349
return
388350
}
389351

390-
err = p.Start(context.Background())
352+
// Use the internal version of start since we already hold the lock.
353+
err = p.start(context.Background())
391354
if err != nil {
392355
if !errors.Is(err, errAlreadyStopped) {
393-
// MAYBE(erd): add retry
394356
p.logger.Errorw("error restarting process", "error", err)
395357
}
396358
return
397359
}
398360
restarted = true
399361
}
400362

363+
// This helper function is only meant to be called from manage. If logging is
364+
// enabled it creates goroutines that log as much as possible until the pipes
365+
// are closed. It returns a function that terminates logging and blocks until
366+
// the loggers drain.
367+
func (p *managedProcess) startLoggers(stdOut, stdErr io.ReadCloser) func() {
368+
if !p.shouldLog && p.logWriter == nil {
369+
// No logging enabled, return a noop func so the caller can unconditionally
370+
// invoke it.
371+
return func() {}
372+
}
373+
374+
stopLogging := make(chan struct{})
375+
var activeLoggers sync.WaitGroup
376+
activeLoggers.Add(2)
377+
logPipe := func(name string, pipe io.ReadCloser, isErr bool, logger utils.ZapCompatibleLogger) {
378+
defer activeLoggers.Done()
379+
pipeR := bufio.NewReader(pipe)
380+
logWriterError := false
381+
for {
382+
select {
383+
case <-stopLogging:
384+
return
385+
default:
386+
}
387+
line, _, err := pipeR.ReadLine()
388+
if err != nil {
389+
if !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
390+
p.logger.Errorw("error reading output", "name", name, "error", err)
391+
}
392+
return
393+
}
394+
if p.shouldLog {
395+
if isErr {
396+
logger.Error("\n\\_ " + string(line))
397+
} else {
398+
logger.Info("\n\\_ " + string(line))
399+
}
400+
}
401+
if p.logWriter != nil && !logWriterError {
402+
_, err := p.logWriter.Write(line)
403+
if err == nil {
404+
_, err = p.logWriter.Write([]byte("\n"))
405+
}
406+
if err != nil {
407+
if !errors.Is(err, io.ErrClosedPipe) {
408+
p.logger.Debugw("error writing process output to log writer", "name", name, "error", err)
409+
}
410+
if !p.shouldLog {
411+
return
412+
}
413+
logWriterError = true
414+
}
415+
}
416+
}
417+
}
418+
419+
utils.PanicCapturingGo(func() {
420+
name := "StdOut"
421+
var logger utils.ZapCompatibleLogger
422+
if p.stdoutLogger != nil {
423+
logger = p.stdoutLogger
424+
} else {
425+
logger = utils.Sublogger(p.logger, name)
426+
}
427+
logPipe(name, stdOut, false, logger)
428+
})
429+
utils.PanicCapturingGo(func() {
430+
name := "StdErr"
431+
var logger utils.ZapCompatibleLogger
432+
if p.stderrLogger != nil {
433+
logger = p.stderrLogger
434+
} else {
435+
logger = utils.Sublogger(p.logger, name)
436+
}
437+
logPipe(name, stdErr, true, logger)
438+
})
439+
440+
return func() {
441+
close(stopLogging)
442+
activeLoggers.Wait()
443+
}
444+
}
445+
401446
func (p *managedProcess) Stop() error {
402-
// Minimally hold a lock here so that we can signal the
403-
// management goroutine to stop. If we were to hold the
404-
// lock for the duration of the function, we would possibly
405-
// deadlock with manage trying to restart.
406447
p.mu.Lock()
407-
if p.stopped {
448+
449+
// Return early if the process has already been killed.
450+
select {
451+
case <-p.killCh:
408452
p.mu.Unlock()
453+
<-p.managingCh
409454
return nil
455+
default:
410456
}
457+
411458
close(p.killCh)
412-
p.stopped = true
413459

460+
// All relevant methods wait on the lock we hold and will not attempt to
461+
// (re)start the process now that we closed killch, so we can safely drop the
462+
// lock to unblock other calls while we proceed with shutown.
463+
p.mu.Unlock()
464+
465+
// Nothing to do.
414466
if p.cmd == nil {
415-
p.mu.Unlock()
416467
return nil
417468
}
418-
p.mu.Unlock()
419-
420-
// Since p.cmd is mutex guarded and we just signaled the manage
421-
// goroutine to stop, no new Start can happen and therefore
422-
// p.cmd can no longer be modified rendering it safe to read
423-
// without a lock held.
424469

425470
forceKilled, err := p.kill()
426471
if err != nil {
@@ -460,9 +505,10 @@ func (p *managedProcess) KillGroup() {
460505
// management goroutine to stop. We will attempt to kill the
461506
// process even if p.stopped is true.
462507
p.mu.Lock()
463-
if !p.stopped {
508+
select {
509+
case <-p.killCh:
510+
default:
464511
close(p.killCh)
465-
p.stopped = true
466512
}
467513

468514
if p.cmd == nil {

pexec/process.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ type ProcessConfig struct {
4545
//
4646
// NOTE(benjirewis): use `jsonschema:"-"` struct tag to avoid issues with
4747
// jsonschema reflection (go functions cannot be encoded to JSON).
48-
OnUnexpectedExit func(int) bool `jsonschema:"-"`
48+
OnUnexpectedExit UnexpectedExitHandler `jsonschema:"-"`
4949
// The logger to use for STDOUT of this process. If not specified, will use
5050
// a sublogger of the `logger` parameter given to `NewManagedProcess`.
5151
StdOutLogger utils.ZapCompatibleLogger

runtime.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ func PanicCapturingGoWithCallback(f func(), callback func(err interface{})) {
143143
go func() {
144144
defer func() {
145145
if err := recover(); err != nil {
146-
debug.PrintStack()
147146
golog.Global().Errorw("panic while running function", "error", err)
147+
debug.PrintStack()
148148
if callback == nil {
149149
return
150150
}
@@ -161,15 +161,10 @@ func PanicCapturingGoWithCallback(f func(), callback func(err interface{})) {
161161
// it terminates normally.
162162
func ManagedGo(f, onComplete func()) {
163163
PanicCapturingGoWithCallback(func() {
164-
defer func() {
165-
if err := recover(); err == nil && onComplete != nil {
166-
onComplete()
167-
} else if err != nil {
168-
// re-panic
169-
panic(err)
170-
}
171-
}()
172164
f()
165+
if onComplete != nil {
166+
onComplete()
167+
}
173168
}, func(_ interface{}) {
174169
ManagedGo(f, onComplete)
175170
})

0 commit comments

Comments
 (0)