Skip to content

Commit c76188c

Browse files
committed
Fix forwardingStreamHandler errgroup race
Commit 5b5db75, "Add generic gRPC stream forwarding (#306)", introduced a flakiness in the grpc_test due to calling `errgroup.Go(func() error { return err })` when an error had occurred in a go routine that was asynchrounous to the errgroup.Wait() call. This could lead to `errgroup.Go()` being called just when `errgroup.Wait()` was about to return, which is not allowed.
1 parent 6869e65 commit c76188c

File tree

2 files changed

+48
-45
lines changed

2 files changed

+48
-45
lines changed

pkg/grpc/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ go_library(
7575
"@org_golang_google_protobuf//proto",
7676
"@org_golang_google_protobuf//reflect/protoreflect",
7777
"@org_golang_google_protobuf//types/known/emptypb",
78-
"@org_golang_x_sync//errgroup",
7978
"@org_golang_x_sync//semaphore",
8079
] + select({
8180
"@rules_go//go/platform:android": [
Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package grpc
22

33
import (
4+
"context"
45
"io"
56

6-
"golang.org/x/sync/errgroup"
77
"google.golang.org/grpc"
88
"google.golang.org/protobuf/types/known/emptypb"
99
)
@@ -38,55 +38,59 @@ func (s *forwardingStreamHandler) HandleStream(srv any, incomingStream grpc.Serv
3838
ServerStreams: true,
3939
ClientStreams: true,
4040
}
41-
group, groupCtx := errgroup.WithContext(incomingStream.Context())
42-
group.Go(func() error {
43-
// groupCtx is guaranteed to be canceled before returning from this method, so outgoingStream will not leak resources.
44-
outgoingStream, err := s.backend.NewStream(groupCtx, &desc, method)
45-
if err != nil {
46-
return err
47-
}
48-
// Avoid group.Go because incomingStream.RecvMsg might block returning
49-
// an error from the outgoingStream and getting the context for
50-
// incomingStream canceled.
51-
go func() {
52-
for {
53-
msg := &emptypb.Empty{}
54-
if err := incomingStream.RecvMsg(msg); err != nil {
55-
if err == io.EOF {
56-
// Let's continue to receive on outgoingStream, so don't
57-
// cancel grouptCtx.
58-
outgoingStream.CloseSend()
59-
return
60-
}
61-
// Cancel groupCtx immediately.
62-
group.Go(func() error { return err })
63-
return
64-
}
65-
if err := outgoingStream.SendMsg(msg); err != nil {
66-
if err == io.EOF {
67-
// The error will be returned by outgoingStream.RecvMsg(),
68-
// no need to cancel groupCtx now.
69-
return
70-
}
71-
// Cancel groupCtx immediately.
72-
group.Go(func() error { return err })
73-
return
74-
}
75-
}
76-
}()
41+
ctx, cancel := context.WithCancelCause(incomingStream.Context())
42+
defer cancel(nil)
43+
44+
// ctx is guaranteed to be canceled when returning from this method, so
45+
// outgoingStream will not leak resources.
46+
outgoingStream, err := s.backend.NewStream(ctx, &desc, method)
47+
if err != nil {
48+
return err
49+
}
7750

51+
// The only way to cancel a blocking incomingStream.RecvMsg is to return
52+
// from this method. Therefore, an error from outgoingStream.RecvMsg
53+
// needs to be returned without waiting for incomingStream.RecvMsg, so
54+
// it cannot be run inside e.g. errgroup.Go.
55+
go func() {
7856
for {
7957
msg := &emptypb.Empty{}
80-
if err := outgoingStream.RecvMsg(msg); err != nil {
58+
if err := incomingStream.RecvMsg(msg); err != nil {
8159
if err == io.EOF {
82-
return nil
60+
// Let's continue to receive on outgoingStream, so don't
61+
// cancel grouptCtx.
62+
outgoingStream.CloseSend()
63+
return
8364
}
84-
return err
65+
// Cancel ctx immediately.
66+
cancel(err)
67+
return
8568
}
86-
if err := incomingStream.SendMsg(msg); err != nil {
87-
return err
69+
if err := outgoingStream.SendMsg(msg); err != nil {
70+
if err == io.EOF {
71+
// The error will be returned by outgoingStream.RecvMsg(),
72+
// no need to cancel ctx now.
73+
return
74+
}
75+
// Cancel ctx immediately.
76+
cancel(err)
77+
return
78+
}
79+
}
80+
}()
81+
82+
for {
83+
msg := &emptypb.Empty{}
84+
if err := outgoingStream.RecvMsg(msg); err != nil {
85+
if err != io.EOF {
86+
cancel(err)
8887
}
88+
break
8989
}
90-
})
91-
return group.Wait()
90+
if err := incomingStream.SendMsg(msg); err != nil {
91+
cancel(err)
92+
break
93+
}
94+
}
95+
return context.Cause(ctx)
9296
}

0 commit comments

Comments
 (0)