Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ type StreamOption struct {
}

type StreamOptions struct {
RecvTimeoutConfig streaming.TimeoutConfig
StreamEventHandlers []rpcinfo.ServerStreamEventHandler
StreamMiddlewares []sep.StreamMiddleware
StreamMiddlewareBuilders []sep.StreamMiddlewareBuilder
Expand Down
16 changes: 16 additions & 0 deletions pkg/streaming/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,19 @@ type Result struct {
type GRPCStreamGetter interface {
GetGRPCStream() Stream
}

// RecvTimeoutConfigSetter allows dynamically setting RecvTimeoutConfig on a server stream.
// Server handlers can use this to override the global recv timeout on a per-stream basis.
type RecvTimeoutConfigSetter interface {
SetRecvTimeoutConfig(cfg TimeoutConfig)
}

// SetRecvTimeoutConfig sets the RecvTimeoutConfig on a server stream if it supports it.
// Returns true if the config was set successfully.
func SetRecvTimeoutConfig(stream ServerStream, cfg TimeoutConfig) bool {
if setter, ok := stream.(RecvTimeoutConfigSetter); ok {
setter.SetRecvTimeoutConfig(cfg)
return true
}
return false
}
11 changes: 11 additions & 0 deletions server/option_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/endpoint/sep"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/streaming"
"github.com/cloudwego/kitex/pkg/utils"
)

Expand Down Expand Up @@ -104,6 +105,16 @@ func WithStreamEventHandler(hdl rpcinfo.ServerStreamEventHandler) StreamOption {
}}
}

// WithStreamRecvTimeoutConfig sets the default recv timeout for all server-side streaming handlers.
// Handlers can dynamically override this per-stream using streaming.SetRecvTimeoutConfig(stream, cfg).
func WithStreamRecvTimeoutConfig(cfg streaming.TimeoutConfig) StreamOption {
return StreamOption{F: func(o *internal_server.StreamOptions, di *utils.Slice) {
di.Push(fmt.Sprintf("WithStreamRecvTimeoutConfig(%+v)", cfg))

o.RecvTimeoutConfig = cfg
}}
}

// Deprecated: Use WithStreamRecvMiddleware instead, this requires enabling the streamx feature.
func WithRecvMiddleware(mw endpoint.RecvMiddleware) Option {
mwb := func(ctx context.Context) endpoint.RecvMiddleware {
Expand Down
16 changes: 16 additions & 0 deletions server/option_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ package server

import (
"testing"
"time"

"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/streaming"
)

func Test_WithStreamEventHandler(t *testing.T) {
Expand All @@ -37,3 +39,17 @@ func Test_WithStreamEventHandler(t *testing.T) {
iSvr := svr.(*server)
test.Assert(t, len(iSvr.opt.StreamOptions.StreamEventHandlers) == 2, iSvr.opt)
}

func Test_WithStreamRecvTimeoutConfig(t *testing.T) {
cfg := streaming.TimeoutConfig{
Timeout: 3 * time.Second,
DisableCancelRemote: true,
}
svr, _ := NewTestServer(
WithStreamOptions(
WithStreamRecvTimeoutConfig(cfg),
))
iSvr := svr.(*server)
test.Assert(t, iSvr.opt.StreamOptions.RecvTimeoutConfig.Timeout == 3*time.Second, iSvr.opt)
test.Assert(t, iSvr.opt.StreamOptions.RecvTimeoutConfig.DisableCancelRemote, iSvr.opt)
}
86 changes: 78 additions & 8 deletions server/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ package server

import (
"context"
"fmt"
"runtime/debug"
"time"

"github.com/bytedance/gopkg/util/gopool"

"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/endpoint/sep"
"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/streaming"
)
Expand All @@ -34,9 +40,10 @@ func (s *server) wrapStreamMiddleware() endpoint.Middleware {
recvEP := s.opt.StreamOptions.BuildRecvChain()
grpcSendEP := s.opt.Streaming.BuildSendInvokeChain()
grpcRecvEP := s.opt.Streaming.BuildRecvInvokeChain()
recvTmCfg := s.opt.StreamOptions.RecvTimeoutConfig
return func(ctx context.Context, req, resp interface{}) (err error) {
if st, ok := req.(*streaming.Args); ok {
nst := newStream(ctx, st.ServerStream, sendEP, recvEP, s.opt.TracerCtl, grpcSendEP, grpcRecvEP)
nst := newStream(ctx, st.ServerStream, sendEP, recvEP, s.opt.TracerCtl, grpcSendEP, grpcRecvEP, recvTmCfg)
st.ServerStream = nst
st.Stream = nst.GetGRPCStream()
}
Expand All @@ -47,6 +54,7 @@ func (s *server) wrapStreamMiddleware() endpoint.Middleware {

func newStream(ctx context.Context, s streaming.ServerStream, sendEP sep.StreamSendEndpoint, recvEP sep.StreamRecvEndpoint,
traceCtl *rpcinfo.TraceController, grpcSendEP endpoint.SendEndpoint, grpcRecvEP endpoint.RecvEndpoint,
recvTmCfg streaming.TimeoutConfig,
) *stream {
ri := rpcinfo.GetRPCInfo(ctx)
st := &stream{
Expand All @@ -56,10 +64,11 @@ func newStream(ctx context.Context, s streaming.ServerStream, sendEP sep.StreamS
recv: recvEP,
send: sendEP,
traceCtl: traceCtl,
recvTmCfg: recvTmCfg,
}
if grpcStreamGetter, ok := s.(streaming.GRPCStreamGetter); ok {
if grpcStream := grpcStreamGetter.GetGRPCStream(); grpcStream != nil {
st.grpcStream = newGRPCStream(grpcStream, grpcSendEP, grpcRecvEP)
st.grpcStream = newGRPCStream(grpcStream, grpcSendEP, grpcRecvEP, recvTmCfg)
st.grpcStream.st = st
}
}
Expand All @@ -73,11 +82,15 @@ type stream struct {
ri rpcinfo.RPCInfo
traceCtl *rpcinfo.TraceController

recv sep.StreamRecvEndpoint
send sep.StreamSendEndpoint
recv sep.StreamRecvEndpoint
send sep.StreamSendEndpoint
recvTmCfg streaming.TimeoutConfig
}

var _ streaming.GRPCStreamGetter = (*stream)(nil)
var (
_ streaming.GRPCStreamGetter = (*stream)(nil)
_ streaming.RecvTimeoutConfigSetter = (*stream)(nil)
)

func (s *stream) GetGRPCStream() streaming.Stream {
if s.grpcStream == nil {
Expand All @@ -86,13 +99,30 @@ func (s *stream) GetGRPCStream() streaming.Stream {
return s.grpcStream
}

// SetRecvTimeoutConfig allows handlers to dynamically override the recv timeout for this stream.
func (s *stream) SetRecvTimeoutConfig(cfg streaming.TimeoutConfig) {
s.recvTmCfg = cfg
if s.grpcStream != nil {
s.grpcStream.recvTmCfg = cfg
}
}

// RecvMsg receives a message from the client.
func (s *stream) RecvMsg(ctx context.Context, m interface{}) (err error) {
err = s.recv(ctx, s.ServerStream, m)
err = s.recvWithTimeout(ctx, m)
s.handleStreamRecvEvent(err)
return
}

func (s *stream) recvWithTimeout(ctx context.Context, m interface{}) error {
if s.recvTmCfg.Timeout <= 0 {
return s.recv(ctx, s.ServerStream, m)
}
return callWithTimeout(s.recvTmCfg, func() error {
return s.recv(ctx, s.ServerStream, m)
})
}

func (s *stream) handleStreamRecvEvent(err error) {
s.traceCtl.HandleStreamRecvEvent(s.ctx, s.ri, rpcinfo.StreamRecvEvent{
Err: err,
Expand All @@ -112,11 +142,14 @@ func (s *stream) handleStreamSendEvent(err error) {
})
}

func newGRPCStream(st streaming.Stream, sendEP endpoint.SendEndpoint, recvEP endpoint.RecvEndpoint) *grpcStream {
func newGRPCStream(st streaming.Stream, sendEP endpoint.SendEndpoint, recvEP endpoint.RecvEndpoint,
recvTmCfg streaming.TimeoutConfig,
) *grpcStream {
return &grpcStream{
Stream: st,
sendEndpoint: sendEP,
recvEndpoint: recvEP,
recvTmCfg: recvTmCfg,
}
}

Expand All @@ -127,14 +160,24 @@ type grpcStream struct {

sendEndpoint endpoint.SendEndpoint
recvEndpoint endpoint.RecvEndpoint
recvTmCfg streaming.TimeoutConfig
}

func (s *grpcStream) RecvMsg(m interface{}) (err error) {
err = s.recvEndpoint(s.Stream, m)
err = s.recvWithTimeout(m)
s.st.handleStreamRecvEvent(err)
return
}

func (s *grpcStream) recvWithTimeout(m interface{}) error {
if s.recvTmCfg.Timeout <= 0 {
return s.recvEndpoint(s.Stream, m)
}
return callWithTimeout(s.recvTmCfg, func() error {
return s.recvEndpoint(s.Stream, m)
})
}

func (s *grpcStream) SendMsg(m interface{}) (err error) {
err = s.sendEndpoint(s.Stream, m)
s.st.handleStreamSendEvent(err)
Expand Down Expand Up @@ -172,3 +215,30 @@ type gRPCCompatibleServerStream struct {
func (gs gRPCCompatibleServerStream) GetGRPCStream() streaming.Stream {
return gs.st
}

// callWithTimeout wraps a blocking call with a timeout.
// Unlike the client-side version, no cancel function is needed because
// ServerStream does not support CancelWithErr. When timeout occurs,
// the error is returned to the handler; the handler returns, and the
// framework closes the stream, which unblocks the stuck recv goroutine.
func callWithTimeout(tmCfg streaming.TimeoutConfig, call func() error) error {
timer := time.NewTimer(tmCfg.Timeout)
defer timer.Stop()
finishChan := make(chan error, 1)
gopool.Go(func() {
var callErr error
defer func() {
if r := recover(); r != nil {
callErr = fmt.Errorf("server stream Recv panic, panic=%v, stack=%s", r, debug.Stack())
}
finishChan <- callErr
}()
callErr = call()
})
select {
case <-timer.C:
return kerrors.ErrRPCTimeout.WithCause(fmt.Errorf("server stream Recv timeout, timeout config=%+v", tmCfg))
case callErr := <-finishChan:
return callErr
}
}
Loading
Loading