diff --git a/options.go b/options.go index c509353b18..49cf378b59 100644 --- a/options.go +++ b/options.go @@ -19,7 +19,7 @@ type ProgramOption func(*Program) // cancelled it will exit with an error ErrProgramKilled. func WithContext(ctx context.Context) ProgramOption { return func(p *Program) { - p.ctx = ctx + p.externalCtx = ctx } } diff --git a/options_test.go b/options_test.go index ce8d41795a..60336c61b4 100644 --- a/options_test.go +++ b/options_test.go @@ -2,6 +2,7 @@ package tea import ( "bytes" + "context" "os" "sync/atomic" "testing" @@ -51,6 +52,16 @@ func TestOptions(t *testing.T) { } }) + t.Run("external context", func(t *testing.T) { + extCtx, extCancel := context.WithCancel(context.Background()) + defer extCancel() + + p := NewProgram(nil, WithContext(extCtx)) + if p.externalCtx != extCtx || p.externalCtx == context.Background() { + t.Errorf("expected passed in external context, got default (nil)") + } + }) + t.Run("input options", func(t *testing.T) { exercise := func(t *testing.T, opt ProgramOption, expect inputType) { p := NewProgram(nil, opt) diff --git a/tea.go b/tea.go index 0bc915e5b3..d48f832f4e 100644 --- a/tea.go +++ b/tea.go @@ -27,6 +27,9 @@ import ( "golang.org/x/sync/errgroup" ) +// ErrProgramPanic is returned by [Program.Run] when the program recovers from a panic. +var ErrProgramPanic = errors.New("program experienced a panic") + // ErrProgramKilled is returned by [Program.Run] when the program gets killed. var ErrProgramKilled = errors.New("program was killed") @@ -147,6 +150,12 @@ type Program struct { inputType inputType + // externalCtx is a context that was passed in via WithContext, otherwise defaulting + // to ctx.Background() (in case it was not), the internal context is derived from it. + externalCtx context.Context + + // ctx is the programs's internal context for signalling internal teardown. + // It is built and derived from the externalCtx in NewProgram(). ctx context.Context cancel context.CancelFunc @@ -243,11 +252,11 @@ func NewProgram(model Model, opts ...ProgramOption) *Program { // A context can be provided with a ProgramOption, but if none was provided // we'll use the default background context. - if p.ctx == nil { - p.ctx = context.Background() + if p.externalCtx == nil { + p.externalCtx = context.Background() } // Initialize context and teardown channel. - p.ctx, p.cancel = context.WithCancel(p.ctx) + p.ctx, p.cancel = context.WithCancel(p.externalCtx) // if no output was set, set it to stdout if p.output == nil { @@ -346,7 +355,11 @@ func (p *Program) handleCommands(cmds chan Cmd) chan struct{} { go func() { // Recover from panics. if !p.startupOptions.has(withoutCatchPanics) { - defer p.recoverFromPanic() + defer func() { + if r := recover(); r != nil { + p.recoverFromGoPanic(r) + } + }() } msg := cmd() // this can be long. @@ -460,7 +473,11 @@ func (p *Program) eventLoop(model Model, cmds chan Cmd) (Model, error) { case BatchMsg: for _, cmd := range msg { - cmds <- cmd + select { + case <-p.ctx.Done(): + return model, nil + case cmds <- cmd: + } } continue @@ -506,7 +523,13 @@ func (p *Program) eventLoop(model Model, cmds chan Cmd) (Model, error) { var cmd Cmd model, cmd = model.Update(msg) // run update - cmds <- cmd // process command (if any) + + select { + case <-p.ctx.Done(): + return model, nil + case cmds <- cmd: // process command (if any) + } + p.renderer.write(model.View()) // send view to renderer } } @@ -515,11 +538,15 @@ func (p *Program) eventLoop(model Model, cmds chan Cmd) (Model, error) { // Run initializes the program and runs its event loops, blocking until it gets // terminated by either [Program.Quit], [Program.Kill], or its signal handler. // Returns the final model. -func (p *Program) Run() (Model, error) { +func (p *Program) Run() (returnModel Model, returnErr error) { p.handlers = channelHandlers{} cmds := make(chan Cmd) - p.errs = make(chan error) - p.finished = make(chan struct{}, 1) + p.errs = make(chan error, 1) + + p.finished = make(chan struct{}) + defer func() { + close(p.finished) + }() defer p.cancel() @@ -568,7 +595,12 @@ func (p *Program) Run() (Model, error) { // Recover from panics. if !p.startupOptions.has(withoutCatchPanics) { - defer p.recoverFromPanic() + defer func() { + if r := recover(); r != nil { + returnErr = fmt.Errorf("%w: %w", ErrProgramKilled, ErrProgramPanic) + p.recoverFromPanic(r) + } + }() } // If no renderer is set use the standard one. @@ -645,11 +677,27 @@ func (p *Program) Run() (Model, error) { // Run event loop, handle updates and draw. model, err := p.eventLoop(model, cmds) - killed := p.ctx.Err() != nil || err != nil - if killed && err == nil { - err = fmt.Errorf("%w: %s", ErrProgramKilled, p.ctx.Err()) + + if err == nil && len(p.errs) > 0 { + err = <-p.errs // Drain a leftover error in case eventLoop crashed } - if err == nil { + + killed := p.externalCtx.Err() != nil || p.ctx.Err() != nil || err != nil + if killed { + if err == nil && p.externalCtx.Err() != nil { + // Return also as context error the cancellation of an external context. + // This is the context the user knows about and should be able to act on. + err = fmt.Errorf("%w: %w", ErrProgramKilled, p.externalCtx.Err()) + } else if err == nil && p.ctx.Err() != nil { + // Return only that the program was killed (not the internal mechanism). + // The user does not know or need to care about the internal program context. + err = ErrProgramKilled + } else { + // Return that the program was killed and also the error that caused it. + err = fmt.Errorf("%w: %w", ErrProgramKilled, err) + } + } else { + // Graceful shutdown of the program (not killed): // Ensure we rendered the final state of the model. p.renderer.write(model.View()) } @@ -704,11 +752,11 @@ func (p *Program) Quit() { p.Send(Quit()) } -// Kill stops the program immediately and restores the former terminal state. +// Kill signals the program to stop immediately and restore the former terminal state. // The final render that you would normally see when quitting will be skipped. // [program.Run] returns a [ErrProgramKilled] error. func (p *Program) Kill() { - p.shutdown(true) + p.cancel() } // Wait waits/blocks until the underlying Program finished shutting down. @@ -717,7 +765,11 @@ func (p *Program) Wait() { } // shutdown performs operations to free up resources and restore the terminal -// to its original state. +// to its original state. It is called once at the end of the program's lifetime. +// +// This method should not be called to signal the program to be killed/shutdown. +// Doing so can lead to race conditions with the eventual call at the program's end. +// As alternatives, the [Quit] or [Kill] convenience methods should be used instead. func (p *Program) shutdown(kill bool) { p.cancel() @@ -744,19 +796,30 @@ func (p *Program) shutdown(kill bool) { } _ = p.restoreTerminalState() - if !kill { - p.finished <- struct{}{} - } } // recoverFromPanic recovers from a panic, prints the stack trace, and restores // the terminal to a usable state. -func (p *Program) recoverFromPanic() { - if r := recover(); r != nil { - p.shutdown(true) - fmt.Printf("Caught panic:\n\n%s\n\nRestoring terminal...\n\n", r) - debug.PrintStack() +func (p *Program) recoverFromPanic(r interface{}) { + select { + case p.errs <- ErrProgramPanic: + default: } + p.shutdown(true) // Ok to call here, p.Run() cannot do it anymore. + fmt.Printf("Caught panic:\n\n%s\n\nRestoring terminal...\n\n", r) + debug.PrintStack() +} + +// recoverFromGoPanic recovers from a goroutine panic, prints a stack trace and +// signals for the program to be killed and terminal restored to a usable state. +func (p *Program) recoverFromGoPanic(r interface{}) { + select { + case p.errs <- ErrProgramPanic: + default: + } + p.cancel() + fmt.Printf("Caught goroutine panic:\n\n%s\n\nRestoring terminal...\n\n", r) + debug.PrintStack() } // ReleaseTerminal restores the original terminal state and cancels the input diff --git a/tea_test.go b/tea_test.go index 981851be0b..d4e29d064d 100644 --- a/tea_test.go +++ b/tea_test.go @@ -4,13 +4,20 @@ import ( "bytes" "context" "errors" + "sync" "sync/atomic" "testing" "time" ) +type ctxImplodeMsg struct { + cancel context.CancelFunc +} + type incrementMsg struct{} +type panicMsg struct{} + type testModel struct { executed atomic.Value counter atomic.Value @@ -21,7 +28,11 @@ func (m testModel) Init() Cmd { } func (m *testModel) Update(msg Msg) (Model, Cmd) { - switch msg.(type) { + switch msg := msg.(type) { + case ctxImplodeMsg: + msg.cancel() + time.Sleep(100 * time.Millisecond) + case incrementMsg: i := m.counter.Load() if i == nil { @@ -32,6 +43,9 @@ func (m *testModel) Update(msg Msg) (Model, Cmd) { case KeyMsg: return m, Quit + + case panicMsg: + panic("testing panic behavior") } return m, nil @@ -81,6 +95,106 @@ func TestTeaQuit(t *testing.T) { } } +func TestTeaWaitQuit(t *testing.T) { + var buf bytes.Buffer + var in bytes.Buffer + + progStarted := make(chan struct{}) + waitStarted := make(chan struct{}) + errChan := make(chan error, 1) + + m := &testModel{} + p := NewProgram(m, WithInput(&in), WithOutput(&buf)) + + go func() { + _, err := p.Run() + errChan <- err + }() + + go func() { + for { + time.Sleep(time.Millisecond) + if m.executed.Load() != nil { + close(progStarted) + + <-waitStarted + time.Sleep(50 * time.Millisecond) + p.Quit() + + return + } + } + }() + + <-progStarted + + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + p.Wait() + wg.Done() + }() + } + close(waitStarted) + wg.Wait() + + err := <-errChan + if err != nil { + t.Fatalf("Expected nil, got %v", err) + } +} + +func TestTeaWaitKill(t *testing.T) { + var buf bytes.Buffer + var in bytes.Buffer + + progStarted := make(chan struct{}) + waitStarted := make(chan struct{}) + errChan := make(chan error, 1) + + m := &testModel{} + p := NewProgram(m, WithInput(&in), WithOutput(&buf)) + + go func() { + _, err := p.Run() + errChan <- err + }() + + go func() { + for { + time.Sleep(time.Millisecond) + if m.executed.Load() != nil { + close(progStarted) + + <-waitStarted + time.Sleep(50 * time.Millisecond) + p.Kill() + + return + } + } + }() + + <-progStarted + + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + p.Wait() + wg.Done() + }() + } + close(waitStarted) + wg.Wait() + + err := <-errChan + if !errors.Is(err, ErrProgramKilled) { + t.Fatalf("Expected %v, got %v", ErrProgramKilled, err) + } +} + func TestTeaWithFilter(t *testing.T) { testTeaWithFilter(t, 0) testTeaWithFilter(t, 1) @@ -138,9 +252,17 @@ func TestTeaKill(t *testing.T) { } }() - if _, err := p.Run(); !errors.Is(err, ErrProgramKilled) { + _, err := p.Run() + + if !errors.Is(err, ErrProgramKilled) { t.Fatalf("Expected %v, got %v", ErrProgramKilled, err) } + + if errors.Is(err, context.Canceled) { + // The end user should not know about the program's internal context state. + // The program should only report external context cancellation as a context error. + t.Fatalf("Internal context cancellation was reported as context error!") + } } func TestTeaContext(t *testing.T) { @@ -160,6 +282,66 @@ func TestTeaContext(t *testing.T) { } }() + _, err := p.Run() + + if !errors.Is(err, ErrProgramKilled) { + t.Fatalf("Expected %v, got %v", ErrProgramKilled, err) + } + + if !errors.Is(err, context.Canceled) { + // The end user should know that their passed in context caused the kill. + t.Fatalf("Expected %v, got %v", context.Canceled, err) + } +} + +func TestTeaContextImplodeDeadlock(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var buf bytes.Buffer + var in bytes.Buffer + + m := &testModel{} + p := NewProgram(m, WithContext(ctx), WithInput(&in), WithOutput(&buf)) + go func() { + for { + time.Sleep(time.Millisecond) + if m.executed.Load() != nil { + p.Send(ctxImplodeMsg{cancel: cancel}) + return + } + } + }() + + if _, err := p.Run(); !errors.Is(err, ErrProgramKilled) { + t.Fatalf("Expected %v, got %v", ErrProgramKilled, err) + } +} + +func TestTeaContextBatchDeadlock(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var buf bytes.Buffer + var in bytes.Buffer + + inc := func() Msg { + cancel() + return incrementMsg{} + } + + m := &testModel{} + p := NewProgram(m, WithContext(ctx), WithInput(&in), WithOutput(&buf)) + go func() { + for { + time.Sleep(time.Millisecond) + if m.executed.Load() != nil { + batch := make(BatchMsg, 100) + for i := range batch { + batch[i] = inc + } + p.Send(batch) + return + } + } + }() + if _, err := p.Run(); !errors.Is(err, ErrProgramKilled) { t.Fatalf("Expected %v, got %v", ErrProgramKilled, err) } @@ -267,3 +449,65 @@ func TestTeaNoRun(t *testing.T) { m := &testModel{} NewProgram(m, WithInput(&in), WithOutput(&buf)) } + +func TestTeaPanic(t *testing.T) { + var buf bytes.Buffer + var in bytes.Buffer + + m := &testModel{} + p := NewProgram(m, WithInput(&in), WithOutput(&buf)) + go func() { + for { + time.Sleep(time.Millisecond) + if m.executed.Load() != nil { + p.Send(panicMsg{}) + return + } + } + }() + + _, err := p.Run() + + if !errors.Is(err, ErrProgramPanic) { + t.Fatalf("Expected %v, got %v", ErrProgramPanic, err) + } + + if !errors.Is(err, ErrProgramKilled) { + t.Fatalf("Expected %v, got %v", ErrProgramKilled, err) + } +} + +func TestTeaGoroutinePanic(t *testing.T) { + var buf bytes.Buffer + var in bytes.Buffer + + panicCmd := func() Msg { + panic("testing goroutine panic behavior") + } + + m := &testModel{} + p := NewProgram(m, WithInput(&in), WithOutput(&buf)) + go func() { + for { + time.Sleep(time.Millisecond) + if m.executed.Load() != nil { + batch := make(BatchMsg, 10) + for i := range batch { + batch[i] = panicCmd + } + p.Send(batch) + return + } + } + }() + + _, err := p.Run() + + if !errors.Is(err, ErrProgramPanic) { + t.Fatalf("Expected %v, got %v", ErrProgramPanic, err) + } + + if !errors.Is(err, ErrProgramKilled) { + t.Fatalf("Expected %v, got %v", ErrProgramKilled, err) + } +}