Skip to content

Commit 6995ef2

Browse files
authored
internal/transport: Wait for server goroutines to exit during shutdown in test (#8306)
1 parent aaabd60 commit 6995ef2

File tree

1 file changed

+91
-33
lines changed

1 file changed

+91
-33
lines changed

internal/transport/transport_test.go

Lines changed: 91 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -320,21 +320,23 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream)
320320
}
321321

322322
type server struct {
323-
lis net.Listener
324-
port string
325-
startedErr chan error // error (or nil) with server start value
326-
mu sync.Mutex
327-
conns map[ServerTransport]net.Conn
328-
h *testStreamHandler
329-
ready chan struct{}
330-
channelz *channelz.Server
323+
lis net.Listener
324+
port string
325+
startedErr chan error // error (or nil) with server start value
326+
mu sync.Mutex
327+
conns map[ServerTransport]net.Conn
328+
h *testStreamHandler
329+
ready chan struct{}
330+
channelz *channelz.Server
331+
servingTasksDone chan struct{}
331332
}
332333

333334
func newTestServer() *server {
334335
return &server{
335-
startedErr: make(chan error, 1),
336-
ready: make(chan struct{}),
337-
channelz: channelz.RegisterServer("test server"),
336+
startedErr: make(chan error, 1),
337+
ready: make(chan struct{}),
338+
servingTasksDone: make(chan struct{}),
339+
channelz: channelz.RegisterServer("test server"),
338340
}
339341
}
340342

@@ -358,6 +360,12 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
358360
s.port = p
359361
s.conns = make(map[ServerTransport]net.Conn)
360362
s.startedErr <- nil
363+
wg := sync.WaitGroup{}
364+
defer func() {
365+
wg.Wait()
366+
close(s.servingTasksDone)
367+
}()
368+
361369
for {
362370
conn, err := s.lis.Accept()
363371
if err != nil {
@@ -383,40 +391,89 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
383391
s.mu.Unlock()
384392
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
385393
defer cancel()
394+
wg.Add(1)
386395
switch ht {
387396
case notifyCall:
388-
go transport.HandleStreams(ctx, h.handleStreamAndNotify)
397+
go func() {
398+
transport.HandleStreams(ctx, h.handleStreamAndNotify)
399+
wg.Done()
400+
}()
389401
case suspended:
390-
go transport.HandleStreams(ctx, func(*ServerStream) {})
402+
go func() {
403+
transport.HandleStreams(ctx, func(*ServerStream) {})
404+
wg.Done()
405+
}()
391406
case misbehaved:
392-
go transport.HandleStreams(ctx, func(s *ServerStream) {
393-
go h.handleStreamMisbehave(t, s)
394-
})
407+
go func() {
408+
transport.HandleStreams(ctx, func(s *ServerStream) {
409+
wg.Add(1)
410+
go func() {
411+
h.handleStreamMisbehave(t, s)
412+
wg.Done()
413+
}()
414+
})
415+
wg.Done()
416+
}()
395417
case encodingRequiredStatus:
396-
go transport.HandleStreams(ctx, func(s *ServerStream) {
397-
go h.handleStreamEncodingRequiredStatus(s)
398-
})
418+
go func() {
419+
transport.HandleStreams(ctx, func(s *ServerStream) {
420+
wg.Add(1)
421+
go func() {
422+
h.handleStreamEncodingRequiredStatus(s)
423+
wg.Done()
424+
}()
425+
})
426+
wg.Done()
427+
}()
399428
case invalidHeaderField:
400-
go transport.HandleStreams(ctx, func(s *ServerStream) {
401-
go h.handleStreamInvalidHeaderField(s)
402-
})
429+
go func() {
430+
transport.HandleStreams(ctx, func(s *ServerStream) {
431+
wg.Add(1)
432+
go func() {
433+
h.handleStreamInvalidHeaderField(s)
434+
wg.Done()
435+
}()
436+
})
437+
wg.Done()
438+
}()
403439
case delayRead:
404440
h.notify = make(chan struct{})
405441
h.getNotified = make(chan struct{})
406442
s.mu.Lock()
407443
close(s.ready)
408444
s.mu.Unlock()
409-
go transport.HandleStreams(ctx, func(s *ServerStream) {
410-
go h.handleStreamDelayRead(t, s)
411-
})
445+
go func() {
446+
transport.HandleStreams(ctx, func(s *ServerStream) {
447+
wg.Add(1)
448+
go func() {
449+
h.handleStreamDelayRead(t, s)
450+
wg.Done()
451+
}()
452+
})
453+
wg.Done()
454+
}()
412455
case pingpong:
413-
go transport.HandleStreams(ctx, func(s *ServerStream) {
414-
go h.handleStreamPingPong(t, s)
415-
})
456+
go func() {
457+
transport.HandleStreams(ctx, func(s *ServerStream) {
458+
wg.Add(1)
459+
go func() {
460+
h.handleStreamPingPong(t, s)
461+
wg.Done()
462+
}()
463+
})
464+
wg.Done()
465+
}()
416466
default:
417-
go transport.HandleStreams(ctx, func(s *ServerStream) {
418-
go h.handleStream(t, s)
419-
})
467+
go func() {
468+
transport.HandleStreams(ctx, func(s *ServerStream) {
469+
wg.Add(1)
470+
go func() {
471+
h.handleStream(t, s)
472+
wg.Done()
473+
}()
474+
})
475+
wg.Done()
476+
}()
420477
}
421478
}
422479
}
@@ -440,6 +497,7 @@ func (s *server) stop() {
440497
}
441498
s.conns = nil
442499
s.mu.Unlock()
500+
<-s.servingTasksDone
443501
}
444502

445503
func (s *server) addr() string {
@@ -2254,11 +2312,11 @@ func (s) TestPingPong1B(t *testing.T) {
22542312
runPingPongTest(t, 1)
22552313
}
22562314

2257-
func TestPingPong1KB(t *testing.T) {
2315+
func (s) TestPingPong1KB(t *testing.T) {
22582316
runPingPongTest(t, 1024)
22592317
}
22602318

2261-
func TestPingPong64KB(t *testing.T) {
2319+
func (s) TestPingPong64KB(t *testing.T) {
22622320
runPingPongTest(t, 65536)
22632321
}
22642322

0 commit comments

Comments
 (0)