Skip to content

Commit f856ab8

Browse files
committed
Pass a context to the calling function for cancellation
1 parent d1b997f commit f856ab8

2 files changed

Lines changed: 174 additions & 20 deletions

File tree

singleflight.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,43 @@ type Group struct {
2020
mu sync.Mutex // protects calls
2121
}
2222

23-
// Do executes and returns the results of the given function, making
24-
// sure that only one execution is in-flight for a given key at a
25-
// time. If a duplicate comes in, the duplicate caller waits for the
26-
// original to complete and receives the same results.
27-
// Passed context terminates the execution of Do function, not the passed
28-
// function fn. If there are multiple callers, context passed to one caller
29-
// does not effect the execution and returned values of others.
23+
// Do executes and returns the results of the given function, making sure that
24+
// only one execution is in-flight for a given key at a time. If a duplicate
25+
// comes in, the duplicate caller waits for the original to complete and
26+
// receives the same results.
27+
//
28+
// The context passed to the fn function is a new context which is canceled when
29+
// contexts from all callers are canceled, so that no caller is expecting the
30+
// result. If there are multiple callers, context passed to one caller does not
31+
// effect the execution and returned values of others.
32+
//
3033
// The return value shared indicates whether v was given to multiple callers.
31-
func (g *Group) Do(ctx context.Context, key string, fn func() (interface{}, error)) (v interface{}, shared bool, err error) {
34+
func (g *Group) Do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (v interface{}, shared bool, err error) {
3235
g.mu.Lock()
3336
if g.calls == nil {
3437
g.calls = make(map[string]*call)
3538
}
3639

3740
if c, ok := g.calls[key]; ok {
3841
c.shared = true
42+
c.counter++
3943
g.mu.Unlock()
4044

4145
return g.wait(ctx, key, c)
4246
}
4347

48+
callCtx, cancel := context.WithCancel(context.Background())
49+
4450
c := &call{
45-
done: make(chan struct{}),
51+
done: make(chan struct{}),
52+
cancel: cancel,
53+
counter: 1,
4654
}
4755
g.calls[key] = c
4856
g.mu.Unlock()
4957

5058
go func() {
51-
c.val, c.err = fn()
59+
c.val, c.err = fn(callCtx)
5260
close(c.done)
5361
}()
5462

@@ -65,6 +73,10 @@ func (g *Group) wait(ctx context.Context, key string, c *call) (v interface{}, s
6573
err = ctx.Err()
6674
}
6775
g.mu.Lock()
76+
c.counter--
77+
if c.counter == 0 {
78+
c.cancel()
79+
}
6880
if !c.forgotten {
6981
delete(g.calls, key)
7082
}
@@ -99,4 +111,9 @@ type call struct {
99111

100112
// shared indicates if results val and err are passed to multiple callers.
101113
shared bool
114+
115+
// Number of callers that are waiting for the result.
116+
counter int
117+
// Cancel function for the context passed to the executing function.
118+
cancel context.CancelFunc
102119
}

singleflight_test.go

Lines changed: 147 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
package singleflight_test
77

88
import (
9+
"bytes"
910
"context"
1011
"errors"
12+
"fmt"
13+
"runtime/pprof"
1114
"strconv"
15+
"strings"
1216
"sync"
1317
"sync/atomic"
1418
"testing"
@@ -21,7 +25,7 @@ func TestDo(t *testing.T) {
2125
var g singleflight.Group
2226

2327
want := "val"
24-
got, shared, err := g.Do(context.Background(), "key", func() (interface{}, error) {
28+
got, shared, err := g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) {
2529
return want, nil
2630
})
2731
if err != nil {
@@ -38,7 +42,7 @@ func TestDo(t *testing.T) {
3842
func TestDo_error(t *testing.T) {
3943
var g singleflight.Group
4044
wantErr := errors.New("test error")
41-
got, _, err := g.Do(context.Background(), "key", func() (interface{}, error) {
45+
got, _, err := g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) {
4246
return nil, wantErr
4347
})
4448
if err != wantErr {
@@ -64,7 +68,7 @@ func TestDo_multipleCalls(t *testing.T) {
6468
for i := 0; i < n; i++ {
6569
go func(i int) {
6670
defer wg.Done()
67-
got[i], shared[i], err[i] = g.Do(context.Background(), "key", func() (interface{}, error) {
71+
got[i], shared[i], err[i] = g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) {
6872
atomic.AddInt32(&counter, 1)
6973
time.Sleep(100 * time.Millisecond)
7074
return want, nil
@@ -95,7 +99,7 @@ func TestDo_callRemoval(t *testing.T) {
9599

96100
wantPrefix := "val"
97101
counter := 0
98-
fn := func() (interface{}, error) {
102+
fn := func(_ context.Context) (interface{}, error) {
99103
counter++
100104
return wantPrefix + strconv.Itoa(counter), nil
101105
}
@@ -124,6 +128,9 @@ func TestDo_callRemoval(t *testing.T) {
124128
}
125129

126130
func TestDo_cancelContext(t *testing.T) {
131+
done := make(chan struct{})
132+
defer close(done)
133+
127134
var g singleflight.Group
128135

129136
want := "val"
@@ -133,8 +140,11 @@ func TestDo_cancelContext(t *testing.T) {
133140
cancel()
134141
}()
135142
start := time.Now()
136-
got, shared, err := g.Do(ctx, "key", func() (interface{}, error) {
137-
time.Sleep(time.Second)
143+
got, shared, err := g.Do(ctx, "key", func(_ context.Context) (interface{}, error) {
144+
select {
145+
case <-time.After(time.Second):
146+
case <-done:
147+
}
138148
return want, nil
139149
})
140150
if d := time.Since(start); d < 100*time.Microsecond || d > time.Second {
@@ -152,11 +162,17 @@ func TestDo_cancelContext(t *testing.T) {
152162
}
153163

154164
func TestDo_cancelContextSecond(t *testing.T) {
165+
done := make(chan struct{})
166+
defer close(done)
167+
155168
var g singleflight.Group
156169

157170
want := "val"
158-
fn := func() (interface{}, error) {
159-
time.Sleep(time.Second)
171+
fn := func(_ context.Context) (interface{}, error) {
172+
select {
173+
case <-time.After(time.Second):
174+
case <-done:
175+
}
160176
return want, nil
161177
}
162178
go func() {
@@ -186,16 +202,22 @@ func TestDo_cancelContextSecond(t *testing.T) {
186202
}
187203

188204
func TestForget(t *testing.T) {
205+
done := make(chan struct{})
206+
defer close(done)
207+
189208
var g singleflight.Group
190209

191210
wantPrefix := "val"
192211
var counter uint64
193212
firstCall := make(chan struct{})
194-
fn := func() (interface{}, error) {
213+
fn := func(_ context.Context) (interface{}, error) {
195214
c := atomic.AddUint64(&counter, 1)
196215
if c == 1 {
197216
close(firstCall)
198-
time.Sleep(time.Second)
217+
select {
218+
case <-time.After(time.Second):
219+
case <-done:
220+
}
199221
}
200222
return wantPrefix + strconv.FormatUint(c, 10), nil
201223
}
@@ -220,3 +242,118 @@ func TestForget(t *testing.T) {
220242
t.Errorf("got value %v, want %v", got, want)
221243
}
222244
}
245+
246+
func TestDo_multipleCallsCanceled(t *testing.T) {
247+
const n = 5
248+
249+
for lastCall := 0; lastCall < n; lastCall++ {
250+
lastCall := lastCall
251+
t.Run(fmt.Sprintf("last call %v of %v", lastCall, n), func(t *testing.T) {
252+
done := make(chan struct{})
253+
defer close(done)
254+
255+
var g singleflight.Group
256+
257+
var counter int32
258+
259+
fnCalled := make(chan struct{})
260+
fnErrChan := make(chan error)
261+
var mu sync.Mutex
262+
contexts := make([]context.Context, n)
263+
cancelFuncs := make([]context.CancelFunc, n)
264+
var wg sync.WaitGroup
265+
wg.Add(n)
266+
for i := 0; i < n; i++ {
267+
go func(i int) {
268+
defer wg.Done()
269+
ctx, cancel := context.WithCancel(context.Background())
270+
mu.Lock()
271+
contexts[i] = ctx
272+
cancelFuncs[i] = cancel
273+
mu.Unlock()
274+
_, _, _ = g.Do(ctx, "key", func(ctx context.Context) (interface{}, error) {
275+
atomic.AddInt32(&counter, 1)
276+
close(fnCalled)
277+
var err error
278+
select {
279+
case <-ctx.Done():
280+
err = ctx.Err()
281+
if err == nil {
282+
err = errors.New("got unexpected <nil> error from context")
283+
}
284+
case <-time.After(10 * time.Second):
285+
err = errors.New("unexpected timeout, context not canceled")
286+
case <-done:
287+
}
288+
289+
fnErrChan <- err
290+
291+
return nil, nil
292+
})
293+
}(i)
294+
}
295+
select {
296+
case <-fnCalled:
297+
case <-time.After(10 * time.Second):
298+
t.Fatal("timeout waiting for function to be called")
299+
}
300+
301+
// Ensure that n goroutines are waiting at the select case in Group.wait.
302+
// Update the line number on changes.
303+
waitStacks(t, "resenje.org/singleflight/singleflight.go:68", n, 2*time.Second)
304+
305+
// cancel all but one calls
306+
for i := 0; i < n; i++ {
307+
if i == lastCall {
308+
continue
309+
}
310+
mu.Lock()
311+
cancelFuncs[i]()
312+
<-contexts[i].Done()
313+
mu.Unlock()
314+
}
315+
316+
select {
317+
case err := <-fnErrChan:
318+
t.Fatalf("got unexpected error in function: %v", err)
319+
default:
320+
}
321+
322+
// Ensure that only the last goroutine is waiting at the select case in Group.wait.
323+
// Update the line number on changes.
324+
waitStacks(t, "resenje.org/singleflight/singleflight.go:68", 1, 2*time.Second)
325+
326+
mu.Lock()
327+
cancelFuncs[lastCall]()
328+
mu.Unlock()
329+
330+
wg.Wait()
331+
332+
select {
333+
case err := <-fnErrChan:
334+
if err != context.Canceled {
335+
t.Fatalf("got unexpected error in function %v, want %v", err, context.Canceled)
336+
}
337+
case <-time.After(10 * time.Second):
338+
t.Fatal("timeout waiting for the error")
339+
}
340+
})
341+
}
342+
}
343+
344+
func waitStacks(t *testing.T, loc string, count int, timeout time.Duration) {
345+
t.Helper()
346+
347+
for deadline := time.Now().Add(timeout); time.Now().Before(deadline); {
348+
// Ensure that exact n goroutines are waiting at the desired stack trace.
349+
var buf bytes.Buffer
350+
if err := pprof.Lookup("goroutine").WriteTo(&buf, 2); err != nil {
351+
t.Fatal(err)
352+
}
353+
c := strings.Count(buf.String(), loc)
354+
if c == count {
355+
break
356+
}
357+
time.Sleep(10 * time.Millisecond)
358+
}
359+
}

0 commit comments

Comments
 (0)