@@ -33,8 +33,7 @@ func TestWorkerPoolLifeCycle(t *testing.T) {
3333
3434 t .Run ("noop" , func (t * testing.T ) {
3535 ctx := context .Background ()
36- cb := func (ctx context.Context , item int ) {}
37- subject := NewWorkerPool (cb )
36+ subject := NewWorkerPool (noop )
3837 subject .Start (ctx , 10 )
3938 subject .Stop (ctx )
4039 })
@@ -67,19 +66,24 @@ func TestWorkerPoolLifeCycle(t *testing.T) {
6766 })
6867
6968 t .Run ("noop multiple stops" , func (t * testing.T ) {
69+ const testSize = 10
7070 ctx := context .Background ()
71- cb := func (ctx context.Context , item int ) {}
72- subject := NewWorkerPool (cb )
71+ subject := NewWorkerPool (noop )
7372 subject .Start (ctx , 10 )
74- subject .Stop (ctx )
75- subject .Stop (ctx )
76- subject .Stop (ctx )
73+ var wg sync.WaitGroup
74+ wg .Add (testSize )
75+ for range testSize {
76+ go func () {
77+ defer wg .Done ()
78+ subject .Stop (ctx )
79+ }()
80+ }
81+ wg .Wait ()
7782 })
7883
7984 t .Run ("buffered channel" , func (t * testing.T ) {
8085 ctx := context .Background ()
81- cb := func (ctx context.Context , item int ) {}
82- subject := NewWorkerPool (cb , WithChannelBufferSize (10 ))
86+ subject := NewWorkerPool (noop , WithChannelBufferSize (10 ))
8387 err := subject .Submit (ctx , 1 )
8488 require .NoError (t , err )
8589 err = subject .Submit (ctx , 2 )
@@ -91,8 +95,7 @@ func TestWorkerPoolLifeCycle(t *testing.T) {
9195 t .Run ("submit fails because of ctx cancellation" , func (t * testing.T ) {
9296 ctx , cancel := context .WithCancel (context .Background ())
9397
94- cb := func (ctx context.Context , item int ) {}
95- subject := NewWorkerPool (cb )
98+ subject := NewWorkerPool (noop )
9699
97100 // don't start workers to block the submit.
98101
@@ -107,8 +110,7 @@ func TestWorkerPoolLifeCycle(t *testing.T) {
107110
108111 t .Run ("submit fails because pool closes" , func (t * testing.T ) {
109112 ctx := context .Background ()
110- cb := func (ctx context.Context , item int ) {}
111- subject := NewWorkerPool (cb )
113+ subject := NewWorkerPool (noop )
112114
113115 // don't start workers to block the submit.
114116 subject .Start (ctx , 0 ) // noop start
@@ -124,8 +126,7 @@ func TestWorkerPoolLifeCycle(t *testing.T) {
124126
125127 t .Run ("send to closed pool" , func (t * testing.T ) {
126128 ctx := context .Background ()
127- cb := func (_ context.Context , _ int ) {}
128- subject := NewWorkerPool (cb )
129+ subject := NewWorkerPool (noop )
129130 subject .Start (ctx , 10 )
130131 subject .Stop (ctx )
131132 err := subject .Submit (ctx , 1 )
@@ -196,11 +197,56 @@ func TestWorkerPoolLifeCycle(t *testing.T) {
196197
197198 t .Run ("noop with logs" , func (t * testing.T ) {
198199 ctx := context .Background ()
199- cb := func (ctx context.Context , item int ) {}
200- subject := NewWorkerPool (cb , WithLogger (slog .New (slog .NewTextHandler (os .Stdout , & slog.HandlerOptions {Level : slog .LevelDebug }))))
200+ subject := NewWorkerPool (noop , WithLogger (slog .New (slog .NewTextHandler (os .Stdout , & slog.HandlerOptions {Level : slog .LevelDebug }))))
201201 subject .Start (ctx , 10 )
202202 subject .Stop (ctx )
203203 })
204+
205+ t .Run ("concurrent submits and close" , func (t * testing.T ) {
206+ for _ , testSize := range []int {1 , 10 , 100 , 1000 } {
207+ t .Run (fmt .Sprintf ("testSize=%d" , testSize ), func (t * testing.T ) {
208+ for _ , bufferSize := range []int {0 , 1 , 10 , 100 } {
209+ t .Run (fmt .Sprintf ("buffer=%d" , bufferSize ), func (t * testing.T ) {
210+ testConcurrentSubmitsAndClose (t , testSize , bufferSize )
211+ })
212+ }
213+ })
214+ }
215+ })
216+ }
217+
218+ func testConcurrentSubmitsAndClose (t * testing.T , testSize , bufferSize int ) {
219+ t .Helper ()
220+ ctx := context .Background ()
221+ seenSum := atomic.Int64 {}
222+ subject := NewWorkerPool (
223+ func (ctx context.Context , item int ) {
224+ seenSum .Add (int64 (item ))
225+ },
226+ WithChannelBufferSize (bufferSize ),
227+ )
228+ subject .Start (ctx , 10 )
229+ sentSum := atomic.Int64 {}
230+ var wg sync.WaitGroup
231+ wg .Add (testSize + 1 )
232+ for i := range testSize {
233+ if i == testSize / 10 {
234+ go func () {
235+ defer wg .Done ()
236+ subject .Stop (ctx )
237+ }()
238+ }
239+
240+ go func () {
241+ defer wg .Done ()
242+ err := subject .Submit (ctx , i )
243+ if err == nil {
244+ sentSum .Add (int64 (i ))
245+ }
246+ }()
247+ }
248+ wg .Wait ()
249+ assert .Equal (t , sentSum .Load (), seenSum .Load ())
204250}
205251
206252func TestMultipleSenders (t * testing.T ) {
@@ -235,11 +281,16 @@ func TestMultipleSenders(t *testing.T) {
235281 assert .Equal (t , int64 (senders * perSender ), count .Load ())
236282}
237283
284+ func BenchmarkNew (b * testing.B ) {
285+ for range b .N {
286+ NewWorkerPool (noop )
287+ }
288+ }
289+
238290func BenchmarkSubmit (b * testing.B ) {
239291 ctx := context .Background ()
240- cb := func (_ context.Context , _ int ) {}
241292
242- subject := NewWorkerPool (cb , WithChannelBufferSize (b .N + 1 ))
293+ subject := NewWorkerPool (noop , WithChannelBufferSize (b .N + 1 ))
243294
244295 b .ResetTimer ()
245296 for i := range b .N {
@@ -250,11 +301,17 @@ func BenchmarkSubmit(b *testing.B) {
250301 subject .Stop (ctx )
251302}
252303
304+ func BenchmarkStop (b * testing.B ) {
305+ ctx := context .Background ()
306+ for range b .N {
307+ NewWorkerPool (noop ).Stop (ctx )
308+ }
309+ }
310+
253311func BenchmarkWork (b * testing.B ) {
254312 ctx := context .Background ()
255- cb := func (_ context.Context , i int ) {}
256313
257- subject := NewWorkerPool (cb , WithChannelBufferSize (b .N + 1 ))
314+ subject := NewWorkerPool (noop , WithChannelBufferSize (b .N + 1 ))
258315
259316 for i := range b .N {
260317 _ = subject .Submit (ctx , i )
@@ -303,9 +360,8 @@ func BenchmarkFullFlow(b *testing.B) {
303360 for idx , tc := range tests {
304361 b .Run (fmt .Sprintf ("%d_w%d_s%d_b%d" , idx , tc .workers , tc .senders , tc .channelBufferSize ), func (b * testing.B ) {
305362 ctx := context .Background ()
306- cb := func (_ context.Context , i int ) {}
307363
308- subject := NewWorkerPool (cb , WithChannelBufferSize (tc .channelBufferSize ))
364+ subject := NewWorkerPool (noop , WithChannelBufferSize (tc .channelBufferSize ))
309365
310366 start := make (chan struct {})
311367
@@ -325,6 +381,8 @@ func BenchmarkFullFlow(b *testing.B) {
325381 }
326382}
327383
384+ func noop (context.Context , int ) {}
385+
328386func mockSender (b * testing.B , ctx context.Context , wg * sync.WaitGroup , start chan struct {}, subject * WorkerPool [int ]) {
329387 b .Helper ()
330388 defer wg .Done ()
0 commit comments