Skip to content

Commit f45155b

Browse files
authored
[Feature] Add timeout for apiserver grpc server (#3427)
1 parent 2f2c1a2 commit f45155b

File tree

4 files changed

+119
-9
lines changed

4 files changed

+119
-9
lines changed

apiserver/cmd/main.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"path"
1212
"strings"
1313
"sync/atomic"
14+
"time"
1415

1516
assetfs "github.com/elazarl/go-bindata-assetfs"
1617
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
@@ -27,6 +28,7 @@ import (
2728
"github.com/ray-project/kuberay/apiserver/pkg/manager"
2829
"github.com/ray-project/kuberay/apiserver/pkg/server"
2930
"github.com/ray-project/kuberay/apiserver/pkg/swagger"
31+
"github.com/ray-project/kuberay/apiserver/pkg/util"
3032
api "github.com/ray-project/kuberay/proto/go_client"
3133
)
3234

@@ -36,6 +38,7 @@ var (
3638
collectMetricsFlag = flag.Bool("collectMetricsFlag", true, "Whether to collect Prometheus metrics in API server.")
3739
logFile = flag.String("logFilePath", "", "Synchronize logs to local file")
3840
localSwaggerPath = flag.String("localSwaggerPath", "", "Specify the root directory for `*.swagger.json` the swagger files.")
41+
grpcTimeout = flag.Duration("grpc_timeout", util.GRPCServerDefaultTimeout, "gRPC server timeout duration")
3942
healthy int32
4043
)
4144

@@ -54,7 +57,8 @@ func main() {
5457
resourceManager := manager.NewResourceManager(&clientManager)
5558

5659
atomic.StoreInt32(&healthy, 1)
57-
go startRPCServer(resourceManager)
60+
klog.Infof("Setting gRPC server timeout to %v", *grpcTimeout)
61+
go startRPCServer(resourceManager, *grpcTimeout)
5862
startHttpProxy()
5963
// See also https://gist.github.com/enricofoltran/10b4a980cd07cb02836f70a4ab3e72d7
6064
quit := make(chan os.Signal, 1)
@@ -70,7 +74,7 @@ func main() {
7074

7175
type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error
7276

73-
func startRPCServer(resourceManager *manager.ResourceManager) {
77+
func startRPCServer(resourceManager *manager.ResourceManager, grpcTimeout time.Duration) {
7478
klog.Infof("Starting gRPC server at port %s", *rpcPortFlag)
7579

7680
listener, err := net.Listen("tcp", *rpcPortFlag)
@@ -86,8 +90,13 @@ func startRPCServer(resourceManager *manager.ResourceManager) {
8690

8791
s := grpc.NewServer(
8892
grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor),
89-
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(grpc_prometheus.UnaryServerInterceptor, interceptor.APIServerInterceptor)),
90-
grpc.MaxRecvMsgSize(math.MaxInt32))
93+
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
94+
interceptor.TimeoutInterceptor(grpcTimeout),
95+
grpc_prometheus.UnaryServerInterceptor,
96+
interceptor.APIServerInterceptor,
97+
)),
98+
grpc.MaxRecvMsgSize(math.MaxInt32),
99+
)
91100
api.RegisterClusterServiceServer(s, clusterServer)
92101
api.RegisterComputeTemplateServiceServer(s, templateServer)
93102
api.RegisterRayJobServiceServer(s, jobServer)

apiserver/pkg/interceptor/interceptor.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package interceptor
22

33
import (
44
"context"
5+
"time"
56

67
"google.golang.org/grpc"
78
klog "k8s.io/klog/v2"
@@ -19,3 +20,17 @@ func APIServerInterceptor(ctx context.Context, req interface{}, info *grpc.Unary
1920
klog.Infof("%v handler finished", info.FullMethod)
2021
return
2122
}
23+
24+
// TimeoutInterceptor implements UnaryServerInterceptor that sets the timeout for the request
25+
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
26+
return func(
27+
ctx context.Context,
28+
req interface{},
29+
_ *grpc.UnaryServerInfo,
30+
handler grpc.UnaryHandler,
31+
) (interface{}, error) {
32+
ctx, cancel := context.WithTimeout(ctx, timeout)
33+
defer cancel()
34+
return handler(ctx, req)
35+
}
36+
}

apiserver/pkg/interceptor/interceptor_test.go

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@ import (
77
"io"
88
"os"
99
"testing"
10+
"time"
1011

1112
"github.com/stretchr/testify/assert"
1213
"github.com/stretchr/testify/require"
1314
"google.golang.org/grpc"
15+
"google.golang.org/grpc/codes"
16+
"google.golang.org/grpc/status"
1417
klog "k8s.io/klog/v2"
1518
)
1619

@@ -20,9 +23,28 @@ type mockHandler struct {
2023
called bool
2124
}
2225

23-
func (h *mockHandler) Handle(_ context.Context, _ interface{}) (interface{}, error) {
26+
// Handle simulates the behavior of a gRPC handler with an optional delay.
27+
// If the delay completes before the context expires, it returns "test_response" along with predefined error.
28+
// If the context is canceled or the deadline is exceeded before the delay completes,
29+
// it returns a corresponding gRPC status error instead.
30+
func (h *mockHandler) Handle(ctx context.Context, _ interface{}, delay time.Duration) (interface{}, error) {
2431
h.called = true
25-
return "test_response", h.returnErr
32+
33+
select {
34+
case <-time.After(delay):
35+
return "test_response", h.returnErr
36+
case <-ctx.Done():
37+
var grpcCode codes.Code
38+
switch ctx.Err() {
39+
case context.Canceled:
40+
grpcCode = codes.Canceled
41+
case context.DeadlineExceeded:
42+
grpcCode = codes.DeadlineExceeded
43+
default:
44+
grpcCode = codes.Unknown
45+
}
46+
return nil, status.Error(grpcCode, ctx.Err().Error())
47+
}
2648
}
2749

2850
func TestAPIServerInterceptor(t *testing.T) {
@@ -61,7 +83,7 @@ func TestAPIServerInterceptor(t *testing.T) {
6183
req,
6284
info,
6385
func(ctx context.Context, req interface{}) (interface{}, error) {
64-
return tt.handler.Handle(ctx, req)
86+
return tt.handler.Handle(ctx, req, 0 /*delay*/)
6587
},
6688
)
6789

@@ -96,7 +118,7 @@ func TestAPIServerInterceptorContextPassing(t *testing.T) {
96118
func(receivedCtx context.Context, req interface{}) (interface{}, error) {
97119
// Verify context value is passed through
98120
assert.Equal(t, "test_value", receivedCtx.Value(testContextKey("test_key")))
99-
return handler.Handle(receivedCtx, req)
121+
return handler.Handle(receivedCtx, req, 0 /*delay*/)
100122
},
101123
)
102124
}
@@ -153,7 +175,7 @@ func TestAPIServerInterceptorLogging(t *testing.T) {
153175
"test_request",
154176
info,
155177
func(receivedCtx context.Context, req interface{}) (interface{}, error) {
156-
return handler.Handle(receivedCtx, req)
178+
return handler.Handle(receivedCtx, req, 0 /*delay*/)
157179
},
158180
)
159181

@@ -192,3 +214,62 @@ func TestAPIServerInterceptorLogging(t *testing.T) {
192214
})
193215
}
194216
}
217+
218+
func TestTimeoutInterceptor(t *testing.T) {
219+
tests := []struct {
220+
expectedError error
221+
name string
222+
timeout time.Duration
223+
handlerDelay time.Duration
224+
expectedCalled bool
225+
}{
226+
{
227+
name: "handler completes before timeout",
228+
timeout: 100 * time.Millisecond,
229+
handlerDelay: 50 * time.Millisecond,
230+
expectedError: nil,
231+
expectedCalled: true,
232+
},
233+
{
234+
name: "handler exceeds timeout",
235+
timeout: 50 * time.Millisecond,
236+
handlerDelay: 100 * time.Millisecond,
237+
expectedError: status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()),
238+
expectedCalled: true,
239+
},
240+
}
241+
242+
for _, tt := range tests {
243+
t.Run(tt.name, func(t *testing.T) {
244+
// Create test context and request
245+
ctx := context.Background()
246+
req := "test_request"
247+
handler := &mockHandler{}
248+
249+
// Create the interceptor with the specified timeout
250+
interceptor := TimeoutInterceptor(tt.timeout)
251+
252+
// Call the interceptor
253+
resp, err := interceptor(
254+
ctx,
255+
req,
256+
&grpc.UnaryServerInfo{FullMethod: "TestTimeoutMethod"},
257+
func(ctx context.Context, req interface{}) (interface{}, error) {
258+
return handler.Handle(ctx, req, tt.handlerDelay)
259+
},
260+
)
261+
262+
// Verify response and error
263+
if tt.expectedError == nil {
264+
// Verify handler was called
265+
assert.Equal(t, tt.expectedCalled, handler.called, "handler call status should match expected")
266+
267+
require.NoError(t, err)
268+
assert.Equal(t, "test_response", resp, "response should match expected")
269+
} else {
270+
require.Error(t, err)
271+
require.Equal(t, tt.expectedError, err, "A matching error is expected")
272+
}
273+
})
274+
}
275+
}

apiserver/pkg/util/config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package util
22

3+
import "time"
4+
35
// ClientOptions contains configuration needed to create a Kubernetes client
46
type ClientOptions struct {
57
QPS float32
@@ -30,4 +32,7 @@ const (
3032

3133
// The component name for apiserver
3234
ComponentName = "kuberay-apiserver"
35+
36+
// Timeout for apiserver gRPC server
37+
GRPCServerDefaultTimeout = 60 * time.Second
3338
)

0 commit comments

Comments
 (0)