diff --git a/internal/transport/client_stream.go b/internal/transport/client_stream.go index 980452519ea7..6fb101ca4d92 100644 --- a/internal/transport/client_stream.go +++ b/internal/transport/client_stream.go @@ -24,6 +24,7 @@ import ( "golang.org/x/net/http2" "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" ) @@ -50,6 +51,7 @@ type ClientStream struct { headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream + statsHandler stats.Handler } // Read reads an n byte message from the input stream. diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 44e814a80b19..c943503f3590 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -478,7 +478,7 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts return t, nil } -func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientStream { +func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) *ClientStream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &ClientStream{ Stream: Stream{ @@ -486,10 +486,11 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientSt sendCompress: callHdr.SendCompress, contentSubtype: callHdr.ContentSubtype, }, - ct: t, - done: make(chan struct{}), - headerChan: make(chan struct{}), - doneFunc: callHdr.DoneFunc, + ct: t, + done: make(chan struct{}), + headerChan: make(chan struct{}), + doneFunc: callHdr.DoneFunc, + statsHandler: handler, } s.Stream.buf.init() s.Stream.wq.init(defaultWriteQuota, s.done) @@ -744,7 +745,7 @@ func (e NewStreamError) Error() string { // NewStream creates a stream and registers it into the transport as "active" // streams. All non-nil errors returned will be *NewStreamError. -func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) { +func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error) { ctx = peer.NewContext(ctx, t.Peer()) // ServerName field of the resolver returned address takes precedence over @@ -781,7 +782,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS if err != nil { return nil, &NewStreamError{Err: err, AllowTransparentRetry: false} } - s := t.newStream(ctx, callHdr) + s := t.newStream(ctx, callHdr, handler) cleanup := func(err error) { if s.swapState(streamDone) == streamDone { // If it was already done, return. @@ -906,7 +907,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true} } } - if t.statsHandler != nil { + if s.statsHandler != nil { header, ok := metadata.FromOutgoingContext(ctx) if ok { header.Set("user-agent", t.userAgent) @@ -915,7 +916,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS } // Note: The header fields are compressed with hpack after this call returns. // No WireLength field is set here. - t.statsHandler.HandleRPC(s.ctx, &stats.OutHeader{ + s.statsHandler.HandleRPC(s.ctx, &stats.OutHeader{ Client: true, FullMethod: callHdr.Method, RemoteAddr: t.remoteAddr, @@ -1591,16 +1592,16 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } } - if t.statsHandler != nil { + if s.statsHandler != nil { if !endStream { - t.statsHandler.HandleRPC(s.ctx, &stats.InHeader{ + s.statsHandler.HandleRPC(s.ctx, &stats.InHeader{ Client: true, WireLength: int(frame.Header().Length), Header: metadata.MD(mdata).Copy(), Compression: s.recvCompress, }) } else { - t.statsHandler.HandleRPC(s.ctx, &stats.InTrailer{ + s.statsHandler.HandleRPC(s.ctx, &stats.InTrailer{ Client: true, WireLength: int(frame.Header().Length), Trailer: metadata.MD(mdata).Copy(), diff --git a/internal/transport/keepalive_test.go b/internal/transport/keepalive_test.go index 324d6a57696f..8e8f8e5b86bc 100644 --- a/internal/transport/keepalive_test.go +++ b/internal/transport/keepalive_test.go @@ -69,7 +69,7 @@ func (s) TestMaxConnectionIdle(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - stream, err := client.NewStream(ctx, &CallHdr{}) + stream, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("client.NewStream() failed: %v", err) } @@ -111,7 +111,7 @@ func (s) TestMaxConnectionIdleBusyClient(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - _, err := client.NewStream(ctx, &CallHdr{}) + _, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("client.NewStream() failed: %v", err) } @@ -150,7 +150,7 @@ func (s) TestMaxConnectionAge(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := client.NewStream(ctx, &CallHdr{}); err != nil { + if _, err := client.NewStream(ctx, &CallHdr{}, nil); err != nil { t.Fatalf("client.NewStream() failed: %v", err) } @@ -373,7 +373,7 @@ func (s) TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() // Create a stream, but send no data on it. - if _, err := client.NewStream(ctx, &CallHdr{}); err != nil { + if _, err := client.NewStream(ctx, &CallHdr{}, nil); err != nil { t.Fatalf("Stream creation failed: %v", err) } @@ -519,7 +519,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := client.NewStream(ctx, &CallHdr{}); err != nil { + if _, err := client.NewStream(ctx, &CallHdr{}, nil); err != nil { t.Fatalf("Stream creation failed: %v", err) } @@ -748,7 +748,7 @@ func (s) TestTCPUserTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - stream, err := client.NewStream(ctx, &CallHdr{}) + stream, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("client.NewStream() failed: %v", err) } @@ -815,7 +815,7 @@ func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials func checkForHealthyStream(client *http2Client) error { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - stream, err := client.NewStream(ctx, &CallHdr{}) + stream, err := client.NewStream(ctx, &CallHdr{}, nil) stream.Close(err) return err } @@ -824,7 +824,7 @@ func pollForStreamCreationError(client *http2Client) error { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() for { - if _, err := client.NewStream(ctx, &CallHdr{}); err != nil { + if _, err := client.NewStream(ctx, &CallHdr{}, nil); err != nil { break } time.Sleep(50 * time.Millisecond) @@ -850,7 +850,7 @@ func waitForGoAwayTooManyPings(client *http2Client) error { return fmt.Errorf("test timed out before getting GoAway with reason:GoAwayTooManyPings from server") } - if _, err := client.NewStream(ctx, &CallHdr{}); err == nil { + if _, err := client.NewStream(ctx, &CallHdr{}, nil); err == nil { return fmt.Errorf("stream creation succeeded after receiving a GoAway from the server") } return nil diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 10b9155f0932..b86094da9433 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -617,7 +617,7 @@ type ClientTransport interface { GracefulClose() // NewStream creates a Stream for an RPC. - NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) + NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error) // Error returns a channel that is closed when some I/O error // happens. Typically the caller should have a goroutine to monitor diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index d1ac1286dcfb..2d32216ba6fa 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -630,7 +630,7 @@ func (s) TestInflightStreamClosing(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - stream, err := client.NewStream(ctx, &CallHdr{}) + stream, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("Client failed to create RPC request: %v", err) } @@ -678,7 +678,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() - s, err := ct.NewStream(ctx, callHdr) + s, err := ct.NewStream(ctx, callHdr, nil) if err != nil { t.Fatalf("ct.NewStream() = %v", err) } @@ -691,7 +691,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { } // The expected stream ID here is 3 since stream IDs are incremented by 2. - s, err = ct.NewStream(ctx, callHdr) + s, err = ct.NewStream(ctx, callHdr, nil) if err != nil { t.Fatalf("ct.NewStream() = %v", err) } @@ -714,14 +714,14 @@ func (s) TestClientSendAndReceive(t *testing.T) { } ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() - s1, err1 := ct.NewStream(ctx, callHdr) + s1, err1 := ct.NewStream(ctx, callHdr, nil) if err1 != nil { t.Fatalf("failed to open stream: %v", err1) } if s1.id != 1 { t.Fatalf("wrong stream id: %d", s1.id) } - s2, err2 := ct.NewStream(ctx, callHdr) + s2, err2 := ct.NewStream(ctx, callHdr, nil) if err2 != nil { t.Fatalf("failed to open stream: %v", err2) } @@ -761,7 +761,7 @@ func performOneRPC(ct ClientTransport) { } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - s, err := ct.NewStream(ctx, callHdr) + s, err := ct.NewStream(ctx, callHdr, nil) if err != nil { return } @@ -807,7 +807,7 @@ func (s) TestLargeMessage(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - s, err := ct.NewStream(ctx, callHdr) + s, err := ct.NewStream(ctx, callHdr, nil) if err != nil { t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) } @@ -855,7 +855,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - s, err := ct.NewStream(ctx, callHdr) + s, err := ct.NewStream(ctx, callHdr, nil) if err != nil { t.Fatalf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) return @@ -954,7 +954,7 @@ func (s) TestGracefulClose(t *testing.T) { // Create a stream that will exist for this whole test and confirm basic // functionality. - s, err := ct.NewStream(ctx, &CallHdr{}) + s, err := ct.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("NewStream(_, _) = _, %v, want _, ", err) } @@ -986,7 +986,7 @@ func (s) TestGracefulClose(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - _, err := ct.NewStream(ctx, &CallHdr{}) + _, err := ct.NewStream(ctx, &CallHdr{}, nil) if err != nil && err.(*NewStreamError).Err == ErrConnClosing && err.(*NewStreamError).AllowTransparentRetry { return } @@ -1014,7 +1014,7 @@ func (s) TestLargeMessageSuspension(t *testing.T) { // Set a long enough timeout for writing a large message out. ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - s, err := ct.NewStream(ctx, callHdr) + s, err := ct.NewStream(ctx, callHdr, nil) if err != nil { t.Fatalf("failed to open stream: %v", err) } @@ -1054,7 +1054,7 @@ func (s) TestMaxStreams(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - s, err := ct.NewStream(ctx, callHdr) + s, err := ct.NewStream(ctx, callHdr, nil) if err != nil { t.Fatalf("Failed to open stream: %v", err) } @@ -1075,7 +1075,7 @@ func (s) TestMaxStreams(t *testing.T) { // This is only to get rid of govet. All these context are based on a base // context which is canceled at the end of the test. defer cancel() - if str, err := ct.NewStream(ctx, callHdr); err == nil { + if str, err := ct.NewStream(ctx, callHdr, nil); err == nil { slist = append(slist, str) continue } else if err.Error() != expectedErr.Error() { @@ -1090,7 +1090,7 @@ func (s) TestMaxStreams(t *testing.T) { defer close(done) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - if _, err := ct.NewStream(ctx, callHdr); err != nil { + if _, err := ct.NewStream(ctx, callHdr, nil); err != nil { t.Errorf("Failed to open stream: %v", err) } }() @@ -1141,7 +1141,7 @@ func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - s, err := ct.NewStream(ctx, callHdr) + s, err := ct.NewStream(ctx, callHdr, nil) if err != nil { t.Fatalf("Failed to open stream: %v", err) } @@ -1211,7 +1211,7 @@ func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) { server.mu.Unlock() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cstream1, err := client.NewStream(ctx, &CallHdr{}) + cstream1, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("Client failed to create first stream. Err: %v", err) } @@ -1238,7 +1238,7 @@ func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) { server.h.notify = notifyChan server.mu.Unlock() // Create another stream on client. - cstream2, err := client.NewStream(ctx, &CallHdr{}) + cstream2, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("Client failed to create second stream. Err: %v", err) } @@ -1300,7 +1300,7 @@ func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() server.mu.Unlock() - cstream1, err := client.NewStream(ctx, &CallHdr{}) + cstream1, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("Failed to create 1st stream. Err: %v", err) } @@ -1309,7 +1309,7 @@ func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) { t.Fatalf("Client failed to write data. Err: %v", err) } // Client should be able to create another stream and send data on it. - cstream2, err := client.NewStream(ctx, &CallHdr{}) + cstream2, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("Failed to create 2nd stream. Err: %v", err) } @@ -1587,7 +1587,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) { } defer ct.Close(fmt.Errorf("closed manually by test")) - str, err := ct.NewStream(connectCtx, &CallHdr{}) + str, err := ct.NewStream(connectCtx, &CallHdr{}, nil) if err != nil { t.Fatalf("Error while creating stream: %v", err) } @@ -1617,7 +1617,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - s, err := ct.NewStream(ctx, callHdr) + s, err := ct.NewStream(ctx, callHdr, nil) if err != nil { return } @@ -1647,7 +1647,7 @@ func (s) TestInvalidHeaderField(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - s, err := ct.NewStream(ctx, callHdr) + s, err := ct.NewStream(ctx, callHdr, nil) if err != nil { return } @@ -1667,7 +1667,7 @@ func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { defer ct.Close(fmt.Errorf("closed manually by test")) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}) + s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}, nil) if err != nil { t.Fatalf("failed to create the stream") } @@ -1797,7 +1797,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) clientStreams := make([]*ClientStream, numStreams) for i := 0; i < numStreams; i++ { var err error - clientStreams[i], err = client.NewStream(ctx, &CallHdr{}) + clientStreams[i], err = client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("Failed to create stream. Err: %v", err) } @@ -2348,7 +2348,7 @@ func (s) TestWriteHeaderConnectionError(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cstream, err := client.NewStream(ctx, &CallHdr{}) + cstream, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("Client failed to create first stream. Err: %v", err) } @@ -2412,7 +2412,7 @@ func runPingPongTest(t *testing.T, msgSize int) { }) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - stream, err := client.NewStream(ctx, &CallHdr{}) + stream, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("Failed to create stream. Err: %v", err) } @@ -2486,7 +2486,7 @@ func (s) TestHeaderTblSize(t *testing.T) { defer server.stop() ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() - _, err := ct.NewStream(ctx, &CallHdr{}) + _, err := ct.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("failed to open stream: %v", err) } @@ -3005,7 +3005,7 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { if err != nil { t.Fatalf("Error while creating client transport: %v", err) } - _, err = ct.NewStream(ctx, &CallHdr{}) + _, err = ct.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("failed to open stream: %v", err) } @@ -3075,7 +3075,7 @@ func (s) TestClientCloseReturnsAfterReaderCompletes(t *testing.T) { t.Fatalf("Failed to create transport: %v", err) } - if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { + if _, err := ct.NewStream(ctx, &CallHdr{}, nil); err != nil { t.Fatalf("Failed to open stream: %v", err) } @@ -3167,7 +3167,7 @@ func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) { t.Fatalf("failed to create transport: %v", connErr) } - if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { + if _, err := ct.NewStream(ctx, &CallHdr{}, nil); err != nil { t.Fatalf("Failed to open stream: %v", err) } @@ -3502,7 +3502,7 @@ func (s) TestDeleteStreamMetricsIncrementedOnlyOnce(t *testing.T) { t.Fatal("Server transport not found") } - clientStream, err := client.NewStream(ctx, &CallHdr{}) + clientStream, err := client.NewStream(ctx, &CallHdr{}, nil) if err != nil { t.Fatalf("Failed to create stream: %v", err) } diff --git a/stream.go b/stream.go index f92102fb4fbf..5a28e2336cc5 100644 --- a/stream.go +++ b/stream.go @@ -548,7 +548,7 @@ func (a *csAttempt) newStream() error { } } } - s, err := a.transport.NewStream(a.ctx, cs.callHdr) + s, err := a.transport.NewStream(a.ctx, cs.callHdr, a.statsHandler) if err != nil { nse, ok := err.(*transport.NewStreamError) if !ok { @@ -1354,7 +1354,7 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin transport: t, } - s, err := as.transport.NewStream(as.ctx, as.callHdr) + s, err := as.transport.NewStream(as.ctx, as.callHdr, nil) if err != nil { err = toRPCErr(err) return nil, err