@@ -3,51 +3,123 @@ package fn
3
3
import (
4
4
"context"
5
5
"sync"
6
+ "sync/atomic"
6
7
)
7
8
8
9
// GoroutineManager is used to launch goroutines until context expires or the
9
10
// manager is stopped. The Stop method blocks until all started goroutines stop.
10
11
type GoroutineManager struct {
11
- wg sync.WaitGroup
12
- mu sync.Mutex
13
- ctx context.Context
14
- cancel func ()
12
+ // id is used to generate unique ids for each goroutine.
13
+ id atomic.Uint32
14
+
15
+ // cancelFns is a map of cancel functions that can be used to cancel the
16
+ // context of a goroutine. The mutex must be held when accessing this
17
+ // map. The key is the id of the goroutine.
18
+ cancelFns map [uint32 ]context.CancelFunc
19
+
20
+ mu sync.Mutex
21
+
22
+ stopped sync.Once
23
+ quit chan struct {}
24
+ wg sync.WaitGroup
15
25
}
16
26
17
27
// NewGoroutineManager constructs and returns a new instance of
18
28
// GoroutineManager.
19
- func NewGoroutineManager (ctx context.Context ) * GoroutineManager {
20
- ctx , cancel := context .WithCancel (ctx )
21
-
29
+ func NewGoroutineManager () * GoroutineManager {
22
30
return & GoroutineManager {
23
- ctx : ctx ,
24
- cancel : cancel ,
31
+ cancelFns : make (map [uint32 ]context.CancelFunc ),
32
+ quit : make (chan struct {}),
33
+ }
34
+ }
35
+
36
+ // addCancelFn adds a context cancel function to the manager and returns an id
37
+ // that can can be used to cancel the context later on when the goroutine is
38
+ // done.
39
+ func (g * GoroutineManager ) addCancelFn (cancel context.CancelFunc ) uint32 {
40
+ g .mu .Lock ()
41
+ defer g .mu .Unlock ()
42
+
43
+ id := g .id .Add (1 )
44
+ g .cancelFns [id ] = cancel
45
+
46
+ return id
47
+ }
48
+
49
+ // cancel cancels the context associated with the passed id.
50
+ func (g * GoroutineManager ) cancel (id uint32 ) {
51
+ g .mu .Lock ()
52
+ defer g .mu .Unlock ()
53
+
54
+ g .cancelUnsafe (id )
55
+ }
56
+
57
+ // cancelUnsafe cancels the context associated with the passed id without
58
+ // acquiring the mutex.
59
+ func (g * GoroutineManager ) cancelUnsafe (id uint32 ) {
60
+ fn , ok := g .cancelFns [id ]
61
+ if ! ok {
62
+ return
25
63
}
64
+
65
+ fn ()
66
+
67
+ delete (g .cancelFns , id )
26
68
}
27
69
28
70
// Go tries to start a new goroutine and returns a boolean indicating its
29
- // success. It fails iff the goroutine manager is stopping or its context passed
30
- // to NewGoroutineManager has expired.
31
- func (g * GoroutineManager ) Go (f func (ctx context.Context )) bool {
32
- // Calling wg.Add(1) and wg.Wait() when wg's counter is 0 is a race
33
- // condition, since it is not clear should Wait() block or not. This
71
+ // success. It returns true if the goroutine was successfully created and false
72
+ // otherwise. A goroutine will fail to be created iff the goroutine manager is
73
+ // stopping or the passed context has already expired. The passed call-back
74
+ // function must exit if the passed context expires.
75
+ func (g * GoroutineManager ) Go (ctx context.Context ,
76
+ f func (ctx context.Context )) bool {
77
+
78
+ // Derive a cancellable context from the passed context and store its
79
+ // cancel function in the manager. The context will be cancelled when
80
+ // either the parent context is cancelled or the quit channel is closed
81
+ // which will call the stored cancel function.
82
+ ctx , cancel := context .WithCancel (ctx )
83
+ id := g .addCancelFn (cancel )
84
+
85
+ // Calling wg.Add(1) and wg.Wait() when the wg's counter is 0 is a race
86
+ // condition, since it is not clear if Wait() should block or not. This
34
87
// kind of race condition is detected by Go runtime and results in a
35
- // crash if running with `-race`. To prevent this, whole Go method is
36
- // protected with a mutex. The call to wg.Wait() inside Stop() can still
37
- // run in parallel with Go, but in that case g.ctx is in expired state,
38
- // because cancel() was called in Stop, so Go returns before wg.Add(1)
39
- // call.
88
+ // crash if running with `-race`. To prevent this, we protect the calls
89
+ // to wg.Add(1) and wg.Wait() with a mutex. If we block here because
90
+ // Stop is running first, then Stop will close the quit channel which
91
+ // will cause the context to be cancelled, and we will exit before
92
+ // calling wg.Add(1). If we grab the mutex here before Stop does, then
93
+ // Stop will block until after we call wg.Add(1).
40
94
g .mu .Lock ()
41
95
defer g .mu .Unlock ()
42
96
43
- if g .ctx .Err () != nil {
97
+ // Before continuing to start the goroutine, we need to check if the
98
+ // context has already expired. This could be the case if the parent
99
+ // context has already expired or if Stop has been called.
100
+ if ctx .Err () != nil {
101
+ g .cancelUnsafe (id )
102
+
103
+ return false
104
+ }
105
+
106
+ // Ensure that the goroutine is not started if the manager has stopped.
107
+ select {
108
+ case <- g .quit :
109
+ g .cancelUnsafe (id )
110
+
44
111
return false
112
+ default :
45
113
}
46
114
47
115
g .wg .Add (1 )
48
116
go func () {
49
- defer g .wg .Done ()
50
- f (g .ctx )
117
+ defer func () {
118
+ g .cancel (id )
119
+ g .wg .Done ()
120
+ }()
121
+
122
+ f (ctx )
51
123
}()
52
124
53
125
return true
@@ -56,20 +128,30 @@ func (g *GoroutineManager) Go(f func(ctx context.Context)) bool {
56
128
// Stop prevents new goroutines from being added and waits for all running
57
129
// goroutines to finish.
58
130
func (g * GoroutineManager ) Stop () {
59
- g .mu .Lock ()
60
- g .cancel ()
61
- g .mu .Unlock ()
62
-
63
- // Wait for all goroutines to finish. Note that this wg.Wait() call is
64
- // safe, since it can't run in parallel with wg.Add(1) call in Go, since
65
- // we just cancelled the context and even if Go call starts running here
66
- // after acquiring the mutex, it would see that the context has expired
67
- // and return false instead of calling wg.Add(1).
68
- g .wg .Wait ()
131
+ g .stopped .Do (func () {
132
+ // Closing the quit channel will prevent any new goroutines from
133
+ // starting.
134
+ g .mu .Lock ()
135
+ close (g .quit )
136
+ for _ , cancel := range g .cancelFns {
137
+ cancel ()
138
+ }
139
+ g .mu .Unlock ()
140
+
141
+ // Wait for all goroutines to finish. Note that this wg.Wait()
142
+ // call is safe, since it can't run in parallel with wg.Add(1)
143
+ // call in Go, since we just cancelled the context and even if
144
+ // Go call starts running here after acquiring the mutex, it
145
+ // would see that the context has expired and return false
146
+ // instead of calling wg.Add(1).
147
+ g .wg .Wait ()
148
+ })
69
149
}
70
150
71
- // Done returns a channel which is closed when either the context passed to
72
- // NewGoroutineManager expires or when Stop is called.
151
+ // Done returns a channel which is closed once Stop has been called and the
152
+ // quit channel closed. Note that the channel closing indicates that shutdown
153
+ // of the GoroutineManager has started but not necessarily that the Stop method
154
+ // has finished.
73
155
func (g * GoroutineManager ) Done () <- chan struct {} {
74
- return g .ctx . Done ()
156
+ return g .quit
75
157
}
0 commit comments