Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions packages/libs/restate-sdk/src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ export interface Request {
* than cleanup any external resources that might be shared across attempts (e.g. database connections).
*/
readonly attemptCompletedSignal: AbortSignal;

/**
* Signal that is aborted when the invocation is cancelled by the Restate runtime.
* Unlike {@link attemptCompletedSignal}, this signal specifically indicates cancellation
* and can be used to abort in-flight operations (e.g., fetch calls, database queries).
*
* The signal's reason will be a {@link CancelledError}.
*/
readonly cancellationSignal: AbortSignal;
}

/* eslint-disable @typescript-eslint/no-explicit-any */
Expand Down
23 changes: 22 additions & 1 deletion packages/libs/restate-sdk/src/context_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
WasmHeader,
} from "./endpoint/handlers/vm/sdk_shared_core_wasm_bindings.js";
import {
CancelledError,
ensureError,
INTERNAL_ERROR_CODE,
logError,
Expand Down Expand Up @@ -68,6 +69,7 @@ import type {
import { CompletablePromise } from "./utils/completable_promise.js";
import type { AsyncResultValue, InternalRestatePromise } from "./promises.js";
import {
CancellationWatcherPromise,
extractContext,
InvocationPendingPromise,
pendingPromise,
Expand Down Expand Up @@ -109,7 +111,10 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
outputWriter: WritableStreamDefaultWriter<Uint8Array>,
readonly journalValueCodec: JournalValueCodec,
defaultSerde?: Serde<any>,
private readonly asTerminalError?: (error: any) => TerminalError | undefined
private readonly asTerminalError?: (
error: any
) => TerminalError | undefined,
cancellationController?: AbortController
) {
this.rand = new RandImpl(input.random_seed, () => {
// TODO reimplement this check with async context
Expand All @@ -133,6 +138,22 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
this.promiseExecutorErrorCallback.bind(this)
);
this.defaultSerde = defaultSerde ?? serde.json;

if (cancellationController) {
const cancelWatcher = new CancellationWatcherPromise(this, () => {
if (!cancellationController.signal.aborted) {
cancellationController.abort(new CancelledError());
}
});
invocationEndPromise.promise.then(
() => cancelWatcher.stop(),
() => cancelWatcher.stop()
);
void cancelWatcher.then(
() => {},
() => {}
);
}
}

cancel(invocationId: InvocationId): void {
Expand Down
10 changes: 9 additions & 1 deletion packages/libs/restate-sdk/src/endpoint/handlers/generic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
*/

import {
CancelledError,
ensureError,
logError,
RestateError,
Expand Down Expand Up @@ -369,6 +370,11 @@ export class GenericHandler implements RestateHandler {
// Get input
const input = coreVm.sys_input();

const cancellationController = new AbortController();
if (coreVm.is_completed(vm.cancel_handle())) {
cancellationController.abort(new CancelledError());
}

const invocationRequest: Request = {
id: input.invocation_id,
headers: input.headers.reduce((headers, { key, value }) => {
Expand All @@ -387,6 +393,7 @@ export class GenericHandler implements RestateHandler {
body: input.input,
extraArgs,
attemptCompletedSignal: abortSignal,
cancellationSignal: cancellationController.signal,
};

// Prepare logger
Expand Down Expand Up @@ -442,7 +449,8 @@ export class GenericHandler implements RestateHandler {
outputWriter,
journalValueCodec,
service.options?.serde,
service.options?.asTerminalError
service.options?.asTerminalError,
cancellationController
);

journalValueCodec
Expand Down
51 changes: 51 additions & 0 deletions packages/libs/restate-sdk/src/promises.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import type {
InvocationPromise,
} from "./context.js";
import type * as vm from "./endpoint/handlers/vm/sdk_shared_core_wasm_bindings.js";
import { cancel_handle } from "./endpoint/handlers/vm/sdk_shared_core_wasm_bindings.js";
import {
CancelledError,
RestateError,
Expand Down Expand Up @@ -361,6 +362,56 @@ export class RestateMappedPromise<T, U> extends AbstractRestatePromise<U> {
readonly [Symbol.toStringTag] = "RestateMappedPromise";
}

export class CancellationWatcherPromise extends AbstractRestatePromise<void> {
private completed = false;
private readonly completablePromise = new CompletablePromise<void>();

constructor(
ctx: ContextImpl,
private readonly onCancellation: () => void
) {
super(ctx);
}

uncompletedLeaves(): number[] {
return this.completed ? [] : [cancel_handle()];
}

tryComplete(): Promise<void> {
if (this.completed) return Promise.resolve();
if (this[RESTATE_CTX_SYMBOL].coreVm.is_completed(cancel_handle())) {
this.completed = true;
this.onCancellation();
this.completablePromise.resolve();
}
return Promise.resolve();
}

override tryCancel() {
if (this.completed) return;
this.completed = true;
this.onCancellation();
this.completablePromise.resolve();
}

/**
* Stop watching without triggering the cancellation callback.
* Used to cleanly shut down the watcher when the invocation ends
* (for any reason), preventing interaction with a closed VM.
*/
stop() {
if (this.completed) return;
this.completed = true;
this.completablePromise.resolve();
}

publicPromise(): Promise<void> {
return this.completablePromise.promise;
}

readonly [Symbol.toStringTag] = "CancellationWatcherPromise";
}

/**
* Promises executor, gluing VM with I/O and Promises given to user space.
*/
Expand Down
177 changes: 177 additions & 0 deletions packages/libs/restate-sdk/test/cancellation_signal.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
*
* This file is part of the Restate SDK for Node.js/TypeScript,
* which is released under the MIT license.
*
* You can find a copy of the license in file LICENSE in the root
* directory of this repository or package, or at
* https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
*/

import { describe, expect, it, vi } from "vitest";
import {
CancellationWatcherPromise,
RESTATE_CTX_SYMBOL,
} from "../src/promises.js";
import { cancel_handle } from "../src/endpoint/handlers/vm/sdk_shared_core_wasm_bindings.js";

function createMockCtx(isCompletedFn: (handle: number) => boolean) {
return {
coreVm: {
is_completed: isCompletedFn,
},
promisesExecutor: {
doProgress: () => Promise.resolve(),
},
} as any;
}

describe("CancellationWatcherPromise", () => {
it("tryComplete fires callback when is_completed returns true", async () => {
const callback = vi.fn();
let completed = false;
const ctx = createMockCtx(() => completed);

const watcher = new CancellationWatcherPromise(ctx, callback);

await watcher.tryComplete();
expect(callback).not.toHaveBeenCalled();

completed = true;
await watcher.tryComplete();
expect(callback).toHaveBeenCalledOnce();
});

it("tryCancel fires callback and resolves publicPromise", async () => {
const callback = vi.fn();
const ctx = createMockCtx(() => false);

const watcher = new CancellationWatcherPromise(ctx, callback);

watcher.tryCancel();
expect(callback).toHaveBeenCalledOnce();

await expect(watcher.publicPromise()).resolves.toBeUndefined();
});

it("uncompletedLeaves returns [cancel_handle()] before completion and [] after", async () => {
const callback = vi.fn();
const ctx = createMockCtx(() => true);

const watcher = new CancellationWatcherPromise(ctx, callback);

expect(watcher.uncompletedLeaves()).toEqual([cancel_handle()]);

await watcher.tryComplete();

expect(watcher.uncompletedLeaves()).toEqual([]);
});

it("callback is only fired once even if tryComplete is called multiple times", async () => {
const callback = vi.fn();
const ctx = createMockCtx(() => true);

const watcher = new CancellationWatcherPromise(ctx, callback);

await watcher.tryComplete();
await watcher.tryComplete();
await watcher.tryComplete();

expect(callback).toHaveBeenCalledOnce();
});

it("callback is only fired once even if tryCancel is called multiple times", () => {
const callback = vi.fn();
const ctx = createMockCtx(() => false);

const watcher = new CancellationWatcherPromise(ctx, callback);

watcher.tryCancel();
watcher.tryCancel();
watcher.tryCancel();

expect(callback).toHaveBeenCalledOnce();
});

it("tryComplete after tryCancel does not fire callback again", async () => {
const callback = vi.fn();
const ctx = createMockCtx(() => true);

const watcher = new CancellationWatcherPromise(ctx, callback);

watcher.tryCancel();
expect(callback).toHaveBeenCalledOnce();

await watcher.tryComplete();
expect(callback).toHaveBeenCalledOnce();
});

it("publicPromise resolves after tryComplete detects completion", async () => {
const callback = vi.fn();
const ctx = createMockCtx(() => true);

const watcher = new CancellationWatcherPromise(ctx, callback);

await watcher.tryComplete();

await expect(watcher.publicPromise()).resolves.toBeUndefined();
});

it("has correct RESTATE_CTX_SYMBOL reference", () => {
const callback = vi.fn();
const ctx = createMockCtx(() => false);

const watcher = new CancellationWatcherPromise(ctx, callback);

expect(watcher[RESTATE_CTX_SYMBOL]).toBe(ctx);
});

it("stop resolves publicPromise without firing callback", async () => {
const callback = vi.fn();
const ctx = createMockCtx(() => false);

const watcher = new CancellationWatcherPromise(ctx, callback);

watcher.stop();
expect(callback).not.toHaveBeenCalled();

await expect(watcher.publicPromise()).resolves.toBeUndefined();
expect(watcher.uncompletedLeaves()).toEqual([]);
});

it("stop after tryCancel does not fire callback again", () => {
const callback = vi.fn();
const ctx = createMockCtx(() => false);

const watcher = new CancellationWatcherPromise(ctx, callback);

watcher.tryCancel();
expect(callback).toHaveBeenCalledOnce();

watcher.stop();
expect(callback).toHaveBeenCalledOnce();
});

it("tryCancel after stop does not fire callback", () => {
const callback = vi.fn();
const ctx = createMockCtx(() => false);

const watcher = new CancellationWatcherPromise(ctx, callback);

watcher.stop();
watcher.tryCancel();
expect(callback).not.toHaveBeenCalled();
});

it("tryComplete after stop does not fire callback", async () => {
const callback = vi.fn();
const ctx = createMockCtx(() => true);

const watcher = new CancellationWatcherPromise(ctx, callback);

watcher.stop();
await watcher.tryComplete();
expect(callback).not.toHaveBeenCalled();
});
});
1 change: 1 addition & 0 deletions packages/tests/restate-e2e-services/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import "./event_handler.js";
import "./list.js";
import "./map.js";
import "./cancel_test.js";
import "./cancel_signal_test.js";
import "./non_determinism.js";
import "./failing.js";
import "./side_effect.js";
Expand Down
Loading