Skip to content

Torch backend cast+multiply masking pattern causes unnecessary memory allocations #22386

@MarcosAsh

Description

@MarcosAsh

Summary

Following up from #22379, where Softmax hit OOM on the torch backend because of how masking was done. The problem was multiply(tensor, cast(mask, dtype)). On jax and tf this gets fused away during compilation, but torch runs eagerly so every intermediate sticks around in memory. The fix in #22379 was simple: use where instead, which skips the float copy of the mask entirely.

I went through the codebase and found the same pattern in a bunch of other places. Listing them here so they can be picked off over time.

Layer forward passes (big tensors, most likely to cause OOM)

  • layers/attention/attention.py:181 -- casts padding_mask to float and multiplies with attention scores, which can get very arge
  • layers/attention/attention.py:240 -- same thing with q_mask on the attention output
  • layers/normalization/batch_normalization.py:442-464 -- this one is probably the worst after Softmax. Casts the mask, broadcasts it to input shape, and multiplies twice (once for mean, once for variance). Keeps 4+ full-size intermediates alive at the same time
  • layers/core/masking.py:65 -- inputs * backend.cast(boolean_mask, dtype=inputs.dtype)
  • layers/pooling/global_average_pooling1d.py:74-78 -- casts mask to float then multiplies with inputs
  • activations/activations.py:93-95 -- relu with custom threshold, casts a bool comparison to float and multiplies
  • constraints/constraints.py:127 -- NonNeg multiplies weights by ops.greater_equal(w, 0.0), implicit bool-to-float cast

Losses and metrics (smaller tensors, less urgent)

  • losses/losses.py:2349-2353 -- sparse categorical crossentropy ignore_class, two cast+multiply calls
  • losses/loss.py:200 -- apply_mask casts the mask then multiplies with sample_weight
  • metrics/iou_metrics.py:131-135 -- casts the same valid_mask three separate times
  • metrics/confusion_metrics.py:670 -- cast+multiply on feasibility mask
  • metrics/metrics_utils.py:568-570 -- confusion matrix weighted update
  • metrics/regression_metrics.py:511,519 -- R2 metric, casts sample_weight twice

The fix

For most of these, it's just replacing:
tensor * ops.cast(bool_mask, tensor.dtype)
with:
ops.where(bool_mask, tensor, 0)

For places where a lot of intermediates pile up (like batch norm), you can also free things early with del.

Note: #22307 the memory profiling part (phase 3) would catch exactly this kind of thing so this doesnt happen in the future .

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions