diff --git a/README.md b/README.md index 82840f9..7be0e2e 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ import "github.com/modal-labs/libmodal/modal-go" Examples: - [Call a deployed function](./modal-go/examples/function-call/main.go) +- [Spawn a deployed function](./modal-go/examples/function-spawn/main.go) - [Call a deployed cls](./modal-go/examples/cls-call/main.go) - [Create a sandbox](./modal-go/examples/sandbox/main.go) - [Execute sandbox commands](./modal-go/examples/sandbox-exec/main.go) diff --git a/modal-client b/modal-client index ed035f6..2bd3ea0 160000 --- a/modal-client +++ b/modal-client @@ -1 +1 @@ -Subproject commit ed035f67f92a8f4e1e2b82e9727b2f8477ed0248 +Subproject commit 2bd3ea0798bf98f63744ea700141aac2e9f7bed1 diff --git a/modal-go/cls.go b/modal-go/cls.go index 7106af2..a3108f5 100644 --- a/modal-go/cls.go +++ b/modal-go/cls.go @@ -195,7 +195,7 @@ func encodeParameter(paramSpec *pb.ClassParameterSpec, value any) (*pb.ClassPara paramValue.SetBytesValue(bytesValue) default: - return nil, fmt.Errorf("unsupported parameter type: %v", paramType) + return nil, fmt.Errorf("unsupported parameter type: %w", paramType) } return paramValue, nil diff --git a/modal-go/examples/cls-call/main.go b/modal-go/examples/cls-call/main.go index e147596..0b02a00 100644 --- a/modal-go/examples/cls-call/main.go +++ b/modal-go/examples/cls-call/main.go @@ -4,6 +4,7 @@ package main import ( "context" + "fmt" "log" "github.com/modal-labs/libmodal/modal-go" @@ -18,30 +19,30 @@ func main() { "libmodal-test-support", "EchoCls", modal.LookupOptions{}, ) if err != nil { - log.Fatalf("Failed to lookup Cls: %v", err) + fmt.Errorf("Failed to lookup Cls: %w", err) } instance, err := cls.Instance(nil) if err != nil { - log.Fatalf("Failed to create Cls instance: %v", err) + fmt.Errorf("Failed to create Cls instance: %w", err) } function, err := instance.Method("echo_string") if err != nil { - log.Fatalf("Failed to access Cls method: %v", err) + fmt.Errorf("Failed to access Cls method: %w", err) } // Call the Cls function with args. result, err := function.Remote([]any{"Hello world!"}, nil) if err != nil { - log.Fatalf("Failed to call Cls method: %v", err) + fmt.Errorf("Failed to call Cls method: %w", err) } - log.Printf("%v\n", result) + log.Println("Response:", result) // Call the Cls function with kwargs. result, err = function.Remote(nil, map[string]any{"s": "Hello world!"}) if err != nil { - log.Fatalf("Failed to call Cls method: %v", err) + fmt.Errorf("Failed to call Cls method: %w", err) } - log.Printf("%v\n", result) + log.Println("Response:", result) } diff --git a/modal-go/examples/function-call/main.go b/modal-go/examples/function-call/main.go index acf314e..59145de 100644 --- a/modal-go/examples/function-call/main.go +++ b/modal-go/examples/function-call/main.go @@ -15,18 +15,18 @@ func main() { echo, err := modal.FunctionLookup(ctx, "libmodal-test-support", "echo_string", modal.LookupOptions{}) if err != nil { - log.Fatalf("Failed to lookup function: %v", err) + fmt.Errorf("Failed to lookup function: %w", err) } ret, err := echo.Remote([]any{"Hello world!"}, nil) if err != nil { - log.Fatalf("Failed to call function: %v", err) + fmt.Errorf("Failed to call function: %w", err) } - fmt.Printf("%s\n", ret) + log.Println("Response:", ret) ret, err = echo.Remote(nil, map[string]any{"s": "Hello world!"}) if err != nil { - log.Fatalf("Failed to call function with kwargs: %v", err) + fmt.Errorf("Failed to call function with kwargs: %w", err) } - log.Printf("%s\n", ret) + log.Println("Response:", ret) } diff --git a/modal-go/examples/function-spawn/main.go b/modal-go/examples/function-spawn/main.go new file mode 100644 index 0000000..9c79f86 --- /dev/null +++ b/modal-go/examples/function-spawn/main.go @@ -0,0 +1,32 @@ +// This example spawns a function defined in `libmodal_test_support.py`, and +// later gets its outputs. + +package main + +import ( + "context" + "fmt" + "log" + + "github.com/modal-labs/libmodal/modal-go" +) + +func main() { + ctx := context.Background() + + echo, err := modal.FunctionLookup(ctx, "libmodal-test-support", "echo_string", modal.LookupOptions{}) + if err != nil { + fmt.Errorf("Failed to lookup function: %w", err) + } + + fc, err := echo.Spawn(nil, map[string]any{"s": "Hello world!"}) + if err != nil { + fmt.Errorf("Failed to spawn function: %w", err) + } + + ret, err := fc.Get(modal.FunctionCallGetOptions{}) + if err != nil { + fmt.Errorf("Failed to get function results: %w", err) + } + log.Println("Response:", ret) +} diff --git a/modal-go/examples/sandbox-exec/main.go b/modal-go/examples/sandbox-exec/main.go index e46e42b..753c3e2 100644 --- a/modal-go/examples/sandbox-exec/main.go +++ b/modal-go/examples/sandbox-exec/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "io" "log" @@ -13,17 +14,17 @@ func main() { app, err := modal.AppLookup(ctx, "libmodal-example", modal.LookupOptions{CreateIfMissing: true}) if err != nil { - log.Fatalf("Failed to lookup or create app: %v", err) + fmt.Errorf("Failed to lookup or create app: %w", err) } image, err := app.ImageFromRegistry("python:3.13-slim") if err != nil { - log.Fatalf("Failed to create image from registry: %v", err) + fmt.Errorf("Failed to create image from registry: %w", err) } sb, err := app.CreateSandbox(image, modal.SandboxOptions{}) if err != nil { - log.Fatalf("Failed to create sandbox: %v", err) + fmt.Errorf("Failed to create sandbox: %w", err) } log.Println("Started sandbox:", sb.SandboxId) defer sb.Terminate() @@ -47,22 +48,22 @@ for i in range(50000): }, ) if err != nil { - log.Fatalf("Failed to execute command in sandbox: %v", err) + fmt.Errorf("Failed to execute command in sandbox: %w", err) } contentStdout, err := io.ReadAll(p.Stdout) if err != nil { - log.Fatalf("Failed to read stdout: %v", err) + fmt.Errorf("Failed to read stdout: %w", err) } contentStderr, err := io.ReadAll(p.Stderr) if err != nil { - log.Fatalf("Failed to read stderr: %v", err) + fmt.Errorf("Failed to read stderr: %w", err) } log.Printf("Got %d bytes stdout and %d bytes stderr\n", len(contentStdout), len(contentStderr)) returnCode, err := p.Wait() if err != nil { - log.Fatalf("Failed to wait for process completion: %v", err) + fmt.Errorf("Failed to wait for process completion: %w", err) } log.Println("Return code:", returnCode) diff --git a/modal-go/examples/sandbox/main.go b/modal-go/examples/sandbox/main.go index 7877038..8420b0a 100644 --- a/modal-go/examples/sandbox/main.go +++ b/modal-go/examples/sandbox/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "io" "log" @@ -13,34 +14,34 @@ func main() { app, err := modal.AppLookup(ctx, "libmodal-example", modal.LookupOptions{CreateIfMissing: true}) if err != nil { - log.Fatalf("Failed to lookup or create app: %v", err) + fmt.Errorf("Failed to lookup or create app: %w", err) } image, err := app.ImageFromRegistry("alpine:3.21") if err != nil { - log.Fatalf("Failed to create image from registry: %v", err) + fmt.Errorf("Failed to create image from registry: %w", err) } sb, err := app.CreateSandbox(image, modal.SandboxOptions{ Command: []string{"cat"}, }) if err != nil { - log.Fatalf("Failed to create sandbox: %v", err) + fmt.Errorf("Failed to create sandbox: %w", err) } log.Printf("sandbox: %s\n", sb.SandboxId) _, err = sb.Stdin.Write([]byte("this is input that should be mirrored by cat")) if err != nil { - log.Fatalf("Failed to write to sandbox stdin: %v", err) + fmt.Errorf("Failed to write to sandbox stdin: %w", err) } err = sb.Stdin.Close() if err != nil { - log.Fatalf("Failed to close sandbox stdin: %v", err) + fmt.Errorf("Failed to close sandbox stdin: %w", err) } output, err := io.ReadAll(sb.Stdout) if err != nil { - log.Fatalf("Failed to read from sandbox stdout: %v", err) + fmt.Errorf("Failed to read from sandbox stdout: %w", err) } log.Printf("output: %s\n", string(output)) diff --git a/modal-go/function.go b/modal-go/function.go index 3cc9faf..5ce0e4e 100644 --- a/modal-go/function.go +++ b/modal-go/function.go @@ -21,9 +21,12 @@ import ( ) // From: modal/_utils/blob_utils.py -const maxObjectSizeBytes = 2 * 1024 * 1024 // 2 MiB +const maxObjectSizeBytes int = 2 * 1024 * 1024 // 2 MiB -func timeNow() float64 { +// From: modal-client/modal/_utils/function_utils.py +const OutputsTimeout time.Duration = time.Second * 55 + +func timeNowSeconds() float64 { return float64(time.Now().UnixNano()) / 1e9 } @@ -81,8 +84,8 @@ func pickleDeserialize(buffer []byte) (any, error) { return result, nil } -// Execute a single input into a remote Function. -func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) { +// Serializes inputs, make a function call and return its ID +func (f *Function) execFunctionCall(args []any, kwargs map[string]any, invocationType pb.FunctionCallInvocationType) (*string, error) { payload, err := pickleSerialize(args, kwargs) if err != nil { return nil, err @@ -115,33 +118,83 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) { functionMapResponse, err := client.FunctionMap(f.ctx, pb.FunctionMapRequest_builder{ FunctionId: f.FunctionId, FunctionCallType: pb.FunctionCallType_FUNCTION_CALL_TYPE_UNARY, - FunctionCallInvocationType: pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC, + FunctionCallInvocationType: invocationType, PipelinedInputs: functionInputs, }.Build()) if err != nil { - return nil, fmt.Errorf("FunctionMap error: %v", err) + return nil, fmt.Errorf("FunctionMap error: %w", err) + } + + functionCallId := functionMapResponse.GetFunctionCallId() + return &functionCallId, nil +} + +// Remote executes a single input on a remote Function. +func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) { + invocationType := pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC + functionCallId, err := f.execFunctionCall(args, kwargs, invocationType) + if err != nil { + return nil, err } + return pollFunctionOutput(f.ctx, *functionCallId, OutputsTimeout) +} + +// Poll for ouputs for a given FunctionCall ID +func pollFunctionOutput(ctx context.Context, functionCallId string, timeout time.Duration) (any, error) { + startTime := time.Now() + + // Calculate initial backend timeout + pollTimeout := minTimeout(OutputsTimeout, timeout) for { - response, err := client.FunctionGetOutputs(f.ctx, pb.FunctionGetOutputsRequest_builder{ - FunctionCallId: functionMapResponse.GetFunctionCallId(), + // Context might have been cancelled. Check before next poll operation. + if err := ctx.Err(); err != nil { + return nil, err + } + + response, err := client.FunctionGetOutputs(ctx, pb.FunctionGetOutputsRequest_builder{ + FunctionCallId: functionCallId, MaxValues: 1, - Timeout: 55, + Timeout: float32(pollTimeout.Seconds()), LastEntryId: "0-0", ClearOnSuccess: true, - RequestedAt: timeNow(), + RequestedAt: timeNowSeconds(), }.Build()) if err != nil { - return nil, fmt.Errorf("FunctionGetOutputs failed: %v", err) + return nil, fmt.Errorf("FunctionGetOutputs failed: %w", err) } // Output serialization may fail if any of the output items can't be deserialized // into a supported Go type. Users are expected to serialize outputs correctly. outputs := response.GetOutputs() if len(outputs) > 0 { - return processResult(f.ctx, outputs[0].GetResult(), outputs[0].GetDataFormat()) + return processResult(ctx, outputs[0].GetResult(), outputs[0].GetDataFormat()) } + + remainingTime := timeout - time.Since(startTime) + if remainingTime <= 0 { + m := fmt.Sprintf("Timeout exceeded: %.1fs", timeout.Seconds()) + return nil, FunctionTimeoutError{m} + } + + // Add a small delay before next poll to avoid overloading backend. + time.Sleep(50 * time.Millisecond) + pollTimeout = minTimeout(OutputsTimeout, remainingTime) + } +} + +// Spawn starts running a single input on a remote function. +func (f *Function) Spawn(args []any, kwargs map[string]any) (*FunctionCall, error) { + invocationType := pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_ASYNC + functionCallId, err := f.execFunctionCall(args, kwargs, invocationType) + if err != nil { + return nil, err + } + functionCall := FunctionCall{ + FunctionCallId: *functionCallId, + ctx: f.ctx, } + return &functionCall, nil } // processResult processes the result from an invocation. diff --git a/modal-go/function_call.go b/modal-go/function_call.go new file mode 100644 index 0000000..543f3f2 --- /dev/null +++ b/modal-go/function_call.go @@ -0,0 +1,71 @@ +package modal + +import ( + "context" + "fmt" + "time" + + pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto" +) + +// FunctionCall references a Modal Function Call. Function Calls are +// Function invocations with a given input. They can be consumed +// asynchronously (see Get()) or cancelled (see Cancel()). +type FunctionCall struct { + FunctionCallId string + ctx context.Context +} + +// FunctionCallFromId looks up a FunctionCall. +func FunctionCallFromId(ctx context.Context, functionCallId string) (*FunctionCall, error) { + ctx = clientContext(ctx) + functionCall := FunctionCall{ + FunctionCallId: functionCallId, + ctx: ctx, + } + return &functionCall, nil +} + +// FunctionCallGetOptions are options for getting outputs from Function Calls. +type FunctionCallGetOptions struct { + Timeout time.Duration +} + +// Get waits for the output of a FunctionCall. +// If timeout > 0, the operation will be cancelled after the specified duration. +func (fc *FunctionCall) Get(options FunctionCallGetOptions) (any, error) { + ctx := fc.ctx + + // Use default if not specified. + timeoutSeconds := options.Timeout + if options.Timeout == 0 { + timeoutSeconds = OutputsTimeout + } + return pollFunctionOutput(ctx, fc.FunctionCallId, timeoutSeconds) +} + +// Helper function to find the minimum of two float32 values +func minTimeout(a, b time.Duration) time.Duration { + if a < b { + return a + } + return b +} + +// FunctionCallCancelOptions are options for cancelling Function Calls. +type FunctionCallCancelOptions struct { + TerminateContainers bool +} + +// Cancel cancels a FunctionCall. +func (fc *FunctionCall) Cancel(options FunctionCallCancelOptions) error { + _, err := client.FunctionCallCancel(fc.ctx, pb.FunctionCallCancelRequest_builder{ + FunctionCallId: fc.FunctionCallId, + TerminateContainers: options.TerminateContainers, + }.Build()) + if err != nil { + return fmt.Errorf("FunctionCallCancel failed: %w", err) + } + + return nil +} diff --git a/modal-go/test/function_call_test.go b/modal-go/test/function_call_test.go new file mode 100644 index 0000000..5eed5f7 --- /dev/null +++ b/modal-go/test/function_call_test.go @@ -0,0 +1,65 @@ +package test + +import ( + "context" + "testing" + "time" + + "github.com/modal-labs/libmodal/modal-go" + "github.com/onsi/gomega" +) + +func TestFunctionSpawn(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + + function, err := modal.FunctionLookup( + context.Background(), + "libmodal-test-support", "echo_string", modal.LookupOptions{}, + ) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + // Call function using spawn. + functionCall, err := function.Spawn(nil, map[string]any{"s": "hello"}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + // Get outputs. + result, err := functionCall.Get(modal.FunctionCallGetOptions{}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(result).Should(gomega.Equal("output: hello")) + + // Create FunctionCall instance and get output again. + functionCall, err = modal.FunctionCallFromId(context.Background(), functionCall.FunctionCallId) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + result, err = functionCall.Get(modal.FunctionCallGetOptions{}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(result).Should(gomega.Equal("output: hello")) + + // Looking function that takes a long time to complete. + functionSleep, err := modal.FunctionLookup( + context.Background(), + "libmodal-test-support", "sleep", modal.LookupOptions{}, + ) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + functionCall, err = functionSleep.Spawn(nil, map[string]any{"t": 5}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + // Cancel function call. + err = functionCall.Cancel(modal.FunctionCallCancelOptions{}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + // Attempting to get outputs for a cancelled function call + // is expected to return an error. + _, err = functionCall.Get(modal.FunctionCallGetOptions{}) + g.Expect(err).Should(gomega.HaveOccurred()) + + // Spawn function with long running input. + functionCall, err = functionSleep.Spawn(nil, map[string]any{"t": 5}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + // Get is now expected to timeout. + _, err = functionCall.Get(modal.FunctionCallGetOptions{Timeout: 1 * time.Second}) + g.Expect(err).Should(gomega.HaveOccurred()) +} diff --git a/modal-js/examples/function-spawn.ts b/modal-js/examples/function-spawn.ts new file mode 100644 index 0000000..64eedea --- /dev/null +++ b/modal-js/examples/function-spawn.ts @@ -0,0 +1,10 @@ +// This example calls a function defined in `libmodal_test_support.py`. + +import { Function_ } from "modal"; + +const echo = await Function_.lookup("libmodal-test-support", "echo_string"); + +// Spawn the function with kwargs. +const functionCall = await echo.spawn([], { s: "Hello world!" }); +const ret = await functionCall.get(); +console.log(ret); diff --git a/modal-js/package.json b/modal-js/package.json index 4249b85..ea5703b 100644 --- a/modal-js/package.json +++ b/modal-js/package.json @@ -26,7 +26,7 @@ "format": "prettier --write .", "format:check": "prettier --check .", "prepare": "scripts/gen-proto.sh", - "test": "vitest" + "test": "vitest --reporter=verbose" }, "dependencies": { "long": "^5.3.1", diff --git a/modal-js/src/function.ts b/modal-js/src/function.ts index 6a844da..30754db 100644 --- a/modal-js/src/function.ts +++ b/modal-js/src/function.ts @@ -13,6 +13,7 @@ import { } from "../proto/modal_proto/api"; import { LookupOptions } from "./app"; import { client } from "./client"; +import { FunctionCall } from "./function_call"; import { environmentName } from "./config"; import { InternalFailure, @@ -26,7 +27,10 @@ import { ClientError, Status } from "nice-grpc"; // From: modal/_utils/blob_utils.py const maxObjectSizeBytes = 2 * 1024 * 1024; // 2 MiB -function timeNow() { +// From: modal-client/modal/_utils/function_utils.py +export const outputsTimeout = 55 * 1000; + +function timeNowSeconds() { return Date.now() / 1e3; } @@ -65,6 +69,32 @@ export class Function_ { args: any[] = [], kwargs: Record = {}, ): Promise { + const functionCallId = await this.execFunctionCall( + args, + kwargs, + FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, + ); + return await pollFunctionOutput(functionCallId, outputsTimeout); + } + + // Spawn a single input into a remote function. + async spawn( + args: any[] = [], + kwargs: Record = {}, + ): Promise { + const functionCallId = await this.execFunctionCall( + args, + kwargs, + FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, + ); + return new FunctionCall(functionCallId); + } + + async execFunctionCall( + args: any[] = [], + kwargs: Record = {}, + invocationType: FunctionCallInvocationType = FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, + ): Promise { const payload = dumps([args, kwargs]); let argsBlobId: string | undefined = undefined; @@ -76,8 +106,7 @@ export class Function_ { const functionMapResponse = await client.functionMap({ functionId: this.functionId, functionCallType: FunctionCallType.FUNCTION_CALL_TYPE_UNARY, - functionCallInvocationType: - FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, + functionCallInvocationType: invocationType, pipelinedInputs: [ { idx: 0, @@ -91,20 +120,50 @@ export class Function_ { ], }); - while (true) { + return functionMapResponse.functionCallId; + } +} + +export async function pollFunctionOutput( + functionCallId: string, + timeout: number, // in milliseconds +): Promise { + const startTime = Date.now(); + + // Calculate initial backend timeout + let pollTimeout = Math.min(outputsTimeout, timeout); + + while (true) { + try { const response = await client.functionGetOutputs({ - functionCallId: functionMapResponse.functionCallId, + functionCallId: functionCallId, maxValues: 1, - timeout: 55, + timeout: pollTimeout / 1000, // Backend needs seconds lastEntryId: "0-0", clearOnSuccess: true, - requestedAt: timeNow(), + requestedAt: timeNowSeconds(), }); const outputs = response.outputs; if (outputs.length > 0) { return await processResult(outputs[0].result, outputs[0].dataFormat); } + + const remainingTime = timeout - (Date.now() - startTime); + if (remainingTime <= 0) { + const message = `Timeout exceeded: ${(timeout / 1000).toFixed(1)}s`; + throw new FunctionTimeoutError(message); + } + + // Add a small delay before next poll to avoid overloading backend + await new Promise((resolve) => setTimeout(resolve, 50)); + + pollTimeout = Math.min(outputsTimeout, remainingTime); + } catch (error) { + if (error instanceof FunctionTimeoutError) { + throw error; + } + throw new Error(`FunctionGetOutputs failed: ${error}`); } } } diff --git a/modal-js/src/function_call.ts b/modal-js/src/function_call.ts new file mode 100644 index 0000000..f04f1c8 --- /dev/null +++ b/modal-js/src/function_call.ts @@ -0,0 +1,45 @@ +// Manage existing Function Calls (look-ups, polling for output, cancellation). + +import { client } from "./client"; +import { pollFunctionOutput, outputsTimeout } from "./function"; + +export type FunctionCallGetOptions = { + timeout?: number; // in milliseconds +}; + +export type FunctionCallCancelOptions = { + terminateContainers?: boolean; +}; + +/** Represents a Modal FunctionCall, Function Calls are +Function invocations with a given input. They can be consumed +asynchronously (see get()) or cancelled (see cancel()). +*/ +export class FunctionCall { + readonly functionCallId: string; + + constructor(functionCallId: string) { + this.functionCallId = functionCallId; + } + + // Get output for a FunctionCall ID. + async get(options: FunctionCallGetOptions = {}): Promise { + const timeout = options.timeout || outputsTimeout; + return await pollFunctionOutput(this.functionCallId, timeout); + } + + // Cancel ongoing FunctionCall. + async cancel(options: FunctionCallCancelOptions = {}) { + await client.functionCallCancel({ + functionCallId: this.functionCallId, + terminateContainers: options.terminateContainers, + }); + } +} + +// functionCallFromId looks up a FunctionCall. +export async function functionCallFromId( + functionCallId: string, +): Promise { + return new FunctionCall(functionCallId); +} diff --git a/modal-js/test/function_call.test.ts b/modal-js/test/function_call.test.ts new file mode 100644 index 0000000..ec2063f --- /dev/null +++ b/modal-js/test/function_call.test.ts @@ -0,0 +1,35 @@ +import { Function_, FunctionTimeoutError } from "modal"; +import { expect, test } from "vitest"; + +test("FunctionSpawn", async () => { + const function_ = await Function_.lookup( + "libmodal-test-support", + "echo_string", + ); + + // Spawn function with kwargs. + var functionCall = await function_.spawn([], { s: "hello" }); + expect(functionCall.functionCallId).toBeDefined(); + + // Get results after spawn. + var resultKwargs = await functionCall.get(); + expect(resultKwargs).toBe("output: hello"); + + // Try the same again; same results should still be available. + resultKwargs = await functionCall.get(); + expect(resultKwargs).toBe("output: hello"); + + // Lookup function that takes a long time to complete. + const functionSleep_ = await Function_.lookup( + "libmodal-test-support", + "sleep", + ); + + // Spawn with long running input. + functionCall = await functionSleep_.spawn([], { t: 5 }); + expect(functionCall.functionCallId).toBeDefined(); + + // Getting outputs with timeout raises error. + const promise = functionCall.get({ timeout: 1000 }); // 1000ms + await expect(promise).rejects.toThrowError(FunctionTimeoutError); +}); diff --git a/test-support/libmodal_test_support.py b/test-support/libmodal_test_support.py index c551cd4..6dd64e4 100644 --- a/test-support/libmodal_test_support.py +++ b/test-support/libmodal_test_support.py @@ -1,4 +1,5 @@ import modal +import time app = modal.App("libmodal-test-support") @@ -8,6 +9,9 @@ def echo_string(s: str) -> str: return "output: " + s +@app.function(min_containers=1) +def sleep(t: int) -> None: + time.sleep(t) @app.function(min_containers=1) def bytelength(buf: bytes) -> int: