Skip to content

Commit 7adc026

Browse files
authored
feat: Add T.Cleanup (#79)
* feat: Add T.Cleanup Adds a new T.Cleanup method to register cleanup functions. These will run after each iteration of a Check. The implementation of how cleanup functions are tracked and run is borrowed heavily from testing.T's own implementation. Namely: - [T.Cleanup puts functions in a slice](https://cs.opensource.google/go/go/+/refs/tags/go1.24.0:src/testing/testing.go;l=1214-1216) - [cleanups are run in reverse order](https://cs.opensource.google/go/go/+/refs/tags/go1.24.0:src/testing/testing.go;l=1433-1446) popping functions from the slice one by one - [panics are not allowed to interrupt cleanup](https://cs.opensource.google/go/go/+/refs/tags/go1.24.0:src/testing/testing.go;l=1420-1427) As with `testing.T`, cleanup runs after the context is canceled. Because Rapid's context is lazily initialized, it needs an additional check to avoid providing a valid context during cleanup. Resolves #62 * doc: try to be clearer
1 parent c3c5e3c commit 7adc026

File tree

5 files changed

+277
-5
lines changed

5 files changed

+277
-5
lines changed

combinators_external_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,28 @@ func TestCustomContext(t *testing.T) {
9595
})
9696
}
9797

98+
func TestCustomCleanup(t *testing.T) {
99+
t.Parallel()
100+
101+
var open bool
102+
gen := Custom(func(t *T) int {
103+
t.Cleanup(func() { open = false })
104+
return Int().Draw(t, "")
105+
})
106+
107+
// Cleanup functions registered during a Custom generator
108+
// are run after generation, not after the Check.
109+
Check(t, func(t *T) {
110+
open = true
111+
_ = gen.Draw(t, "value")
112+
113+
// Cleanup must run after each run of the custom generator.
114+
if open {
115+
t.Fatalf("cleanup must be run")
116+
}
117+
})
118+
}
119+
98120
func TestFilter(t *testing.T) {
99121
t.Parallel()
100122

engine.go

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"runtime"
2020
"strings"
2121
"sync"
22+
"sync/atomic"
2223
"testing"
2324
"time"
2425
)
@@ -506,6 +507,8 @@ type T struct {
506507

507508
ctx context.Context
508509
cancelCtx context.CancelFunc
510+
cleanups []func()
511+
cleaning atomic.Bool
509512

510513
tbLog bool
511514
rawLog *log.Logger
@@ -545,8 +548,17 @@ func (t *T) shouldLog() bool {
545548
return t.rawLog != nil || t.tbLog
546549
}
547550

548-
// Context returns a context.Context associated with the test.
549-
// It is valid only for the duration of the rapid check.
551+
// Context returns a context.Context that is canceled
552+
// after the property function exits,
553+
// before Cleanup-registered functions are run.
554+
//
555+
// For [Check], [MakeFuzz], and similar functions,
556+
// each call to the property function gets a unique context
557+
// that is canceled after that property function exits.
558+
//
559+
// For [Custom], each time a new value is generated,
560+
// the generator function gets a unique context
561+
// that is canceled after the generator function exits.
550562
func (t *T) Context() context.Context {
551563
// Fast path: no need to lock if the context is already set.
552564
t.mu.RLock()
@@ -556,6 +568,15 @@ func (t *T) Context() context.Context {
556568
return ctx
557569
}
558570

571+
// If we're in the middle of cleaning up
572+
// and the context has already been canceled and cleared,
573+
// don't create a new one. Return a canceled context instead.
574+
if t.cleaning.Load() {
575+
ctx, cancel := context.WithCancel(context.Background())
576+
cancel()
577+
return ctx
578+
}
579+
559580
// Slow path: lock and check again, create new context if needed.
560581
t.mu.Lock()
561582
defer t.mu.Unlock()
@@ -582,17 +603,71 @@ func (t *T) Context() context.Context {
582603
return ctx
583604
}
584605

606+
// Cleanup registers a function to be called
607+
// when a property function finishes running.
608+
//
609+
// For [Check], [MakeFuzz], and similar functions,
610+
// each call to the property function registers its cleanup functions,
611+
// which are called after the property function exits.
612+
//
613+
// For [Custom], each time a new value is generated,
614+
// the generator function registers its cleanup functions,
615+
// which are called after the generator function exits.
616+
//
617+
// Cleanup functions are called in last-in, first-out order.
618+
//
619+
// If [T.Context] is used, the context is canceled
620+
// before the Cleanup functions are executed.
621+
func (t *T) Cleanup(f func()) {
622+
t.mu.Lock()
623+
defer t.mu.Unlock()
624+
625+
t.cleanups = append(t.cleanups, f)
626+
}
627+
585628
// cleanup runs any cleanup tasks associated with the property check.
586629
// It is safe to call multiple times.
587630
func (t *T) cleanup() {
588-
t.mu.Lock()
589-
defer t.mu.Unlock()
631+
t.cleaning.Store(true)
632+
defer t.cleaning.Store(false)
633+
634+
// If a cleanup function panics,
635+
// we still want to run the remaining cleanup functions.
636+
defer func() {
637+
t.mu.Lock()
638+
recurse := len(t.cleanups) > 0
639+
t.mu.Unlock()
640+
641+
if recurse {
642+
t.cleanup()
643+
}
644+
}()
590645

646+
// Context must be closed before t.Cleanup functions are run.
647+
t.mu.Lock()
591648
if t.cancelCtx != nil {
592649
t.cancelCtx()
593650
t.cancelCtx = nil
594651
t.ctx = nil
595652
}
653+
t.mu.Unlock()
654+
655+
for {
656+
var cleanup func()
657+
t.mu.Lock()
658+
if len(t.cleanups) > 0 {
659+
last := len(t.cleanups) - 1
660+
cleanup = t.cleanups[last]
661+
t.cleanups = t.cleanups[:last]
662+
}
663+
t.mu.Unlock()
664+
665+
if cleanup == nil {
666+
break
667+
}
668+
669+
cleanup()
670+
}
596671
}
597672

598673
func (t *T) Logf(format string, args ...any) {

engine_fuzz_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,20 @@ func FuzzContext(f *testing.F) {
9292
}
9393
}
9494
}
95+
96+
func FuzzCleanup(f *testing.F) {
97+
var state []bool
98+
f.Fuzz(MakeFuzz(func(t *T) {
99+
idx := len(state)
100+
state = append(state, false)
101+
t.Cleanup(func() {
102+
state[idx] = true
103+
})
104+
}))
105+
106+
for _, ok := range state {
107+
if !ok {
108+
f.Fatalf("cleanup must be called")
109+
}
110+
}
111+
}

engine_test.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package rapid
99
import (
1010
"context"
1111
"errors"
12+
"reflect"
1213
"strings"
1314
"testing"
1415
)
@@ -97,6 +98,8 @@ func BenchmarkCheckOverhead(b *testing.B) {
9798
}
9899

99100
func TestCheckContext(t *testing.T) {
101+
t.Parallel()
102+
100103
type key struct{}
101104

102105
var ctx context.Context
@@ -115,3 +118,144 @@ func TestCheckContext(t *testing.T) {
115118
t.Fatalf("context must have a value")
116119
}
117120
}
121+
122+
func TestCheckCleanup(t *testing.T) {
123+
t.Parallel()
124+
125+
// Each Check iteration will append a true to indicate "open",
126+
// and flip it to false on cleanup.
127+
//
128+
// After the check is done, we expect all values to be false.
129+
var state []bool
130+
131+
Check(t, func(t *T) {
132+
idx := len(state)
133+
state = append(state, true)
134+
t.Cleanup(func() {
135+
state[idx] = false
136+
})
137+
})
138+
139+
for _, v := range state {
140+
if v {
141+
t.Fatalf("expected all values to be false")
142+
}
143+
}
144+
}
145+
146+
func TestCheckCleanupMultipleOrder(t *testing.T) {
147+
t.Parallel()
148+
149+
// If multiple cleanups are appended during a Check,
150+
// they must run in reverse order.
151+
var state []int
152+
Check(t, func(t *T) {
153+
// We just want to capture the result of one iteration,
154+
// so we'll keep resetting the state.
155+
state = nil
156+
t.Cleanup(func() {
157+
state = append(state, 1)
158+
})
159+
t.Cleanup(func() {
160+
state = append(state, 2)
161+
})
162+
t.Cleanup(func() {
163+
state = append(state, 3)
164+
})
165+
})
166+
167+
if !reflect.DeepEqual(state, []int{3, 2, 1}) {
168+
t.Fatalf("expected cleanups to run in reverse order, got: %v", state)
169+
}
170+
}
171+
172+
func TestCheckCleanupPanic(t *testing.T) {
173+
t.Parallel()
174+
175+
// A Cleanup function halfway through will panic.
176+
// Deferred assertions will check that all values are false.
177+
var state []bool
178+
defer func() {
179+
for _, v := range state {
180+
if v {
181+
t.Errorf("expected all values to be false")
182+
}
183+
}
184+
}()
185+
186+
Check(ignoreErrorsTB{t}, func(t *T) {
187+
idx := len(state)
188+
state = append(state, true)
189+
t.Cleanup(func() {
190+
state[idx] = false
191+
if idx == len(state)/2 {
192+
panic("cleanup panic")
193+
}
194+
})
195+
})
196+
}
197+
198+
func TestCheckCleanupNewCleanupsDuringCleanup(t *testing.T) {
199+
t.Parallel()
200+
201+
// Cleanups can be added during cleanup.
202+
var state []bool
203+
Check(t, func(t *T) {
204+
idx := len(state)
205+
state = append(state, true)
206+
t.Cleanup(func() {
207+
// Odd numbered events will add a new cleanup.
208+
if idx%2 == 0 {
209+
state[idx] = false
210+
} else {
211+
t.Cleanup(func() {
212+
state[idx] = false
213+
})
214+
}
215+
})
216+
})
217+
}
218+
219+
func TestCheckCleanupContextIsCanceled(t *testing.T) {
220+
t.Parallel()
221+
222+
// Context created during Check is canceled by the time Cleanup is run.
223+
Check(t, func(t *T) {
224+
ctx := t.Context()
225+
t.Cleanup(func() {
226+
if err := ctx.Err(); err == nil || !errors.Is(err, context.Canceled) {
227+
t.Fatalf("expected context to be canceled, got: %v", ctx)
228+
}
229+
})
230+
})
231+
}
232+
233+
func TestCheckCleanupContextCreatedInCleanup(t *testing.T) {
234+
t.Parallel()
235+
236+
// Context created during Cleanup is already canceled.
237+
Check(t, func(t *T) {
238+
ctx := t.Context()
239+
t.Cleanup(func() {
240+
// ctx is already cleared on rapid.T by now,
241+
// so this will request a new context.
242+
newCtx := t.Context()
243+
if ctx == newCtx {
244+
t.Fatalf("expected new context")
245+
}
246+
247+
if err := newCtx.Err(); err == nil || !errors.Is(err, context.Canceled) {
248+
t.Fatalf("expected context to be canceled, got: %v", newCtx)
249+
}
250+
})
251+
})
252+
}
253+
254+
// ignoreErrorsTB is a TB that ignores all errors posted to it.
255+
type ignoreErrorsTB struct{ TB }
256+
257+
func (ignoreErrorsTB) Error(...interface{}) {}
258+
func (ignoreErrorsTB) Errorf(string, ...interface{}) {}
259+
func (ignoreErrorsTB) Fatal(...interface{}) {}
260+
func (ignoreErrorsTB) Fatalf(string, ...interface{}) {}
261+
func (ignoreErrorsTB) Fail() {}

generator_test.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestExampleHelper(t *testing.T) {
4646
g.Example(0)
4747
}
4848

49-
func TestExampleContext(t *testing.T) {
49+
func TestCustomExampleContext(t *testing.T) {
5050
type key struct{}
5151

5252
g := Custom(func(t *T) context.Context {
@@ -67,3 +67,17 @@ func TestExampleContext(t *testing.T) {
6767
t.Fatalf("context must have a value")
6868
}
6969
}
70+
71+
func TestCustomExampleCleanup(t *testing.T) {
72+
var state bool
73+
g := Custom(func(t *T) int {
74+
t.Cleanup(func() { state = false })
75+
return Int().Draw(t, "x")
76+
})
77+
78+
state = true
79+
_ = g.Example(0)
80+
if state {
81+
t.Fatalf("cleanup must be called")
82+
}
83+
}

0 commit comments

Comments
 (0)