66package singleflight_test
77
88import (
9+ "bytes"
910 "context"
1011 "errors"
12+ "fmt"
13+ "runtime/pprof"
1114 "strconv"
15+ "strings"
1216 "sync"
1317 "sync/atomic"
1418 "testing"
@@ -21,7 +25,7 @@ func TestDo(t *testing.T) {
2125 var g singleflight.Group
2226
2327 want := "val"
24- got , shared , err := g .Do (context .Background (), "key" , func () (interface {}, error ) {
28+ got , shared , err := g .Do (context .Background (), "key" , func (_ context. Context ) (interface {}, error ) {
2529 return want , nil
2630 })
2731 if err != nil {
@@ -38,7 +42,7 @@ func TestDo(t *testing.T) {
3842func TestDo_error (t * testing.T ) {
3943 var g singleflight.Group
4044 wantErr := errors .New ("test error" )
41- got , _ , err := g .Do (context .Background (), "key" , func () (interface {}, error ) {
45+ got , _ , err := g .Do (context .Background (), "key" , func (_ context. Context ) (interface {}, error ) {
4246 return nil , wantErr
4347 })
4448 if err != wantErr {
@@ -64,7 +68,7 @@ func TestDo_multipleCalls(t *testing.T) {
6468 for i := 0 ; i < n ; i ++ {
6569 go func (i int ) {
6670 defer wg .Done ()
67- got [i ], shared [i ], err [i ] = g .Do (context .Background (), "key" , func () (interface {}, error ) {
71+ got [i ], shared [i ], err [i ] = g .Do (context .Background (), "key" , func (_ context. Context ) (interface {}, error ) {
6872 atomic .AddInt32 (& counter , 1 )
6973 time .Sleep (100 * time .Millisecond )
7074 return want , nil
@@ -95,7 +99,7 @@ func TestDo_callRemoval(t *testing.T) {
9599
96100 wantPrefix := "val"
97101 counter := 0
98- fn := func () (interface {}, error ) {
102+ fn := func (_ context. Context ) (interface {}, error ) {
99103 counter ++
100104 return wantPrefix + strconv .Itoa (counter ), nil
101105 }
@@ -124,6 +128,9 @@ func TestDo_callRemoval(t *testing.T) {
124128}
125129
126130func TestDo_cancelContext (t * testing.T ) {
131+ done := make (chan struct {})
132+ defer close (done )
133+
127134 var g singleflight.Group
128135
129136 want := "val"
@@ -133,8 +140,11 @@ func TestDo_cancelContext(t *testing.T) {
133140 cancel ()
134141 }()
135142 start := time .Now ()
136- got , shared , err := g .Do (ctx , "key" , func () (interface {}, error ) {
137- time .Sleep (time .Second )
143+ got , shared , err := g .Do (ctx , "key" , func (_ context.Context ) (interface {}, error ) {
144+ select {
145+ case <- time .After (time .Second ):
146+ case <- done :
147+ }
138148 return want , nil
139149 })
140150 if d := time .Since (start ); d < 100 * time .Microsecond || d > time .Second {
@@ -152,11 +162,17 @@ func TestDo_cancelContext(t *testing.T) {
152162}
153163
154164func TestDo_cancelContextSecond (t * testing.T ) {
165+ done := make (chan struct {})
166+ defer close (done )
167+
155168 var g singleflight.Group
156169
157170 want := "val"
158- fn := func () (interface {}, error ) {
159- time .Sleep (time .Second )
171+ fn := func (_ context.Context ) (interface {}, error ) {
172+ select {
173+ case <- time .After (time .Second ):
174+ case <- done :
175+ }
160176 return want , nil
161177 }
162178 go func () {
@@ -186,16 +202,22 @@ func TestDo_cancelContextSecond(t *testing.T) {
186202}
187203
188204func TestForget (t * testing.T ) {
205+ done := make (chan struct {})
206+ defer close (done )
207+
189208 var g singleflight.Group
190209
191210 wantPrefix := "val"
192211 var counter uint64
193212 firstCall := make (chan struct {})
194- fn := func () (interface {}, error ) {
213+ fn := func (_ context. Context ) (interface {}, error ) {
195214 c := atomic .AddUint64 (& counter , 1 )
196215 if c == 1 {
197216 close (firstCall )
198- time .Sleep (time .Second )
217+ select {
218+ case <- time .After (time .Second ):
219+ case <- done :
220+ }
199221 }
200222 return wantPrefix + strconv .FormatUint (c , 10 ), nil
201223 }
@@ -220,3 +242,118 @@ func TestForget(t *testing.T) {
220242 t .Errorf ("got value %v, want %v" , got , want )
221243 }
222244}
245+
246+ func TestDo_multipleCallsCanceled (t * testing.T ) {
247+ const n = 5
248+
249+ for lastCall := 0 ; lastCall < n ; lastCall ++ {
250+ lastCall := lastCall
251+ t .Run (fmt .Sprintf ("last call %v of %v" , lastCall , n ), func (t * testing.T ) {
252+ done := make (chan struct {})
253+ defer close (done )
254+
255+ var g singleflight.Group
256+
257+ var counter int32
258+
259+ fnCalled := make (chan struct {})
260+ fnErrChan := make (chan error )
261+ var mu sync.Mutex
262+ contexts := make ([]context.Context , n )
263+ cancelFuncs := make ([]context.CancelFunc , n )
264+ var wg sync.WaitGroup
265+ wg .Add (n )
266+ for i := 0 ; i < n ; i ++ {
267+ go func (i int ) {
268+ defer wg .Done ()
269+ ctx , cancel := context .WithCancel (context .Background ())
270+ mu .Lock ()
271+ contexts [i ] = ctx
272+ cancelFuncs [i ] = cancel
273+ mu .Unlock ()
274+ _ , _ , _ = g .Do (ctx , "key" , func (ctx context.Context ) (interface {}, error ) {
275+ atomic .AddInt32 (& counter , 1 )
276+ close (fnCalled )
277+ var err error
278+ select {
279+ case <- ctx .Done ():
280+ err = ctx .Err ()
281+ if err == nil {
282+ err = errors .New ("got unexpected <nil> error from context" )
283+ }
284+ case <- time .After (10 * time .Second ):
285+ err = errors .New ("unexpected timeout, context not canceled" )
286+ case <- done :
287+ }
288+
289+ fnErrChan <- err
290+
291+ return nil , nil
292+ })
293+ }(i )
294+ }
295+ select {
296+ case <- fnCalled :
297+ case <- time .After (10 * time .Second ):
298+ t .Fatal ("timeout waiting for function to be called" )
299+ }
300+
301+ // Ensure that n goroutines are waiting at the select case in Group.wait.
302+ // Update the line number on changes.
303+ waitStacks (t , "resenje.org/singleflight/singleflight.go:68" , n , 2 * time .Second )
304+
305+ // cancel all but one calls
306+ for i := 0 ; i < n ; i ++ {
307+ if i == lastCall {
308+ continue
309+ }
310+ mu .Lock ()
311+ cancelFuncs [i ]()
312+ <- contexts [i ].Done ()
313+ mu .Unlock ()
314+ }
315+
316+ select {
317+ case err := <- fnErrChan :
318+ t .Fatalf ("got unexpected error in function: %v" , err )
319+ default :
320+ }
321+
322+ // Ensure that only the last goroutine is waiting at the select case in Group.wait.
323+ // Update the line number on changes.
324+ waitStacks (t , "resenje.org/singleflight/singleflight.go:68" , 1 , 2 * time .Second )
325+
326+ mu .Lock ()
327+ cancelFuncs [lastCall ]()
328+ mu .Unlock ()
329+
330+ wg .Wait ()
331+
332+ select {
333+ case err := <- fnErrChan :
334+ if err != context .Canceled {
335+ t .Fatalf ("got unexpected error in function %v, want %v" , err , context .Canceled )
336+ }
337+ case <- time .After (10 * time .Second ):
338+ t .Fatal ("timeout waiting for the error" )
339+ }
340+ })
341+ }
342+ }
343+
344+ func waitStacks (t * testing.T , loc string , count int , timeout time.Duration ) {
345+ t .Helper ()
346+
347+ for deadline := time .Now ().Add (timeout ); time .Now ().Before (deadline ); {
348+ // Ensure that exact n goroutines are waiting at the desired stack trace.
349+ var buf bytes.Buffer
350+ if err := pprof .Lookup ("goroutine" ).WriteTo (& buf , 2 ); err != nil {
351+ t .Fatal (err )
352+ }
353+ c := strings .Count (buf .String (), loc )
354+ if c == count {
355+ break
356+ }
357+ time .Sleep (10 * time .Millisecond )
358+ }
359+ }
0 commit comments