From b4055e230b015c6f38d299aaf60c33a23f5493ff Mon Sep 17 00:00:00 2001 From: Nathan Wang Date: Sat, 17 May 2025 00:51:41 -0400 Subject: [PATCH 1/3] Add Go support for input plane --- modal-go/client.go | 35 +++++++++--- modal-go/function.go | 77 ++++++++++++++++++++++++--- modal-go/test/function_test.go | 15 ++++++ test-support/libmodal_test_support.py | 5 ++ 4 files changed, 117 insertions(+), 15 deletions(-) diff --git a/modal-go/client.go b/modal-go/client.go index 104a87f..65cda7b 100644 --- a/modal-go/client.go +++ b/modal-go/client.go @@ -68,8 +68,13 @@ var defaultConfig config // defaultProfile is resolved at package init from MODAL_PROFILE, ~/.modal.toml, etc. var defaultProfile Profile +// client is the default Modal client that talks to the control plane. var client pb.ModalClientClient +// clients is a map of server URL => client. +// The us-east client talks to the control plane; all other clients talk to input planes. +var clients = map[string]pb.ModalClientClient{} + func init() { var err error defaultConfig, _ = readConfigFile() @@ -78,25 +83,39 @@ func init() { panic(err) // fail fast – credentials are required to proceed } - _, client, err = newClient(defaultProfile) + client, err = getOrCreateClient(defaultProfile.ServerURL) if err != nil { panic(err) } } -// newClient dials api.modal.com with auth/timeout/retry interceptors installed. +// getOrCreateClient returns a client for the given server URL, creating it if it doesn't exist. +func getOrCreateClient(serverURL string) (pb.ModalClientClient, error) { + if client, ok := clients[serverURL]; ok { + return client, nil + } + + _, client, err := createClient(serverURL) + if err != nil { + return nil, err + } + clients[serverURL] = client + return client, nil +} + +// createClient dials the given server URL with auth/timeout/retry interceptors installed. // It returns (conn, stub). Close the conn when done. -func newClient(profile Profile) (*grpc.ClientConn, pb.ModalClientClient, error) { +func createClient(serverURL string) (*grpc.ClientConn, pb.ModalClientClient, error) { var target string var creds credentials.TransportCredentials - if strings.HasPrefix(profile.ServerURL, "https://") { - target = strings.TrimPrefix(profile.ServerURL, "https://") + if strings.HasPrefix(serverURL, "https://") { + target = strings.TrimPrefix(serverURL, "https://") creds = credentials.NewTLS(&tls.Config{}) - } else if strings.HasPrefix(profile.ServerURL, "http://") { - target = strings.TrimPrefix(profile.ServerURL, "http://") + } else if strings.HasPrefix(serverURL, "http://") { + target = strings.TrimPrefix(serverURL, "http://") creds = insecure.NewCredentials() } else { - return nil, nil, status.Errorf(codes.InvalidArgument, "invalid server URL: %s", profile.ServerURL) + return nil, nil, status.Errorf(codes.InvalidArgument, "invalid server URL: %s", serverURL) } conn, err := grpc.NewClient( diff --git a/modal-go/function.go b/modal-go/function.go index 3cc9faf..4c4672f 100644 --- a/modal-go/function.go +++ b/modal-go/function.go @@ -15,7 +15,9 @@ import ( pickle "github.com/kisielk/og-rek" pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto" + "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) @@ -29,21 +31,23 @@ func timeNow() float64 { // Function references a deployed Modal Function. type Function struct { - FunctionId string - MethodName *string // used for class methods - ctx context.Context + FunctionId string + MethodName *string // used for class methods + ctx context.Context + inputPlaneUrl *string } // FunctionLookup looks up an existing Function. func FunctionLookup(ctx context.Context, appName string, name string, options LookupOptions) (*Function, error) { ctx = clientContext(ctx) + var header, trailer metadata.MD resp, err := client.FunctionGet(ctx, pb.FunctionGetRequest_builder{ AppName: appName, ObjectTag: name, Namespace: pb.DeploymentNamespace_DEPLOYMENT_NAMESPACE_WORKSPACE, EnvironmentName: environmentName(options.Environment), - }.Build()) + }.Build(), grpc.Header(&header), grpc.Trailer(&trailer)) if status, ok := status.FromError(err); ok && status.Code() == codes.NotFound { return nil, NotFoundError{fmt.Sprintf("function '%s/%s' not found", appName, name)} @@ -52,7 +56,22 @@ func FunctionLookup(ctx context.Context, appName string, name string, options Lo return nil, err } - return &Function{FunctionId: resp.GetFunctionId(), ctx: ctx}, nil + // Attach x-modal-auth-token to all future requests. + authTokenArray := header.Get("x-modal-auth-token") + if len(authTokenArray) == 0 { + authTokenArray = trailer.Get("x-modal-auth-token") + } + if len(authTokenArray) > 0 { + authToken := authTokenArray[0] + ctx = metadata.AppendToOutgoingContext(ctx, "x-modal-auth-token", authToken) + } + + var inputPlaneUrl *string + if resp.GetHandleMetadata().HasInputPlaneUrl() { + url := resp.GetHandleMetadata().GetInputPlaneUrl() + inputPlaneUrl = &url + } + return &Function{FunctionId: resp.GetFunctionId(), ctx: ctx, inputPlaneUrl: inputPlaneUrl}, nil } // Serialize function inputs to the Python pickle format. @@ -112,6 +131,49 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) { }.Build() functionInputs = append(functionInputs, functionInputItem) + if f.inputPlaneUrl != nil { + return f.remoteInputPlane(functionInputs) + } + + return f.remoteControlPlane(functionInputs) +} + +func (f *Function) remoteInputPlane(functionInputs []*pb.FunctionPutInputsItem) (any, error) { + if f.inputPlaneUrl == nil { + return nil, fmt.Errorf("input plane URL is not set") + } + + client, err := getOrCreateClient(*f.inputPlaneUrl) + if err != nil { + return nil, fmt.Errorf("failed to get client: %w", err) + } + + attemptStartResponse, err := client.AttemptStart(f.ctx, pb.AttemptStartRequest_builder{ + FunctionId: f.FunctionId, + Input: functionInputs[0], + }.Build()) + if err != nil { + return nil, fmt.Errorf("AttemptStart error: %v", err) + } + + for { + response, err := client.AttemptAwait(f.ctx, pb.AttemptAwaitRequest_builder{ + AttemptToken: attemptStartResponse.GetAttemptToken(), + RequestedAt: timeNow(), + TimeoutSecs: 55, + }.Build()) + if err != nil { + return nil, fmt.Errorf("AttemptAwait failed: %v", err) + } + + output := response.GetOutput() + if output != nil { + return processResult(f.ctx, output.GetResult(), output.GetDataFormat()) + } + } +} + +func (f *Function) remoteControlPlane(functionInputs []*pb.FunctionPutInputsItem) (any, error) { functionMapResponse, err := client.FunctionMap(f.ctx, pb.FunctionMapRequest_builder{ FunctionId: f.FunctionId, FunctionCallType: pb.FunctionCallType_FUNCTION_CALL_TYPE_UNARY, @@ -135,8 +197,6 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) { return nil, fmt.Errorf("FunctionGetOutputs failed: %v", err) } - // Output serialization may fail if any of the output items can't be deserialized - // into a supported Go type. Users are expected to serialize outputs correctly. outputs := response.GetOutputs() if len(outputs) > 0 { return processResult(f.ctx, outputs[0].GetResult(), outputs[0].GetDataFormat()) @@ -145,6 +205,9 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) { } // processResult processes the result from an invocation. +// +// Note that output serialization may fail if any of the output items can't be deserialized +// into a supported Go type. Users are expected to serialize outputs correctly. func processResult(ctx context.Context, result *pb.GenericResult, dataFormat pb.DataFormat) (any, error) { if result == nil { return nil, RemoteError{"Received null result from invocation"} diff --git a/modal-go/test/function_test.go b/modal-go/test/function_test.go index 87b8e32..d5a5447 100644 --- a/modal-go/test/function_test.go +++ b/modal-go/test/function_test.go @@ -46,6 +46,21 @@ func TestFunctionCallLargeInput(t *testing.T) { g.Expect(result).Should(gomega.Equal(int64(len))) } +func TestFunctionCallInputPlane(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + + function, err := modal.FunctionLookup( + context.Background(), + "libmodal-test-support", "input_plane", modal.LookupOptions{}, + ) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + result, err := function.Remote([]any{"hello"}, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(result).Should(gomega.Equal("output: hello")) +} + func TestFunctionNotFound(t *testing.T) { t.Parallel() g := gomega.NewWithT(t) diff --git a/test-support/libmodal_test_support.py b/test-support/libmodal_test_support.py index c551cd4..14a63a7 100644 --- a/test-support/libmodal_test_support.py +++ b/test-support/libmodal_test_support.py @@ -14,6 +14,11 @@ def bytelength(buf: bytes) -> int: return len(buf) +@app.function(min_containers=1, experimental_options={"input_plane_region": "us-west"}) +def input_plane(s: str) -> str: + return "output: " + s + + @app.cls(min_containers=1) class EchoCls: @modal.method() From 49e362aedafdaeed94e84e67da013e37aadd4bf2 Mon Sep 17 00:00:00 2001 From: Nathan Wang Date: Sun, 18 May 2025 00:22:34 -0400 Subject: [PATCH 2/3] Add JS support for input plane --- modal-js/src/client.ts | 72 ++++++++++++++++++++++++++++------ modal-js/src/function.ts | 67 ++++++++++++++++++++++++------- modal-js/test/function.test.ts | 9 +++++ 3 files changed, 122 insertions(+), 26 deletions(-) diff --git a/modal-js/src/client.ts b/modal-js/src/client.ts index 9acf72b..58293ce 100644 --- a/modal-js/src/client.ts +++ b/modal-js/src/client.ts @@ -13,6 +13,8 @@ import { import { ClientType, ModalClientDefinition } from "../proto/modal_proto/api"; import { profile, Profile } from "./config"; +let modalAuthToken: string | undefined; + /** gRPC client middleware to add auth token to request. */ function authMiddleware(profile: Profile): ClientMiddleware { return async function* authMiddleware( @@ -27,6 +29,31 @@ function authMiddleware(profile: Profile): ClientMiddleware { options.metadata.set("x-modal-client-version", "1.0.0"); // CLIENT VERSION: Behaves like this Python SDK version options.metadata.set("x-modal-token-id", profile.tokenId); options.metadata.set("x-modal-token-secret", profile.tokenSecret); + if (modalAuthToken) { + options.metadata.set("x-modal-auth-token", modalAuthToken); + } + + const prevOnHeader = options.onHeader; + options.onHeader = (header) => { + const token = header.get("x-modal-auth-token"); + if (token) { + modalAuthToken = token; + } + if (prevOnHeader) { + prevOnHeader(header); + } + }; + const prevOnTrailer = options.onTrailer; + options.onTrailer = (trailer) => { + const token = trailer.get("x-modal-auth-token"); + if (token) { + modalAuthToken = token; + } + if (prevOnTrailer) { + prevOnTrailer(trailer); + } + }; + return yield* call.next(call.request, options); }; } @@ -189,15 +216,36 @@ const retryMiddleware: ClientMiddleware = } }; -// Ref: https://github.com/modal-labs/modal-client/blob/main/modal/_utils/grpc_utils.py -const channel = createChannel(profile.serverUrl, undefined, { - "grpc.max_receive_message_length": 100 * 1024 * 1024, - "grpc.max_send_message_length": 100 * 1024 * 1024, - "grpc-node.flow_control_window": 64 * 1024 * 1024, -}); - -export const client = createClientFactory() - .use(authMiddleware(profile)) - .use(retryMiddleware) - .use(timeoutMiddleware) - .create(ModalClientDefinition, channel); +/** + * Map of server URL => client. + * The us-east client talks to the control plane; all other clients talk to input planes. + */ +const clients: Record> = {}; + +/** Returns a client for the given server URL, creating it if it doesn't exist. */ +export const getOrCreateClient = (serverURL: string): ReturnType => { + if (serverURL in clients) { + return clients[serverURL]; + } + + clients[serverURL] = createClient(serverURL); + return clients[serverURL]; +}; + +const createClient = (serverURL: string) => { + // Ref: https://github.com/modal-labs/modal-client/blob/main/modal/_utils/grpc_utils.py + const channel = createChannel(serverURL, undefined, { + "grpc.max_receive_message_length": 100 * 1024 * 1024, + "grpc.max_send_message_length": 100 * 1024 * 1024, + "grpc-node.flow_control_window": 64 * 1024 * 1024, + }); + + return createClientFactory() + .use(authMiddleware(profile)) + .use(retryMiddleware) + .use(timeoutMiddleware) + .create(ModalClientDefinition, channel); +} + +/** The default Modal client that talks to the control plane. */ +export const client = getOrCreateClient(profile.serverUrl); diff --git a/modal-js/src/function.ts b/modal-js/src/function.ts index 6a844da..94d7151 100644 --- a/modal-js/src/function.ts +++ b/modal-js/src/function.ts @@ -7,12 +7,13 @@ import { DeploymentNamespace, FunctionCallInvocationType, FunctionCallType, + FunctionPutInputsItem, GeneratorDone, GenericResult, GenericResult_GenericStatus, } from "../proto/modal_proto/api"; import { LookupOptions } from "./app"; -import { client } from "./client"; +import { client, getOrCreateClient } from "./client"; import { environmentName } from "./config"; import { InternalFailure, @@ -34,10 +35,12 @@ function timeNow() { export class Function_ { readonly functionId: string; readonly methodName: string | undefined; + private readonly inputPlaneUrl: string | undefined; - constructor(functionId: string, methodName?: string) { + constructor(functionId: string, methodName?: string, inputPlaneUrl?: string) { this.functionId = functionId; this.methodName = methodName; + this.inputPlaneUrl = inputPlaneUrl; } static async lookup( @@ -52,7 +55,7 @@ export class Function_ { namespace: DeploymentNamespace.DEPLOYMENT_NAMESPACE_WORKSPACE, environmentName: environmentName(options.environment), }); - return new Function_(resp.functionId); + return new Function_(resp.functionId, undefined, resp.handleMetadata?.inputPlaneUrl); } catch (err) { if (err instanceof ClientError && err.code === Status.NOT_FOUND) throw new NotFoundError(`Function '${appName}/${name}' not found`); @@ -73,22 +76,58 @@ export class Function_ { } // Single input sync invocation + const functionInputs = [ + { + idx: 0, + input: { + args: argsBlobId ? undefined : payload, + argsBlobId, + dataFormat: DataFormat.DATA_FORMAT_PICKLE, + methodName: this.methodName, + finalInput: false, // This field isn't specified in the Python client, so it defaults to false. + }, + }, + ]; + + if (this.inputPlaneUrl !== undefined) { + return this.remoteInputPlane(functionInputs); + } + return this.remoteControlPlane(functionInputs); + + } + + private async remoteInputPlane(functionInputs: FunctionPutInputsItem[]): Promise { + if (!this.inputPlaneUrl) { + throw new Error("Input plane URL is not set"); + } + const client = getOrCreateClient(this.inputPlaneUrl); + + const attemptStartResponse = await client.attemptStart({ + functionId: this.functionId, + input: functionInputs[0], + }); + + while (true) { + const response = await client.attemptAwait({ + attemptToken: attemptStartResponse.attemptToken, + requestedAt: timeNow(), + timeoutSecs: 55, + }); + + const output = response.output; + if (output) { + return await processResult(output.result, output.dataFormat); + } + } + } + + private async remoteControlPlane(functionInputs: FunctionPutInputsItem[]): Promise { const functionMapResponse = await client.functionMap({ functionId: this.functionId, functionCallType: FunctionCallType.FUNCTION_CALL_TYPE_UNARY, functionCallInvocationType: FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, - pipelinedInputs: [ - { - idx: 0, - input: { - args: argsBlobId ? undefined : payload, - argsBlobId, - dataFormat: DataFormat.DATA_FORMAT_PICKLE, - methodName: this.methodName, - }, - }, - ], + pipelinedInputs: functionInputs, }); while (true) { diff --git a/modal-js/test/function.test.ts b/modal-js/test/function.test.ts index 80eb3ce..0b25c69 100644 --- a/modal-js/test/function.test.ts +++ b/modal-js/test/function.test.ts @@ -27,6 +27,15 @@ test("FunctionCallLargeInput", async () => { expect(result).toBe(len); }); +test("FunctionCallInputPlane", async () => { + const function_ = await Function_.lookup( + "libmodal-test-support", + "input_plane", + ); + const result = await function_.remote(["hello"]); + expect(result).toBe("output: hello"); +}); + test("FunctionNotFound", async () => { const promise = Function_.lookup( "libmodal-test-support", From bb7b679de738cb915b104c415cbf784f487b5660 Mon Sep 17 00:00:00 2001 From: Nathan Wang Date: Mon, 19 May 2025 11:57:21 -0400 Subject: [PATCH 3/3] Do not deploy input plane function yet --- test-support/libmodal_test_support.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test-support/libmodal_test_support.py b/test-support/libmodal_test_support.py index 14a63a7..e65a6cf 100644 --- a/test-support/libmodal_test_support.py +++ b/test-support/libmodal_test_support.py @@ -14,9 +14,10 @@ def bytelength(buf: bytes) -> int: return len(buf) -@app.function(min_containers=1, experimental_options={"input_plane_region": "us-west"}) -def input_plane(s: str) -> str: - return "output: " + s +# TODO(nathan): re-enable once input plane is enabled in prod +# @app.function(min_containers=1, experimental_options={"input_plane_region": "us-west"}) +# def input_plane(s: str) -> str: +# return "output: " + s @app.cls(min_containers=1)