Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ func (kc *kClient) initMiddlewares(ctx context.Context) (mw middleware) {

// 3. stream middlewares
kc.opt.Streaming.InitMiddlewares(ctx)
kc.opt.StreamOptions.EventHandler = kc.opt.TracerCtl.GetStreamEventHandler()
mw.smws = mw.mws
if kc.opt.XDSEnabled && kc.opt.XDSRouterMiddleware != nil && kc.opt.Proxy == nil {
// integrate xds if enabled
Expand Down
26 changes: 26 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1285,3 +1285,29 @@ func Test_initRPCInfoWithStreamClientCallOption(t *testing.T) {
_, ri, _ = cli.initRPCInfo(ctx, mtd, 0, nil, true)
test.Assert(t, ri.Config().StreamRecvTimeout() == callOptTimeout)
}

func Test_WithStreamEventHandler(t *testing.T) {
svcInfo := mocks.ServiceInfo()
testService := "testService"

cliIntf, err := NewClient(svcInfo, WithDestService(testService))
test.Assert(t, err == nil, err)
cli := cliIntf.(*kcFinalizerClient)
test.Assert(t, len(cli.opt.StreamOptions.StreamEventHandlers) == 0, cli.opt)

cliIntf, err = NewClient(svcInfo,
WithDestService(testService),
WithStreamOptions(
WithStreamEventHandler(rpcinfo.StreamEventHandler{}),
WithStreamEventHandler(rpcinfo.StreamEventHandler{
HandleStreamStartEvent: func(ctx context.Context, ri rpcinfo.RPCInfo, evt rpcinfo.StreamStartEvent) {},
HandleStreamRecvEvent: func(ctx context.Context, ri rpcinfo.RPCInfo, evt rpcinfo.StreamRecvEvent) {},
HandleStreamSendEvent: func(ctx context.Context, ri rpcinfo.RPCInfo, evt rpcinfo.StreamSendEvent) {},
HandleStreamRecvHeaderEvent: func(ctx context.Context, ri rpcinfo.RPCInfo, evt rpcinfo.StreamRecvHeaderEvent) {},
HandleStreamFinishEvent: func(ctx context.Context, ri rpcinfo.RPCInfo, evt rpcinfo.StreamFinishEvent) {},
}),
))
test.Assert(t, err == nil, err)
cli = cliIntf.(*kcFinalizerClient)
test.Assert(t, len(cli.opt.StreamOptions.StreamEventHandlers) == 2, cli.opt)
}
10 changes: 10 additions & 0 deletions client/option_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/cloudwego/kitex/internal/client"
"github.com/cloudwego/kitex/pkg/endpoint/cep"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/utils"
)

Expand Down Expand Up @@ -102,3 +103,12 @@ func WithStreamSendMiddlewareBuilder(mwb cep.StreamSendMiddlewareBuilder) Stream
o.StreamSendMiddlewareBuilders = append(o.StreamSendMiddlewareBuilders, mwb)
}}
}

// WithStreamEventHandler add StreamEventHandler for detailed streaming event tracing
func WithStreamEventHandler(hdl rpcinfo.StreamEventHandler) StreamOption {
return StreamOption{F: func(o *client.StreamOptions, di *utils.Slice) {
di.Push(fmt.Sprintf("WithStreamEventHandler(%+v)", hdl))

o.StreamEventHandlers = append(o.StreamEventHandlers, hdl)
}}
}
3 changes: 3 additions & 0 deletions client/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,17 @@ func TestRetryOption(t *testing.T) {
func TestTransportProtocolOption(t *testing.T) {
opts := []client.Option{WithTransportProtocol(transport.GRPC)}
options := client.NewOptions(opts)
test.Assert(t, options.GRPCConnectOpts.TraceController != nil)
test.Assert(t, options.RemoteOpt.ConnPool != nil)
test.Assert(t, options.RemoteOpt.CliHandlerFactory != nil)
opts = []client.Option{WithTransportProtocol(transport.GRPCStreaming)}
options = client.NewOptions(opts)
test.Assert(t, options.GRPCConnectOpts.TraceController != nil)
test.Assert(t, options.RemoteOpt.GRPCStreamingConnPool != nil)
test.Assert(t, options.RemoteOpt.GRPCStreamingCliHandlerFactory != nil)
opts = []client.Option{WithTransportProtocol(transport.TTHeaderStreaming)}
options = client.NewOptions(opts)
test.Assert(t, len(options.TTHeaderStreamingOptions.TransportOptions) == 1)
test.Assert(t, options.RemoteOpt.TTHeaderStreamingCliHandlerFactory != nil)
}

Expand Down
46 changes: 23 additions & 23 deletions client/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"io"
"sync/atomic"

internal_stream "github.com/cloudwego/kitex/internal/stream"
"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/endpoint/cep"
"github.com/cloudwego/kitex/pkg/kerrors"
Expand All @@ -31,7 +30,6 @@ import (
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/stats"
"github.com/cloudwego/kitex/pkg/streaming"
)

Expand Down Expand Up @@ -184,7 +182,7 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) {
}

clientStream := newStream(ctx, st, scm, kc, ri, ri.Invocation().MethodInfo().StreamingMode(),
sendEP, recvEP, kc.opt.StreamOptions.EventHandler, grpcSendEP, grpcRecvEP)
sendEP, recvEP, grpcSendEP, grpcRecvEP)
rresp := resp.(*streaming.Result)
rresp.ClientStream = clientStream
rresp.Stream = clientStream.GetGRPCStream()
Expand All @@ -194,12 +192,11 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) {

type stream struct {
streaming.ClientStream
grpcStream *grpcStream
ctx context.Context
scm *remotecli.StreamConnManager
kc *kClient
ri rpcinfo.RPCInfo
eventHandler internal_stream.StreamEventHandler
grpcStream *grpcStream
ctx context.Context
scm *remotecli.StreamConnManager
kc *kClient
ri rpcinfo.RPCInfo

recv cep.StreamRecvEndpoint
send cep.StreamSendEndpoint
Expand All @@ -215,7 +212,7 @@ var (
)

func newStream(ctx context.Context, s streaming.ClientStream, scm *remotecli.StreamConnManager, kc *kClient, ri rpcinfo.RPCInfo, mode serviceinfo.StreamingMode,
sendEP cep.StreamSendEndpoint, recvEP cep.StreamRecvEndpoint, eventHandler internal_stream.StreamEventHandler, grpcSendEP endpoint.SendEndpoint, grpcRecvEP endpoint.RecvEndpoint,
sendEP cep.StreamSendEndpoint, recvEP cep.StreamRecvEndpoint, grpcSendEP endpoint.SendEndpoint, grpcRecvEP endpoint.RecvEndpoint,
) *stream {
st := &stream{
ClientStream: s,
Expand All @@ -226,7 +223,6 @@ func newStream(ctx context.Context, s streaming.ClientStream, scm *remotecli.Str
streamingMode: mode,
recv: recvEP,
send: sendEP,
eventHandler: eventHandler,
}
if grpcStreamGetter, ok := s.(streaming.GRPCStreamGetter); ok {
if grpcStream := grpcStreamGetter.GetGRPCStream(); grpcStream != nil {
Expand Down Expand Up @@ -267,15 +263,19 @@ func (s *stream) RecvMsg(ctx context.Context, m interface{}) (err error) {
// And it should be returned to the calling business code for error handling
err = s.ri.Invocation().BizStatusErr()
}
if s.eventHandler != nil {
s.eventHandler(s.ctx, stats.StreamRecv, err)
}
s.handleStreamRecvEvent(err)
if err != nil || s.streamingMode == serviceinfo.StreamingClient {
s.DoFinish(err)
}
return
}

func (s *stream) handleStreamRecvEvent(err error) {
s.kc.opt.TracerCtl.HandleStreamRecvEvent(s.ctx, s.ri, rpcinfo.StreamRecvEvent{
Err: err,
})
}

// SendMsg sends a message to the server.
// If an error is returned, stream.DoFinish() will be called to record the end of stream
func (s *stream) SendMsg(ctx context.Context, m interface{}) (err error) {
Expand All @@ -287,15 +287,19 @@ func (s *stream) SendMsg(ctx context.Context, m interface{}) (err error) {
}
}
err = s.send(ctx, s.ClientStream, m)
if s.eventHandler != nil {
s.eventHandler(s.ctx, stats.StreamSend, err)
}
s.handleStreamSendEvent(err)
if err != nil {
s.DoFinish(err)
}
return
}

func (s *stream) handleStreamSendEvent(err error) {
s.kc.opt.TracerCtl.HandleStreamSendEvent(s.ctx, s.ri, rpcinfo.StreamSendEvent{
Err: err,
})
}

// DoFinish implements the streaming.WithDoFinish interface, and it records the end of stream
// It will release the connection.
func (s *stream) DoFinish(err error) {
Expand Down Expand Up @@ -353,9 +357,7 @@ func (s *grpcStream) RecvMsg(m interface{}) (err error) {
// And it should be returned to the calling business code for error handling
err = s.st.ri.Invocation().BizStatusErr()
}
if s.st.eventHandler != nil {
s.st.eventHandler(s.st.ctx, stats.StreamRecv, err)
}
s.st.handleStreamRecvEvent(err)
if err != nil || s.st.streamingMode == serviceinfo.StreamingClient {
s.st.DoFinish(err)
}
Expand All @@ -364,9 +366,7 @@ func (s *grpcStream) RecvMsg(m interface{}) (err error) {

func (s *grpcStream) SendMsg(m interface{}) (err error) {
err = s.sendEndpoint(s.Stream, m)
if s.st.eventHandler != nil {
s.st.eventHandler(s.st.ctx, stats.StreamSend, err)
}
s.st.handleStreamSendEvent(err)
if err != nil {
s.st.DoFinish(err)
}
Expand Down
45 changes: 26 additions & 19 deletions client/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ func TestStreaming(t *testing.T) {
func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
return stream.RecvMsg(ctx, message)
},
nil,
func(stream streaming.Stream, message interface{}) (err error) {
return stream.SendMsg(message)
},
Expand Down Expand Up @@ -263,7 +262,6 @@ func Test_newStream(t *testing.T) {
}, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
return recvErr
},
nil,
func(stream streaming.Stream, message interface{}) (err error) {
return sendErr
}, func(stream streaming.Stream, message interface{}) (err error) {
Expand Down Expand Up @@ -306,7 +304,7 @@ func Test_stream_Header(t *testing.T) {
cr := mock_remote.NewMockConnReleaser(ctrl)
cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0)
scr := remotecli.NewStreamConnManager(cr)
s := newStream(ctx, st, scr, &kClient{}, nil, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil)
s := newStream(ctx, st, scr, &kClient{}, nil, serviceinfo.StreamingBidirectional, nil, nil, nil, nil)
md, err := s.Header()

test.Assert(t, err == nil)
Expand Down Expand Up @@ -339,7 +337,7 @@ func Test_stream_Header(t *testing.T) {
cr := mock_remote.NewMockConnReleaser(ctrl)
cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1)
scr := remotecli.NewStreamConnManager(cr)
s := newStream(ctx, st, scr, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil)
s := newStream(ctx, st, scr, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil)
md, err := s.Header()

test.Assert(t, err == headerErr)
Expand All @@ -357,9 +355,12 @@ func Test_stream_RecvMsg(t *testing.T) {
cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0)
scm := remotecli.NewStreamConnManager(cr)
mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), nil, nil)
s := newStream(ctx, &mockStream{}, scm, &kClient{}, mockRPCInfo, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
kc := &kClient{
opt: client.NewOptions(nil),
}
s := newStream(ctx, &mockStream{}, scm, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
return nil
}, nil, nil,
}, nil,
func(stream streaming.Stream, message interface{}) (err error) {
return nil
},
Expand Down Expand Up @@ -394,7 +395,7 @@ func Test_stream_RecvMsg(t *testing.T) {
scm := remotecli.NewStreamConnManager(cr)
s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingClient, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
return nil
}, nil, nil,
}, nil,
func(stream streaming.Stream, message interface{}) (err error) {
return nil
},
Expand Down Expand Up @@ -430,7 +431,7 @@ func Test_stream_RecvMsg(t *testing.T) {
scm := remotecli.NewStreamConnManager(cr)
s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
return recvErr
}, nil, nil,
}, nil,
func(stream streaming.Stream, message interface{}) (err error) {
return recvErr
},
Expand All @@ -449,9 +450,12 @@ func Test_stream_SendMsg(t *testing.T) {
cr := mock_remote.NewMockConnReleaser(ctrl)
cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0)
scm := remotecli.NewStreamConnManager(cr)
s := newStream(ctx, &mockStream{}, scm, &kClient{}, nil, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
kc := &kClient{
opt: client.NewOptions(nil),
}
s := newStream(ctx, &mockStream{}, scm, kc, nil, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
return nil
}, nil, nil,
}, nil,
func(stream streaming.Stream, message interface{}) (err error) {
return nil
},
Expand Down Expand Up @@ -490,7 +494,7 @@ func Test_stream_SendMsg(t *testing.T) {

s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
return sendErr
}, nil, nil,
}, nil,
func(stream streaming.Stream, message interface{}) (err error) {
return sendErr
},
Expand All @@ -515,7 +519,7 @@ func Test_stream_Close(t *testing.T) {
called = true
return nil
},
}, scm, &kClient{}, nil, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil)
}, scm, &kClient{}, nil, serviceinfo.StreamingBidirectional, nil, nil, nil, nil)

err := s.CloseSend(context.Background())

Expand Down Expand Up @@ -543,7 +547,7 @@ func Test_stream_DoFinish(t *testing.T) {
cr := mock_remote.NewMockConnReleaser(ctrl)
cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1)
scm := remotecli.NewStreamConnManager(cr)
s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil)
s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil)

finishCalled := false
err := errors.New("any err")
Expand Down Expand Up @@ -573,7 +577,7 @@ func Test_stream_DoFinish(t *testing.T) {
cr := mock_remote.NewMockConnReleaser(ctrl)
cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1)
scm := remotecli.NewStreamConnManager(cr)
s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil)
s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil)

finishCalled := false
err := errors.New("any err")
Expand Down Expand Up @@ -603,7 +607,7 @@ func Test_stream_DoFinish(t *testing.T) {
cr := mock_remote.NewMockConnReleaser(ctrl)
cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1)
scm := remotecli.NewStreamConnManager(cr)
s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil)
s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil)

finishCalled := false
var err error
Expand Down Expand Up @@ -633,7 +637,7 @@ func Test_stream_DoFinish(t *testing.T) {
cr := mock_remote.NewMockConnReleaser(ctrl)
cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1)
scm := remotecli.NewStreamConnManager(cr)
s := newStream(st.ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil)
s := newStream(st.ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil)

finishCalled := false
expectedErr := errors.New("error")
Expand Down Expand Up @@ -676,7 +680,10 @@ func TestContextFallback(t *testing.T) {
return nil
},
}
st := newStream(context.Background(), mockSt, nil, nil, mockRPCInfo, serviceinfo.StreamingBidirectional, sendEndpoint, recvEndpoint, nil, nil, nil)
kc := &kClient{
opt: client.NewOptions(nil),
}
st := newStream(context.Background(), mockSt, nil, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, sendEndpoint, recvEndpoint, nil, nil)
err := st.RecvMsg(context.Background(), nil)
test.Assert(t, err == nil)
err = st.SendMsg(context.Background(), nil)
Expand All @@ -694,15 +701,15 @@ func TestContextFallback(t *testing.T) {
return nil
},
}
st = newStream(context.Background(), mockSt, nil, nil, mockRPCInfo, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
st = newStream(context.Background(), mockSt, nil, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
ri := rpcinfo.GetRPCInfo(ctx)
test.Assert(t, ri == mockRPCInfo)
return sendEndpoint(ctx, stream, message)
}, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) {
ri := rpcinfo.GetRPCInfo(ctx)
test.Assert(t, ri == mockRPCInfo)
return recvEndpoint(ctx, stream, message)
}, nil, nil, nil)
}, nil, nil)
err = st.RecvMsg(context.Background(), nil)
test.Assert(t, err == nil)
err = st.SendMsg(context.Background(), nil)
Expand Down
Loading