Skip to content

Add support for calling input plane functions #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions modal-go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,13 @@ var defaultConfig config
// defaultProfile is resolved at package init from MODAL_PROFILE, ~/.modal.toml, etc.
var defaultProfile Profile

// client is the default Modal client that talks to the control plane.
var client pb.ModalClientClient

// clients is a map of server URL => client.
// The us-east client talks to the control plane; all other clients talk to input planes.
var clients = map[string]pb.ModalClientClient{}

func init() {
var err error
defaultConfig, _ = readConfigFile()
Expand All @@ -78,25 +83,39 @@ func init() {
panic(err) // fail fast – credentials are required to proceed
}

_, client, err = newClient(defaultProfile)
client, err = getOrCreateClient(defaultProfile.ServerURL)
if err != nil {
panic(err)
}
}

// newClient dials api.modal.com with auth/timeout/retry interceptors installed.
// getOrCreateClient returns a client for the given server URL, creating it if it doesn't exist.
func getOrCreateClient(serverURL string) (pb.ModalClientClient, error) {
if client, ok := clients[serverURL]; ok {
return client, nil
}

_, client, err := createClient(serverURL)
if err != nil {
return nil, err
}
clients[serverURL] = client
return client, nil
}

// createClient dials the given server URL with auth/timeout/retry interceptors installed.
// It returns (conn, stub). Close the conn when done.
func newClient(profile Profile) (*grpc.ClientConn, pb.ModalClientClient, error) {
func createClient(serverURL string) (*grpc.ClientConn, pb.ModalClientClient, error) {
var target string
var creds credentials.TransportCredentials
if strings.HasPrefix(profile.ServerURL, "https://") {
target = strings.TrimPrefix(profile.ServerURL, "https://")
if strings.HasPrefix(serverURL, "https://") {
target = strings.TrimPrefix(serverURL, "https://")
creds = credentials.NewTLS(&tls.Config{})
} else if strings.HasPrefix(profile.ServerURL, "http://") {
target = strings.TrimPrefix(profile.ServerURL, "http://")
} else if strings.HasPrefix(serverURL, "http://") {
target = strings.TrimPrefix(serverURL, "http://")
creds = insecure.NewCredentials()
} else {
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid server URL: %s", profile.ServerURL)
return nil, nil, status.Errorf(codes.InvalidArgument, "invalid server URL: %s", serverURL)
}

conn, err := grpc.NewClient(
Expand Down
77 changes: 70 additions & 7 deletions modal-go/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import (

pickle "github.com/kisielk/og-rek"
pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
Expand All @@ -29,21 +31,23 @@ func timeNow() float64 {

// Function references a deployed Modal Function.
type Function struct {
FunctionId string
MethodName *string // used for class methods
ctx context.Context
FunctionId string
MethodName *string // used for class methods
ctx context.Context
inputPlaneUrl *string
}

// FunctionLookup looks up an existing Function.
func FunctionLookup(ctx context.Context, appName string, name string, options LookupOptions) (*Function, error) {
ctx = clientContext(ctx)

var header, trailer metadata.MD
resp, err := client.FunctionGet(ctx, pb.FunctionGetRequest_builder{
AppName: appName,
ObjectTag: name,
Namespace: pb.DeploymentNamespace_DEPLOYMENT_NAMESPACE_WORKSPACE,
EnvironmentName: environmentName(options.Environment),
}.Build())
}.Build(), grpc.Header(&header), grpc.Trailer(&trailer))

if status, ok := status.FromError(err); ok && status.Code() == codes.NotFound {
return nil, NotFoundError{fmt.Sprintf("function '%s/%s' not found", appName, name)}
Expand All @@ -52,7 +56,22 @@ func FunctionLookup(ctx context.Context, appName string, name string, options Lo
return nil, err
}

return &Function{FunctionId: resp.GetFunctionId(), ctx: ctx}, nil
// Attach x-modal-auth-token to all future requests.
authTokenArray := header.Get("x-modal-auth-token")
if len(authTokenArray) == 0 {
authTokenArray = trailer.Get("x-modal-auth-token")
}
if len(authTokenArray) > 0 {
authToken := authTokenArray[0]
ctx = metadata.AppendToOutgoingContext(ctx, "x-modal-auth-token", authToken)
}

var inputPlaneUrl *string
if resp.GetHandleMetadata().HasInputPlaneUrl() {
url := resp.GetHandleMetadata().GetInputPlaneUrl()
inputPlaneUrl = &url
}
return &Function{FunctionId: resp.GetFunctionId(), ctx: ctx, inputPlaneUrl: inputPlaneUrl}, nil
}

// Serialize function inputs to the Python pickle format.
Expand Down Expand Up @@ -112,6 +131,49 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
}.Build()
functionInputs = append(functionInputs, functionInputItem)

if f.inputPlaneUrl != nil {
return f.remoteInputPlane(functionInputs)
}

return f.remoteControlPlane(functionInputs)
}

func (f *Function) remoteInputPlane(functionInputs []*pb.FunctionPutInputsItem) (any, error) {
if f.inputPlaneUrl == nil {
return nil, fmt.Errorf("input plane URL is not set")
}

client, err := getOrCreateClient(*f.inputPlaneUrl)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}

attemptStartResponse, err := client.AttemptStart(f.ctx, pb.AttemptStartRequest_builder{
FunctionId: f.FunctionId,
Input: functionInputs[0],
}.Build())
if err != nil {
return nil, fmt.Errorf("AttemptStart error: %v", err)
}

for {
response, err := client.AttemptAwait(f.ctx, pb.AttemptAwaitRequest_builder{
AttemptToken: attemptStartResponse.GetAttemptToken(),
RequestedAt: timeNow(),
TimeoutSecs: 55,
}.Build())
if err != nil {
return nil, fmt.Errorf("AttemptAwait failed: %v", err)
}

output := response.GetOutput()
if output != nil {
return processResult(f.ctx, output.GetResult(), output.GetDataFormat())
}
}
}

func (f *Function) remoteControlPlane(functionInputs []*pb.FunctionPutInputsItem) (any, error) {
functionMapResponse, err := client.FunctionMap(f.ctx, pb.FunctionMapRequest_builder{
FunctionId: f.FunctionId,
FunctionCallType: pb.FunctionCallType_FUNCTION_CALL_TYPE_UNARY,
Expand All @@ -135,8 +197,6 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
return nil, fmt.Errorf("FunctionGetOutputs failed: %v", err)
}

// Output serialization may fail if any of the output items can't be deserialized
// into a supported Go type. Users are expected to serialize outputs correctly.
outputs := response.GetOutputs()
if len(outputs) > 0 {
return processResult(f.ctx, outputs[0].GetResult(), outputs[0].GetDataFormat())
Expand All @@ -145,6 +205,9 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
}

// processResult processes the result from an invocation.
//
// Note that output serialization may fail if any of the output items can't be deserialized
// into a supported Go type. Users are expected to serialize outputs correctly.
func processResult(ctx context.Context, result *pb.GenericResult, dataFormat pb.DataFormat) (any, error) {
if result == nil {
return nil, RemoteError{"Received null result from invocation"}
Expand Down
15 changes: 15 additions & 0 deletions modal-go/test/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ func TestFunctionCallLargeInput(t *testing.T) {
g.Expect(result).Should(gomega.Equal(int64(len)))
}

func TestFunctionCallInputPlane(t *testing.T) {
t.Parallel()
g := gomega.NewWithT(t)

function, err := modal.FunctionLookup(
context.Background(),
"libmodal-test-support", "input_plane", modal.LookupOptions{},
)
g.Expect(err).ShouldNot(gomega.HaveOccurred())

result, err := function.Remote([]any{"hello"}, nil)
g.Expect(err).ShouldNot(gomega.HaveOccurred())
g.Expect(result).Should(gomega.Equal("output: hello"))
}

func TestFunctionNotFound(t *testing.T) {
t.Parallel()
g := gomega.NewWithT(t)
Expand Down
72 changes: 60 additions & 12 deletions modal-js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import {
import { ClientType, ModalClientDefinition } from "../proto/modal_proto/api";
import { profile, Profile } from "./config";

let modalAuthToken: string | undefined;
Copy link
Author

@thecodingwizard thecodingwizard May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the way we propagate auth tokens is a little different in Go vs. Typescript.

In Typescript (and Python), we propagate auth tokens via auth middleware / a grpc interceptor in client.ts.

In Go, we propagate auth tokens via context in function.go.

I'm open to changing Go's implementation to also propagate auth tokens via grpc interceptors in client.go for consistency if we think that's better, but it might be a bit trickier since Go is multithreaded? Regardless, we can always change our minds later.


/** gRPC client middleware to add auth token to request. */
function authMiddleware(profile: Profile): ClientMiddleware {
return async function* authMiddleware<Request, Response>(
Expand All @@ -27,6 +29,31 @@ function authMiddleware(profile: Profile): ClientMiddleware {
options.metadata.set("x-modal-client-version", "1.0.0"); // CLIENT VERSION: Behaves like this Python SDK version
options.metadata.set("x-modal-token-id", profile.tokenId);
options.metadata.set("x-modal-token-secret", profile.tokenSecret);
if (modalAuthToken) {
options.metadata.set("x-modal-auth-token", modalAuthToken);
}

const prevOnHeader = options.onHeader;
options.onHeader = (header) => {
const token = header.get("x-modal-auth-token");
if (token) {
modalAuthToken = token;
}
if (prevOnHeader) {
prevOnHeader(header);
}
};
const prevOnTrailer = options.onTrailer;
options.onTrailer = (trailer) => {
const token = trailer.get("x-modal-auth-token");
if (token) {
modalAuthToken = token;
}
if (prevOnTrailer) {
prevOnTrailer(trailer);
}
};

return yield* call.next(call.request, options);
};
}
Expand Down Expand Up @@ -189,15 +216,36 @@ const retryMiddleware: ClientMiddleware<RetryOptions> =
}
};

// Ref: https://github.com/modal-labs/modal-client/blob/main/modal/_utils/grpc_utils.py
const channel = createChannel(profile.serverUrl, undefined, {
"grpc.max_receive_message_length": 100 * 1024 * 1024,
"grpc.max_send_message_length": 100 * 1024 * 1024,
"grpc-node.flow_control_window": 64 * 1024 * 1024,
});

export const client = createClientFactory()
.use(authMiddleware(profile))
.use(retryMiddleware)
.use(timeoutMiddleware)
.create(ModalClientDefinition, channel);
/**
* Map of server URL => client.
* The us-east client talks to the control plane; all other clients talk to input planes.
*/
const clients: Record<string, ReturnType<typeof createClient>> = {};

/** Returns a client for the given server URL, creating it if it doesn't exist. */
export const getOrCreateClient = (serverURL: string): ReturnType<typeof createClient> => {
if (serverURL in clients) {
return clients[serverURL];
}

clients[serverURL] = createClient(serverURL);
return clients[serverURL];
};

const createClient = (serverURL: string) => {
// Ref: https://github.com/modal-labs/modal-client/blob/main/modal/_utils/grpc_utils.py
const channel = createChannel(serverURL, undefined, {
"grpc.max_receive_message_length": 100 * 1024 * 1024,
"grpc.max_send_message_length": 100 * 1024 * 1024,
"grpc-node.flow_control_window": 64 * 1024 * 1024,
});

return createClientFactory()
.use(authMiddleware(profile))
.use(retryMiddleware)
.use(timeoutMiddleware)
.create(ModalClientDefinition, channel);
}

/** The default Modal client that talks to the control plane. */
export const client = getOrCreateClient(profile.serverUrl);
Loading
Loading