Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions internal/transport/client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"golang.org/x/net/http2"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
)

Expand All @@ -50,6 +51,7 @@ type ClientStream struct {
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream
unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream
statsHandler stats.Handler
}

// Read reads an n byte message from the input stream.
Expand Down
25 changes: 13 additions & 12 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,18 +478,19 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
return t, nil
}

func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientStream {
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) *ClientStream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &ClientStream{
Stream: Stream{
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
contentSubtype: callHdr.ContentSubtype,
},
ct: t,
done: make(chan struct{}),
headerChan: make(chan struct{}),
doneFunc: callHdr.DoneFunc,
ct: t,
done: make(chan struct{}),
headerChan: make(chan struct{}),
doneFunc: callHdr.DoneFunc,
statsHandler: handler,
}
s.Stream.buf.init()
s.Stream.wq.init(defaultWriteQuota, s.done)
Expand Down Expand Up @@ -744,7 +745,7 @@ func (e NewStreamError) Error() string {

// NewStream creates a stream and registers it into the transport as "active"
// streams. All non-nil errors returned will be *NewStreamError.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) {
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error) {
ctx = peer.NewContext(ctx, t.Peer())

// ServerName field of the resolver returned address takes precedence over
Expand Down Expand Up @@ -781,7 +782,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
if err != nil {
return nil, &NewStreamError{Err: err, AllowTransparentRetry: false}
}
s := t.newStream(ctx, callHdr)
s := t.newStream(ctx, callHdr, handler)
cleanup := func(err error) {
if s.swapState(streamDone) == streamDone {
// If it was already done, return.
Expand Down Expand Up @@ -906,7 +907,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true}
}
}
if t.statsHandler != nil {
if s.statsHandler != nil {
header, ok := metadata.FromOutgoingContext(ctx)
if ok {
header.Set("user-agent", t.userAgent)
Expand All @@ -915,7 +916,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
}
// Note: The header fields are compressed with hpack after this call returns.
// No WireLength field is set here.
t.statsHandler.HandleRPC(s.ctx, &stats.OutHeader{
s.statsHandler.HandleRPC(s.ctx, &stats.OutHeader{
Client: true,
FullMethod: callHdr.Method,
RemoteAddr: t.remoteAddr,
Expand Down Expand Up @@ -1591,16 +1592,16 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
}
}

if t.statsHandler != nil {
if s.statsHandler != nil {
if !endStream {
t.statsHandler.HandleRPC(s.ctx, &stats.InHeader{
s.statsHandler.HandleRPC(s.ctx, &stats.InHeader{
Client: true,
WireLength: int(frame.Header().Length),
Header: metadata.MD(mdata).Copy(),
Compression: s.recvCompress,
})
} else {
t.statsHandler.HandleRPC(s.ctx, &stats.InTrailer{
s.statsHandler.HandleRPC(s.ctx, &stats.InTrailer{
Client: true,
WireLength: int(frame.Header().Length),
Trailer: metadata.MD(mdata).Copy(),
Expand Down
18 changes: 9 additions & 9 deletions internal/transport/keepalive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (s) TestMaxConnectionIdle(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
stream, err := client.NewStream(ctx, &CallHdr{})
stream, err := client.NewStream(ctx, &CallHdr{}, nil)
if err != nil {
t.Fatalf("client.NewStream() failed: %v", err)
}
Expand Down Expand Up @@ -111,7 +111,7 @@ func (s) TestMaxConnectionIdleBusyClient(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err := client.NewStream(ctx, &CallHdr{})
_, err := client.NewStream(ctx, &CallHdr{}, nil)
if err != nil {
t.Fatalf("client.NewStream() failed: %v", err)
}
Expand Down Expand Up @@ -150,7 +150,7 @@ func (s) TestMaxConnectionAge(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := client.NewStream(ctx, &CallHdr{}); err != nil {
if _, err := client.NewStream(ctx, &CallHdr{}, nil); err != nil {
t.Fatalf("client.NewStream() failed: %v", err)
}

Expand Down Expand Up @@ -373,7 +373,7 @@ func (s) TestKeepaliveClientClosesWithActiveStreams(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Create a stream, but send no data on it.
if _, err := client.NewStream(ctx, &CallHdr{}); err != nil {
if _, err := client.NewStream(ctx, &CallHdr{}, nil); err != nil {
t.Fatalf("Stream creation failed: %v", err)
}

Expand Down Expand Up @@ -519,7 +519,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := client.NewStream(ctx, &CallHdr{}); err != nil {
if _, err := client.NewStream(ctx, &CallHdr{}, nil); err != nil {
t.Fatalf("Stream creation failed: %v", err)
}

Expand Down Expand Up @@ -748,7 +748,7 @@ func (s) TestTCPUserTimeout(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
stream, err := client.NewStream(ctx, &CallHdr{})
stream, err := client.NewStream(ctx, &CallHdr{}, nil)
if err != nil {
t.Fatalf("client.NewStream() failed: %v", err)
}
Expand Down Expand Up @@ -815,7 +815,7 @@ func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials
func checkForHealthyStream(client *http2Client) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
stream, err := client.NewStream(ctx, &CallHdr{})
stream, err := client.NewStream(ctx, &CallHdr{}, nil)
stream.Close(err)
return err
}
Expand All @@ -824,7 +824,7 @@ func pollForStreamCreationError(client *http2Client) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for {
if _, err := client.NewStream(ctx, &CallHdr{}); err != nil {
if _, err := client.NewStream(ctx, &CallHdr{}, nil); err != nil {
break
}
time.Sleep(50 * time.Millisecond)
Expand All @@ -850,7 +850,7 @@ func waitForGoAwayTooManyPings(client *http2Client) error {
return fmt.Errorf("test timed out before getting GoAway with reason:GoAwayTooManyPings from server")
}

if _, err := client.NewStream(ctx, &CallHdr{}); err == nil {
if _, err := client.NewStream(ctx, &CallHdr{}, nil); err == nil {
return fmt.Errorf("stream creation succeeded after receiving a GoAway from the server")
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ type ClientTransport interface {
GracefulClose()

// NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error)
NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error)

// Error returns a channel that is closed when some I/O error
// happens. Typically the caller should have a goroutine to monitor
Expand Down
Loading
Loading