From e255a29e62070d319a639f1b889db1a2828e57f6 Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Tue, 19 May 2026 15:20:28 +0800 Subject: [PATCH] feat: support server-side streaming Recv timeout config --- internal/server/option.go | 1 + pkg/streaming/streaming.go | 16 +++++ server/option_stream.go | 11 +++ server/option_stream_test.go | 16 +++++ server/stream.go | 86 ++++++++++++++++++++--- server/stream_test.go | 129 +++++++++++++++++++++++++++++++++++ 6 files changed, 251 insertions(+), 8 deletions(-) diff --git a/internal/server/option.go b/internal/server/option.go index 470b19370b..bd572d96a7 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -77,6 +77,7 @@ type StreamOption struct { } type StreamOptions struct { + RecvTimeoutConfig streaming.TimeoutConfig StreamEventHandlers []rpcinfo.ServerStreamEventHandler StreamMiddlewares []sep.StreamMiddleware StreamMiddlewareBuilders []sep.StreamMiddlewareBuilder diff --git a/pkg/streaming/streaming.go b/pkg/streaming/streaming.go index 9f923e95aa..18799a0706 100644 --- a/pkg/streaming/streaming.go +++ b/pkg/streaming/streaming.go @@ -91,3 +91,19 @@ type Result struct { type GRPCStreamGetter interface { GetGRPCStream() Stream } + +// RecvTimeoutConfigSetter allows dynamically setting RecvTimeoutConfig on a server stream. +// Server handlers can use this to override the global recv timeout on a per-stream basis. +type RecvTimeoutConfigSetter interface { + SetRecvTimeoutConfig(cfg TimeoutConfig) +} + +// SetRecvTimeoutConfig sets the RecvTimeoutConfig on a server stream if it supports it. +// Returns true if the config was set successfully. +func SetRecvTimeoutConfig(stream ServerStream, cfg TimeoutConfig) bool { + if setter, ok := stream.(RecvTimeoutConfigSetter); ok { + setter.SetRecvTimeoutConfig(cfg) + return true + } + return false +} diff --git a/server/option_stream.go b/server/option_stream.go index e794567eac..e01e43bae3 100644 --- a/server/option_stream.go +++ b/server/option_stream.go @@ -24,6 +24,7 @@ import ( "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/endpoint/sep" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/streaming" "github.com/cloudwego/kitex/pkg/utils" ) @@ -104,6 +105,16 @@ func WithStreamEventHandler(hdl rpcinfo.ServerStreamEventHandler) StreamOption { }} } +// WithStreamRecvTimeoutConfig sets the default recv timeout for all server-side streaming handlers. +// Handlers can dynamically override this per-stream using streaming.SetRecvTimeoutConfig(stream, cfg). +func WithStreamRecvTimeoutConfig(cfg streaming.TimeoutConfig) StreamOption { + return StreamOption{F: func(o *internal_server.StreamOptions, di *utils.Slice) { + di.Push(fmt.Sprintf("WithStreamRecvTimeoutConfig(%+v)", cfg)) + + o.RecvTimeoutConfig = cfg + }} +} + // Deprecated: Use WithStreamRecvMiddleware instead, this requires enabling the streamx feature. func WithRecvMiddleware(mw endpoint.RecvMiddleware) Option { mwb := func(ctx context.Context) endpoint.RecvMiddleware { diff --git a/server/option_stream_test.go b/server/option_stream_test.go index 39855fc659..ce9aef492d 100644 --- a/server/option_stream_test.go +++ b/server/option_stream_test.go @@ -18,9 +18,11 @@ package server import ( "testing" + "time" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/streaming" ) func Test_WithStreamEventHandler(t *testing.T) { @@ -37,3 +39,17 @@ func Test_WithStreamEventHandler(t *testing.T) { iSvr := svr.(*server) test.Assert(t, len(iSvr.opt.StreamOptions.StreamEventHandlers) == 2, iSvr.opt) } + +func Test_WithStreamRecvTimeoutConfig(t *testing.T) { + cfg := streaming.TimeoutConfig{ + Timeout: 3 * time.Second, + DisableCancelRemote: true, + } + svr, _ := NewTestServer( + WithStreamOptions( + WithStreamRecvTimeoutConfig(cfg), + )) + iSvr := svr.(*server) + test.Assert(t, iSvr.opt.StreamOptions.RecvTimeoutConfig.Timeout == 3*time.Second, iSvr.opt) + test.Assert(t, iSvr.opt.StreamOptions.RecvTimeoutConfig.DisableCancelRemote, iSvr.opt) +} diff --git a/server/stream.go b/server/stream.go index 97641f9bbc..9d0982ad30 100644 --- a/server/stream.go +++ b/server/stream.go @@ -18,9 +18,15 @@ package server import ( "context" + "fmt" + "runtime/debug" + "time" + + "github.com/bytedance/gopkg/util/gopool" "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/endpoint/sep" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streaming" ) @@ -34,9 +40,10 @@ func (s *server) wrapStreamMiddleware() endpoint.Middleware { recvEP := s.opt.StreamOptions.BuildRecvChain() grpcSendEP := s.opt.Streaming.BuildSendInvokeChain() grpcRecvEP := s.opt.Streaming.BuildRecvInvokeChain() + recvTmCfg := s.opt.StreamOptions.RecvTimeoutConfig return func(ctx context.Context, req, resp interface{}) (err error) { if st, ok := req.(*streaming.Args); ok { - nst := newStream(ctx, st.ServerStream, sendEP, recvEP, s.opt.TracerCtl, grpcSendEP, grpcRecvEP) + nst := newStream(ctx, st.ServerStream, sendEP, recvEP, s.opt.TracerCtl, grpcSendEP, grpcRecvEP, recvTmCfg) st.ServerStream = nst st.Stream = nst.GetGRPCStream() } @@ -47,6 +54,7 @@ func (s *server) wrapStreamMiddleware() endpoint.Middleware { func newStream(ctx context.Context, s streaming.ServerStream, sendEP sep.StreamSendEndpoint, recvEP sep.StreamRecvEndpoint, traceCtl *rpcinfo.TraceController, grpcSendEP endpoint.SendEndpoint, grpcRecvEP endpoint.RecvEndpoint, + recvTmCfg streaming.TimeoutConfig, ) *stream { ri := rpcinfo.GetRPCInfo(ctx) st := &stream{ @@ -56,10 +64,11 @@ func newStream(ctx context.Context, s streaming.ServerStream, sendEP sep.StreamS recv: recvEP, send: sendEP, traceCtl: traceCtl, + recvTmCfg: recvTmCfg, } if grpcStreamGetter, ok := s.(streaming.GRPCStreamGetter); ok { if grpcStream := grpcStreamGetter.GetGRPCStream(); grpcStream != nil { - st.grpcStream = newGRPCStream(grpcStream, grpcSendEP, grpcRecvEP) + st.grpcStream = newGRPCStream(grpcStream, grpcSendEP, grpcRecvEP, recvTmCfg) st.grpcStream.st = st } } @@ -73,11 +82,15 @@ type stream struct { ri rpcinfo.RPCInfo traceCtl *rpcinfo.TraceController - recv sep.StreamRecvEndpoint - send sep.StreamSendEndpoint + recv sep.StreamRecvEndpoint + send sep.StreamSendEndpoint + recvTmCfg streaming.TimeoutConfig } -var _ streaming.GRPCStreamGetter = (*stream)(nil) +var ( + _ streaming.GRPCStreamGetter = (*stream)(nil) + _ streaming.RecvTimeoutConfigSetter = (*stream)(nil) +) func (s *stream) GetGRPCStream() streaming.Stream { if s.grpcStream == nil { @@ -86,13 +99,30 @@ func (s *stream) GetGRPCStream() streaming.Stream { return s.grpcStream } +// SetRecvTimeoutConfig allows handlers to dynamically override the recv timeout for this stream. +func (s *stream) SetRecvTimeoutConfig(cfg streaming.TimeoutConfig) { + s.recvTmCfg = cfg + if s.grpcStream != nil { + s.grpcStream.recvTmCfg = cfg + } +} + // RecvMsg receives a message from the client. func (s *stream) RecvMsg(ctx context.Context, m interface{}) (err error) { - err = s.recv(ctx, s.ServerStream, m) + err = s.recvWithTimeout(ctx, m) s.handleStreamRecvEvent(err) return } +func (s *stream) recvWithTimeout(ctx context.Context, m interface{}) error { + if s.recvTmCfg.Timeout <= 0 { + return s.recv(ctx, s.ServerStream, m) + } + return callWithTimeout(s.recvTmCfg, func() error { + return s.recv(ctx, s.ServerStream, m) + }) +} + func (s *stream) handleStreamRecvEvent(err error) { s.traceCtl.HandleStreamRecvEvent(s.ctx, s.ri, rpcinfo.StreamRecvEvent{ Err: err, @@ -112,11 +142,14 @@ func (s *stream) handleStreamSendEvent(err error) { }) } -func newGRPCStream(st streaming.Stream, sendEP endpoint.SendEndpoint, recvEP endpoint.RecvEndpoint) *grpcStream { +func newGRPCStream(st streaming.Stream, sendEP endpoint.SendEndpoint, recvEP endpoint.RecvEndpoint, + recvTmCfg streaming.TimeoutConfig, +) *grpcStream { return &grpcStream{ Stream: st, sendEndpoint: sendEP, recvEndpoint: recvEP, + recvTmCfg: recvTmCfg, } } @@ -127,14 +160,24 @@ type grpcStream struct { sendEndpoint endpoint.SendEndpoint recvEndpoint endpoint.RecvEndpoint + recvTmCfg streaming.TimeoutConfig } func (s *grpcStream) RecvMsg(m interface{}) (err error) { - err = s.recvEndpoint(s.Stream, m) + err = s.recvWithTimeout(m) s.st.handleStreamRecvEvent(err) return } +func (s *grpcStream) recvWithTimeout(m interface{}) error { + if s.recvTmCfg.Timeout <= 0 { + return s.recvEndpoint(s.Stream, m) + } + return callWithTimeout(s.recvTmCfg, func() error { + return s.recvEndpoint(s.Stream, m) + }) +} + func (s *grpcStream) SendMsg(m interface{}) (err error) { err = s.sendEndpoint(s.Stream, m) s.st.handleStreamSendEvent(err) @@ -172,3 +215,30 @@ type gRPCCompatibleServerStream struct { func (gs gRPCCompatibleServerStream) GetGRPCStream() streaming.Stream { return gs.st } + +// callWithTimeout wraps a blocking call with a timeout. +// Unlike the client-side version, no cancel function is needed because +// ServerStream does not support CancelWithErr. When timeout occurs, +// the error is returned to the handler; the handler returns, and the +// framework closes the stream, which unblocks the stuck recv goroutine. +func callWithTimeout(tmCfg streaming.TimeoutConfig, call func() error) error { + timer := time.NewTimer(tmCfg.Timeout) + defer timer.Stop() + finishChan := make(chan error, 1) + gopool.Go(func() { + var callErr error + defer func() { + if r := recover(); r != nil { + callErr = fmt.Errorf("server stream Recv panic, panic=%v, stack=%s", r, debug.Stack()) + } + finishChan <- callErr + }() + callErr = call() + }) + select { + case <-timer.C: + return kerrors.ErrRPCTimeout.WithCause(fmt.Errorf("server stream Recv timeout, timeout config=%+v", tmCfg)) + case callErr := <-finishChan: + return callErr + } +} diff --git a/server/stream_test.go b/server/stream_test.go index 8c0e0ed3e0..741d803e1f 100644 --- a/server/stream_test.go +++ b/server/stream_test.go @@ -18,11 +18,14 @@ package server import ( "context" + "errors" "testing" + "time" internal_server "github.com/cloudwego/kitex/internal/server" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" @@ -199,3 +202,129 @@ func Test_gRPCCompatibleServerStream_with_middleware(t *testing.T) { type wrappedStreamByMiddleware struct { streaming.Stream } + +// blockingServerStream is a mock ServerStream whose RecvMsg blocks until context is done. +type blockingServerStream struct { + streaming.ServerStream + recvCh chan struct{} // closed when recv should unblock +} + +func newBlockingServerStream() *blockingServerStream { + return &blockingServerStream{recvCh: make(chan struct{})} +} + +func (s *blockingServerStream) RecvMsg(ctx context.Context, m interface{}) error { + <-s.recvCh + return nil +} + +func (s *blockingServerStream) unblock() { + close(s.recvCh) +} + +// immediateServerStream is a mock ServerStream whose RecvMsg returns immediately. +type immediateServerStream struct { + streaming.ServerStream + err error +} + +func (s *immediateServerStream) RecvMsg(ctx context.Context, m interface{}) error { + return s.err +} + +func newTestStream(ss streaming.ServerStream, recvTmCfg streaming.TimeoutConfig) *stream { + recvEP := func(ctx context.Context, stream streaming.ServerStream, message interface{}) error { + return stream.RecvMsg(ctx, message) + } + sendEP := func(ctx context.Context, stream streaming.ServerStream, message interface{}) error { + return stream.SendMsg(ctx, message) + } + return &stream{ + ServerStream: ss, + ctx: context.Background(), + ri: rpcinfo.NewRPCInfo(nil, nil, nil, nil, nil), + recv: recvEP, + send: sendEP, + traceCtl: &rpcinfo.TraceController{}, + recvTmCfg: recvTmCfg, + } +} + +func Test_stream_RecvMsg_withTimeout(t *testing.T) { + ss := newBlockingServerStream() + defer ss.unblock() + st := newTestStream(ss, streaming.TimeoutConfig{Timeout: 50 * time.Millisecond}) + + start := time.Now() + err := st.RecvMsg(context.Background(), nil) + elapsed := time.Since(start) + + test.Assert(t, err != nil, "expected timeout error") + test.Assert(t, errors.Is(err, kerrors.ErrRPCTimeout), err) + test.Assert(t, elapsed >= 50*time.Millisecond, elapsed) + test.Assert(t, elapsed < 500*time.Millisecond, elapsed) +} + +func Test_stream_RecvMsg_noTimeout(t *testing.T) { + expectedErr := errors.New("test recv error") + ss := &immediateServerStream{err: expectedErr} + st := newTestStream(ss, streaming.TimeoutConfig{}) // no timeout + + err := st.RecvMsg(context.Background(), nil) + test.Assert(t, err == expectedErr, err) +} + +func Test_stream_RecvMsg_noTimeout_success(t *testing.T) { + ss := &immediateServerStream{err: nil} + st := newTestStream(ss, streaming.TimeoutConfig{}) // no timeout + + err := st.RecvMsg(context.Background(), nil) + test.Assert(t, err == nil, err) +} + +func Test_stream_RecvMsg_withTimeout_noBlock(t *testing.T) { + ss := &immediateServerStream{err: nil} + st := newTestStream(ss, streaming.TimeoutConfig{Timeout: 1 * time.Second}) + + err := st.RecvMsg(context.Background(), nil) + test.Assert(t, err == nil, err) +} + +func Test_stream_SetRecvTimeoutConfig(t *testing.T) { + ss := newBlockingServerStream() + defer ss.unblock() + st := newTestStream(ss, streaming.TimeoutConfig{}) // initially no timeout + + // Set timeout dynamically + st.SetRecvTimeoutConfig(streaming.TimeoutConfig{Timeout: 50 * time.Millisecond}) + + start := time.Now() + err := st.RecvMsg(context.Background(), nil) + elapsed := time.Since(start) + + test.Assert(t, errors.Is(err, kerrors.ErrRPCTimeout), err) + test.Assert(t, elapsed >= 50*time.Millisecond, elapsed) +} + +func Test_streaming_SetRecvTimeoutConfig_helper(t *testing.T) { + ss := &immediateServerStream{} + st := newTestStream(ss, streaming.TimeoutConfig{}) + + // stream implements RecvTimeoutConfigSetter + ok := streaming.SetRecvTimeoutConfig(st, streaming.TimeoutConfig{Timeout: 1 * time.Second}) + test.Assert(t, ok, "expected SetRecvTimeoutConfig to return true") + test.Assert(t, st.recvTmCfg.Timeout == 1*time.Second, st.recvTmCfg) + + // plain ServerStream does not implement it + ok = streaming.SetRecvTimeoutConfig(ss, streaming.TimeoutConfig{Timeout: 1 * time.Second}) + test.Assert(t, !ok, "expected SetRecvTimeoutConfig to return false for plain stream") +} + +func Test_callWithTimeout_panic_recovery(t *testing.T) { + tmCfg := streaming.TimeoutConfig{Timeout: 1 * time.Second} + err := callWithTimeout(tmCfg, func() error { + panic("test panic") + }) + test.Assert(t, err != nil, "expected error from panic") + test.Assert(t, !errors.Is(err, kerrors.ErrRPCTimeout), "panic error should not be ErrRPCTimeout") +}