Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions apiserver/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"path"
"strings"
"sync/atomic"
"time"

assetfs "github.com/elazarl/go-bindata-assetfs"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
Expand All @@ -27,6 +28,7 @@ import (
"github.com/ray-project/kuberay/apiserver/pkg/manager"
"github.com/ray-project/kuberay/apiserver/pkg/server"
"github.com/ray-project/kuberay/apiserver/pkg/swagger"
"github.com/ray-project/kuberay/apiserver/pkg/util"
api "github.com/ray-project/kuberay/proto/go_client"
)

Expand All @@ -36,6 +38,7 @@ var (
collectMetricsFlag = flag.Bool("collectMetricsFlag", true, "Whether to collect Prometheus metrics in API server.")
logFile = flag.String("logFilePath", "", "Synchronize logs to local file")
localSwaggerPath = flag.String("localSwaggerPath", "", "Specify the root directory for `*.swagger.json` the swagger files.")
grpcTimeout = flag.Duration("grpc_timeout", util.GRPCServerDefaultTimeoutSeconds*time.Second, "gRPC server timeout duration")
healthy int32
)

Expand All @@ -54,7 +57,8 @@ func main() {
resourceManager := manager.NewResourceManager(&clientManager)

atomic.StoreInt32(&healthy, 1)
go startRPCServer(resourceManager)
klog.Infof("Setting gRPC server timeout to %v", *grpcTimeout)
go startRPCServer(resourceManager, *grpcTimeout)
startHttpProxy()
// See also https://gist.github.com/enricofoltran/10b4a980cd07cb02836f70a4ab3e72d7
quit := make(chan os.Signal, 1)
Expand All @@ -70,7 +74,7 @@ func main() {

type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error

func startRPCServer(resourceManager *manager.ResourceManager) {
func startRPCServer(resourceManager *manager.ResourceManager, grpcTimeout time.Duration) {
klog.Infof("Starting gRPC server at port %s", *rpcPortFlag)

listener, err := net.Listen("tcp", *rpcPortFlag)
Expand All @@ -86,8 +90,13 @@ func startRPCServer(resourceManager *manager.ResourceManager) {

s := grpc.NewServer(
grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(grpc_prometheus.UnaryServerInterceptor, interceptor.APIServerInterceptor)),
grpc.MaxRecvMsgSize(math.MaxInt32))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
interceptor.TimeoutInterceptor(grpcTimeout),
grpc_prometheus.UnaryServerInterceptor,
interceptor.APIServerInterceptor,
)),
grpc.MaxRecvMsgSize(math.MaxInt32),
)
api.RegisterClusterServiceServer(s, clusterServer)
api.RegisterComputeTemplateServiceServer(s, templateServer)
api.RegisterRayJobServiceServer(s, jobServer)
Expand Down
15 changes: 15 additions & 0 deletions apiserver/pkg/interceptor/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package interceptor

import (
"context"
"time"

"google.golang.org/grpc"
klog "k8s.io/klog/v2"
Expand All @@ -19,3 +20,17 @@ func APIServerInterceptor(ctx context.Context, req interface{}, info *grpc.Unary
klog.Infof("%v handler finished", info.FullMethod)
return
}

// TimeoutInterceptor implements UnaryServerInterceptor that sets the timeout for the request
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return handler(ctx, req)
}
}
87 changes: 82 additions & 5 deletions apiserver/pkg/interceptor/interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ import (
"io"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
klog "k8s.io/klog/v2"
)

Expand All @@ -20,9 +23,24 @@ type mockHandler struct {
called bool
}

func (h *mockHandler) Handle(_ context.Context, _ interface{}) (interface{}, error) {
func (h *mockHandler) Handle(ctx context.Context, _ interface{}, delay time.Duration) (interface{}, error) {
h.called = true
return "test_response", h.returnErr

select {
case <-time.After(delay):
return "test_response", h.returnErr
case <-ctx.Done():
var grpcCode codes.Code
switch ctx.Err() {
case context.Canceled:
grpcCode = codes.Canceled
case context.DeadlineExceeded:
grpcCode = codes.DeadlineExceeded
default:
grpcCode = codes.Unknown
}
return nil, status.Error(grpcCode, ctx.Err().Error())
}
Comment on lines +33 to +47
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this to mimic the grpc IO handler for testing

}

func TestAPIServerInterceptor(t *testing.T) {
Expand Down Expand Up @@ -61,7 +79,7 @@ func TestAPIServerInterceptor(t *testing.T) {
req,
info,
func(ctx context.Context, req interface{}) (interface{}, error) {
return tt.handler.Handle(ctx, req)
return tt.handler.Handle(ctx, req, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 0 /*delay*/

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Thanks!

},
)

Expand Down Expand Up @@ -96,7 +114,7 @@ func TestAPIServerInterceptorContextPassing(t *testing.T) {
func(receivedCtx context.Context, req interface{}) (interface{}, error) {
// Verify context value is passed through
assert.Equal(t, "test_value", receivedCtx.Value(testContextKey("test_key")))
return handler.Handle(receivedCtx, req)
return handler.Handle(receivedCtx, req, 0)
},
)
}
Expand Down Expand Up @@ -153,7 +171,7 @@ func TestAPIServerInterceptorLogging(t *testing.T) {
"test_request",
info,
func(receivedCtx context.Context, req interface{}) (interface{}, error) {
return handler.Handle(receivedCtx, req)
return handler.Handle(receivedCtx, req, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a comment besides constants

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Thanks!

},
)

Expand Down Expand Up @@ -192,3 +210,62 @@ func TestAPIServerInterceptorLogging(t *testing.T) {
})
}
}

func TestTimeoutInterceptor(t *testing.T) {
tests := []struct {
expectedError error
name string
timeout time.Duration
handlerDelay time.Duration
expectedCalled bool
}{
{
name: "handler completes before timeout",
timeout: 100 * time.Millisecond,
handlerDelay: 50 * time.Millisecond,
expectedError: nil,
expectedCalled: true,
},
{
name: "handler exceeds timeout",
timeout: 50 * time.Millisecond,
handlerDelay: 100 * time.Millisecond,
expectedError: status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()),
expectedCalled: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test context and request
ctx := context.Background()
req := "test_request"
handler := &mockHandler{}

// Create the interceptor with the specified timeout
interceptor := TimeoutInterceptor(tt.timeout)

// Call the interceptor
resp, err := interceptor(
ctx,
req,
&grpc.UnaryServerInfo{FullMethod: "TestTimeoutMethod"},
func(ctx context.Context, req interface{}) (interface{}, error) {
return handler.Handle(ctx, req, tt.handlerDelay)
},
)

// Verify response and error
if tt.expectedError == nil {
// Verify handler was called
assert.Equal(t, tt.expectedCalled, handler.called, "handler call status should match expected")

require.NoError(t, err)
assert.Equal(t, "test_response", resp, "response should match expected")
} else {
require.Error(t, err)
require.Equal(t, tt.expectedError, err, "A matching error is expected")
}
})
}
}
3 changes: 3 additions & 0 deletions apiserver/pkg/util/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,7 @@ const (

// The component name for apiserver
ComponentName = "kuberay-apiserver"

// Timeout for apiserver gRPC server
GRPCServerDefaultTimeoutSeconds = 60
)
Loading