Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
459 changes: 459 additions & 0 deletions apps/api/src/runtime/conditional-execution-handler.test.ts

Large diffs are not rendered by default.

122 changes: 122 additions & 0 deletions apps/api/src/runtime/conditional-execution-handler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import type { CloudflareNodeRegistry } from "../nodes/cloudflare-node-registry";
import type { NodeInputMapper } from "./node-input-mapper";
import type { RuntimeState } from "./runtime";

/**
* Handles conditional logic in workflow execution.
* Determines which nodes should be skipped based on inactive outputs and missing inputs.
*/
export class ConditionalExecutionHandler {
constructor(
private nodeRegistry: CloudflareNodeRegistry,
private inputMapper: NodeInputMapper
) {}

/**
* Marks nodes connected to inactive outputs as skipped.
* This is crucial for conditional logic where only one branch should execute.
*/
markInactiveOutputNodesAsSkipped(
runtimeState: RuntimeState,
nodeIdentifier: string,
nodeOutputs: Record<string, unknown>
): RuntimeState {
const node = runtimeState.workflow.nodes.find(
(n) => n.id === nodeIdentifier
);
if (!node) return runtimeState;

// Find outputs that were NOT produced
const inactiveOutputs = node.outputs
.map((output) => output.name)
.filter((outputName) => !(outputName in nodeOutputs));

if (inactiveOutputs.length === 0) return runtimeState;

// Find all edges from this node's inactive outputs
const inactiveEdges = runtimeState.workflow.edges.filter(
(edge) =>
edge.source === nodeIdentifier &&
inactiveOutputs.includes(edge.sourceOutput)
);

// Process each target node of inactive edges
for (const edge of inactiveEdges) {
this.markNodeAsSkippedIfNoValidInputs(runtimeState, edge.target);
}

return runtimeState;
}

/**
* Marks a node as skipped if it cannot execute due to missing required inputs.
* This is smarter than recursively skipping all dependents.
*/
private markNodeAsSkippedIfNoValidInputs(
runtimeState: RuntimeState,
nodeId: string
): void {
if (
runtimeState.skippedNodes.has(nodeId) ||
runtimeState.executedNodes.has(nodeId)
) {
return; // Already processed
}

const node = runtimeState.workflow.nodes.find((n) => n.id === nodeId);
if (!node) return;

// Check if this node has all required inputs satisfied
const allRequiredInputsSatisfied = this.nodeHasAllRequiredInputsSatisfied(
runtimeState,
nodeId
);

// Only skip if the node cannot execute (missing required inputs)
if (!allRequiredInputsSatisfied) {
runtimeState.skippedNodes.add(nodeId);

// Recursively check dependents of this skipped node
const outgoingEdges = runtimeState.workflow.edges.filter(
(edge) => edge.source === nodeId
);

for (const edge of outgoingEdges) {
this.markNodeAsSkippedIfNoValidInputs(runtimeState, edge.target);
}
}
}

/**
* Checks if a node has all required inputs satisfied.
* A node can execute if all its required inputs are available.
*/
private nodeHasAllRequiredInputsSatisfied(
runtimeState: RuntimeState,
nodeId: string
): boolean {
const node = runtimeState.workflow.nodes.find((n) => n.id === nodeId);
if (!node) return false;

// Get the node type definition to check for required inputs
const executable = this.nodeRegistry.createExecutableNode(node);
if (!executable) return false;

const nodeTypeDefinition = (executable.constructor as any).nodeType;
if (!nodeTypeDefinition) return false;

const inputValues = this.inputMapper.collectNodeInputs(
runtimeState,
nodeId
);

// Check each required input based on the node type definition (not workflow node definition)
for (const input of nodeTypeDefinition.inputs) {
if (input.required && inputValues[input.name] === undefined) {
return false; // Found a required input that's missing
}
}

return true; // All required inputs are satisfied
}
}
205 changes: 205 additions & 0 deletions apps/api/src/runtime/credit-manager.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import type { Node } from "@dafthunk/types";
import { describe, expect, it, vi } from "vitest";

import type { Bindings } from "../context";
import type { CloudflareNodeRegistry } from "../nodes/cloudflare-node-registry";
import { CreditManager } from "./credit-manager";

// Mock the credits utility
vi.mock("../utils/credits", () => ({
getOrganizationComputeUsage: vi.fn(),
}));

import { getOrganizationComputeUsage } from "../utils/credits";

describe("CreditManager", () => {
const createMockEnv = (cloudflareEnv?: string): Bindings => {
return {
CLOUDFLARE_ENV: cloudflareEnv,
KV: {} as any,
} as Bindings;
};

const createMockRegistry = (
nodeTypes: Record<string, { computeCost?: number }>
): CloudflareNodeRegistry => {
return {
getNodeType: vi.fn((type: string) => nodeTypes[type] || {}),
} as any;
};

describe("hasEnoughComputeCredits", () => {
it("should always return true in development mode", async () => {
const env = createMockEnv("development");
const registry = createMockRegistry({});
const manager = new CreditManager(env, registry);

const result = await manager.hasEnoughComputeCredits(
"org-123",
100, // computeCredits
200 // computeCost (exceeds credits)
);

expect(result).toBe(true);
expect(getOrganizationComputeUsage).not.toHaveBeenCalled();
});

it("should return true when credits are sufficient", async () => {
const env = createMockEnv("production");
const registry = createMockRegistry({});
const manager = new CreditManager(env, registry);

vi.mocked(getOrganizationComputeUsage).mockResolvedValue(50); // current usage

const result = await manager.hasEnoughComputeCredits(
"org-123",
100, // total credits
30 // additional cost needed
);

expect(result).toBe(true); // 50 + 30 = 80 <= 100
});

it("should return false when credits are insufficient", async () => {
const env = createMockEnv("production");
const registry = createMockRegistry({});
const manager = new CreditManager(env, registry);

vi.mocked(getOrganizationComputeUsage).mockResolvedValue(80); // current usage

const result = await manager.hasEnoughComputeCredits(
"org-123",
100, // total credits
30 // additional cost needed
);

expect(result).toBe(false); // 80 + 30 = 110 > 100
});

it("should return true when exactly at credit limit", async () => {
const env = createMockEnv("production");
const registry = createMockRegistry({});
const manager = new CreditManager(env, registry);

vi.mocked(getOrganizationComputeUsage).mockResolvedValue(70); // current usage

const result = await manager.hasEnoughComputeCredits(
"org-123",
100, // total credits
30 // additional cost needed
);

expect(result).toBe(true); // 70 + 30 = 100 == 100
});

it("should handle zero current usage", async () => {
const env = createMockEnv("production");
const registry = createMockRegistry({});
const manager = new CreditManager(env, registry);

vi.mocked(getOrganizationComputeUsage).mockResolvedValue(0);

const result = await manager.hasEnoughComputeCredits("org-123", 100, 50);

expect(result).toBe(true); // 0 + 50 = 50 <= 100
});
});

describe("getNodesComputeCost", () => {
it("should calculate total cost for multiple nodes", () => {
const registry = createMockRegistry({
text: { computeCost: 1 },
ai: { computeCost: 10 },
image: { computeCost: 5 },
});
const manager = new CreditManager({} as Bindings, registry);

const nodes: Node[] = [
{ id: "A", type: "text", inputs: [], outputs: [] },
{ id: "B", type: "ai", inputs: [], outputs: [] },
{ id: "C", type: "image", inputs: [], outputs: [] },
] as unknown as Node[];

const result = manager.getNodesComputeCost(nodes);

expect(result).toBe(16); // 1 + 10 + 5
});

it("should use default cost of 1 when computeCost not specified", () => {
const registry = createMockRegistry({
text: {}, // no computeCost specified
unknown: {}, // no computeCost specified
});
const manager = new CreditManager({} as Bindings, registry);

const nodes: Node[] = [
{ id: "A", type: "text", inputs: [], outputs: [] },
{ id: "B", type: "unknown", inputs: [], outputs: [] },
] as unknown as Node[];

const result = manager.getNodesComputeCost(nodes);

expect(result).toBe(2); // 1 + 1 (defaults)
});

it("should handle empty node list", () => {
const registry = createMockRegistry({});
const manager = new CreditManager({} as Bindings, registry);

const result = manager.getNodesComputeCost([]);

expect(result).toBe(0);
});

it("should handle nodes with zero cost", () => {
const registry = createMockRegistry({
free: { computeCost: 0 },
});
const manager = new CreditManager({} as Bindings, registry);

const nodes: Node[] = [
{ id: "A", type: "free", inputs: [], outputs: [] },
{ id: "B", type: "free", inputs: [], outputs: [] },
] as unknown as Node[];

const result = manager.getNodesComputeCost(nodes);

expect(result).toBe(0);
});

it("should handle single node", () => {
const registry = createMockRegistry({
expensive: { computeCost: 100 },
});
const manager = new CreditManager({} as Bindings, registry);

const nodes: Node[] = [
{ id: "A", type: "expensive", inputs: [], outputs: [] },
] as unknown as Node[];

const result = manager.getNodesComputeCost(nodes);

expect(result).toBe(100);
});

it("should sum costs correctly for many nodes", () => {
const registry = createMockRegistry({
type1: { computeCost: 3 },
type2: { computeCost: 7 },
});
const manager = new CreditManager({} as Bindings, registry);

const nodes: Node[] = Array.from({ length: 10 }, (_, i) => ({
id: `node-${i}`,
type: i % 2 === 0 ? "type1" : "type2",
inputs: [],
outputs: [],
})) as unknown as Node[];

const result = manager.getNodesComputeCost(nodes);

// 5 nodes of type1 (3 each) + 5 nodes of type2 (7 each) = 15 + 35 = 50
expect(result).toBe(50);
});
});
});
47 changes: 47 additions & 0 deletions apps/api/src/runtime/credit-manager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import type { Node } from "@dafthunk/types";

import type { Bindings } from "../context";
import type { CloudflareNodeRegistry } from "../nodes/cloudflare-node-registry";
import { getOrganizationComputeUsage } from "../utils/credits";

/**
* Manages compute credits for workflow execution.
* Handles credit checks and cost calculations.
*/
export class CreditManager {
constructor(
private env: Bindings,
private nodeRegistry: CloudflareNodeRegistry
) {}

/**
* Checks if the organization has enough compute credits to execute a workflow.
* Credit limits are not enforced in development mode.
*/
async hasEnoughComputeCredits(
organizationId: string,
computeCredits: number,
computeCost: number
): Promise<boolean> {
// Skip credit limit enforcement in development mode
if (this.env.CLOUDFLARE_ENV === "development") {
return true;
}

const computeUsage = await getOrganizationComputeUsage(
this.env.KV,
organizationId
);
return computeUsage + computeCost <= computeCredits;
}

/**
* Returns the compute cost of a list of nodes.
*/
getNodesComputeCost(nodes: Node[]): number {
return nodes.reduce((acc, node) => {
const nodeType = this.nodeRegistry.getNodeType(node.type);
return acc + (nodeType.computeCost ?? 1);
}, 0);
}
}
Loading