Skip to content

Introduce a scheduling barrier or control dependency operator to StableHLO #2923

@janpfeifer

Description

@janpfeifer

Request description

The problem: Temporary Memory Explosion in XLA CPU (PJRT)

When executing large-scale models (e.g., Transformers like Gemma 3) on the XLA CPU backend using bf16 parameters, the compiler's current scheduling policy leads to a massive explosion in "Temporary" memory usage.

Because weights are provided as entry parameters, the XLA scheduler identifies them as "ready" at the start of execution ($t=0$). On the CPU backend (specifically when using oneDNN), weights often undergo a Layout Conversion to optimize for hardware-specific instructions like AVX-512 or AMX.

In an unrolled 48-layer transformer graph, the compiler hoists the layout conversions for all 48 layers to the very beginning of the program. Instead of converting and using weights sequentially (reusing only one buffer), the compiler allocates a concurrent "Temporary" buffer for every single converted weight (48 buffers!).

This results in a memory peak of ~38GB of extra temporary memory (on top of the original input weights ~22GB) for a model that should logically fit in a much smaller execution footprint.

Current Workarounds and Their Limitations

Currently, i have to "lie" to the compiler to force sequentiality. Two common "hacks" are used for that, both with downsides:

  • Arithmetic data data dependency: when doing a dot_general, we want the weights of a layer (rhs operand of the dot_general) to "just become available" when the lhs operand is. So we do something like: dependent_weight = weight + (lhs[0] * const(0)). It works because the multiplication by zero is not optimized away (in the CPU PJRT at least). But it incurs one a copy of the weights, and is brittle, likely non-portable.

  • stablehlo.while loops: wrapping layers in a loop, and leverage the while scheduling barrier. The downside is that this doesn't work if the parameters of the layers are passed as individual tensors. This can be overcome by concatenating them, and taking slices in the while loop, but only if they are the same shape. And if XLA implements loop unrolling, the implicit "scheduling barrier" would disappear, and things suddenly would start exploding in memory.

Proposal: a formal scheduling barrier (or control dependency) operator

We propose the addition of a dedicated operator—such as stablehlo.scheduling_barrier (more aligned in syntax with the existing optimizer_barrier), stablehlo.control_dependency, or stablehlo.identity_with_dependency—to the StableHLO specification.

It would look like:

%result = "stablehlo.scheduling_barrier"(%operand, %triggers...) 
    : (tensor<...>, tensor<...>) -> tensor<...>
  • %result acts as a pure identity function, returning %operand unchanged.
  • Scheduling Fence: The compiler is strictly forbidden from scheduling the production of %result (including any backend-specific layout conversions or buffer assignments for %operand) until the %trigger operand is live and available.

Other considerations

  • stablehlo.optimization_barrier: seems to only impact optimization (e.g., preventing $x + 0 \to x$), and not the scheduling of operations, hence it doesn't solve the issue (I tried) -- it is "transparent" for the "layout assignment pass".
  • stablehlo.after_all: it introduces the dependency but operates only on "tokens", which cannot be used as input in a dot_general operation.

Note: This issue is particularly critical for CPU backends where weights are large and RAM is a finite bottleneck.

Note 2: Apologies in advance if something like that already exists and I'm not aware of, or if I'm missing something obvious.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions