Skip to content

Replace cast+multiply masking with ops.where to reduce memory#22392

Open
rstar327 wants to merge 3 commits intokeras-team:masterfrom
rstar327:fix-cast-multiply-masking-memory
Open

Replace cast+multiply masking with ops.where to reduce memory#22392
rstar327 wants to merge 3 commits intokeras-team:masterfrom
rstar327:fix-cast-multiply-masking-memory

Conversation

@rstar327
Copy link
Contributor

Summary

  • Replace tensor * ops.cast(bool_mask, tensor.dtype) with ops.where(bool_mask, tensor, 0) across 12 files
  • On the torch backend these cast+multiply patterns run eagerly, so every intermediate sticks around in memory. ops.where skips the float copy of the mask entirely.
  • Fixes the layer forward passes (attention, batch norm, masking, pooling, activations, constraints) and losses/metrics (sparse categorical crossentropy, loss masking, IoU, confusion metrics, R2)
  • Added del on batch norm intermediates for earlier garbage collection

Fixes #22386

@codecov-commenter
Copy link

codecov-commenter commented Mar 10, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.97%. Comparing base (3c4f2cc) to head (a2022ed).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #22392   +/-   ##
=======================================
  Coverage   82.97%   82.97%           
=======================================
  Files         596      596           
  Lines       66330    66336    +6     
  Branches    10334    10334           
=======================================
+ Hits        55038    55044    +6     
  Misses       8663     8663           
  Partials     2629     2629           
Flag Coverage Δ
keras 82.80% <100.00%> (+<0.01%) ⬆️
keras-jax 60.74% <100.00%> (+<0.01%) ⬆️
keras-numpy 54.96% <100.00%> (+<0.01%) ⬆️
keras-openvino 49.25% <91.42%> (+<0.01%) ⬆️
keras-tensorflow 61.97% <100.00%> (+<0.01%) ⬆️
keras-torch 60.79% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves memory efficiency across the Keras codebase by refactoring how boolean masks are applied to tensors. By transitioning from a cast-and-multiply pattern to using ops.where, the changes eliminate the creation of unnecessary intermediate float tensors, leading to reduced memory consumption, particularly beneficial for the Torch backend. This optimization impacts core functionalities in layers, losses, and metrics, ensuring more efficient resource utilization.

Highlights

  • Memory Optimization: Replaced tensor * ops.cast(bool_mask, tensor.dtype) with ops.where(bool_mask, tensor, 0) across 12 files to reduce memory footprint, especially on the Torch backend by avoiding intermediate float copies of masks.
  • Affected Components: Applied memory optimizations to various Keras components including activations, constraints, attention mechanisms, masking layers, batch normalization, pooling layers, and several loss and metric functions (sparse categorical crossentropy, loss masking, IoU, confusion metrics, R2).
  • Garbage Collection: Added explicit del statements for intermediate tensors in batch normalization calculations to facilitate earlier garbage collection.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras/src/activations/activations.py
    • Replaced a cast and multiply operation with ops.where for thresholding in static_call.
  • keras/src/constraints/constraints.py
    • Updated the NonNeg constraint to use ops.where instead of ops.multiply with a cast boolean mask.
  • keras/src/layers/attention/attention.py
    • Modified padding mask application to use ops.where for scores.
    • Changed attention output masking to use ops.where.
  • keras/src/layers/core/masking.py
    • Replaced the multiplication with a cast boolean mask with ops.where for setting masked outputs to zero.
  • keras/src/layers/normalization/batch_normalization.py
    • Refactored masked input calculations to use ops.where for memory efficiency.
    • Added del statements for intermediate tensors (masked_inputs, masked_input_sum, difference, squared_difference, weighted_distsq) to enable earlier garbage collection.
  • keras/src/layers/pooling/global_average_pooling1d.py
    • Replaced masked multiplication with ops.where for inputs and adjusted mask casting for sum calculation.
  • keras/src/losses/loss.py
    • Removed redundant ops.cast for mask and refactored apply_mask to use ops.where for creating float_mask.
  • keras/src/losses/losses.py
    • Updated sparse_categorical_crossentropy to use ops.where for applying valid_mask to y_true and y_pred.
  • keras/src/metrics/confusion_metrics.py
    • Replaced a multiplication with a cast boolean mask with ops.where in _find_max_under_constraint.
  • keras/src/metrics/iou_metrics.py
    • Modified update_state to use ops.where for applying valid_mask to y_true, y_pred, and sample_weight.
  • keras/src/metrics/metrics_utils.py
    • Updated weighted_assign_add to use ops.where for conditional weight application.
  • keras/src/metrics/regression_metrics.py
    • Simplified weighted_y_true calculation and total_mse update by optimizing sample_weight casting.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the cast and multiply pattern with ops.where for masking across various files to improve memory efficiency, particularly for the Torch backend. The changes are generally correct and well-implemented. I've identified a minor issue in keras/src/constraints/constraints.py where the change could lead to unintended data type promotion, and I've provided a suggestion to address it. The addition of del statements for intermediate tensors in batch_normalization.py is also a good optimization for memory management.

def __call__(self, w):
w = backend.convert_to_tensor(w)
return ops.multiply(w, ops.greater_equal(w, 0.0))
return ops.where(ops.greater_equal(w, 0.0), w, 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using 0.0 here could cause unintended dtype promotion if the input tensor w is of an integer type. The original implementation preserved the dtype. Using 0 will be correctly cast to either float or int depending on the dtype of w, preserving the original behavior and making the constraint more robust.

An even more idiomatic way to implement this constraint would be ops.maximum(w, 0).

Suggested change
return ops.where(ops.greater_equal(w, 0.0), w, 0.0)
return ops.where(ops.greater_equal(w, 0.0), w, 0)

Copy link
Contributor

@MarcosAsh MarcosAsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea, a few things that need fixing before this can land. Thanks for the work on this pr!

OpenVINO failures (all 10)

Every failure is the same:

Argument 0 must have boolean element type (element type: f32)

OpenVINO's Select op requires the condition to be strictly boolean. In several places the mask argument can arrive as float or int after round-tripping through backend masking logic see inline comments.

mask_broadcasted, ops.shape(inputs)
)
weighted_inputs = broadcasted_mask * inputs
mask_broadcasted = ops.cast(mask_broadcasted, "bool")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of the OpenVINO failures. mask from the layer masking system can arrive as float, so mask_broadcasted inherits that dtype. OpenVINO's Select op (which backs ops.where) requires argument 0 to be strictly boolean.

Needs a mask = ops.cast(mask, "bool") guard at the top of this block, before the expand/broadcast.

mask, 2 if self.data_format == "channels_last" else 1
)
inputs *= mask
mask_expanded = ops.cast(mask_expanded, "bool")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same OpenVINO issue. mask comes in as i32 here, not boolean. Needs mask = ops.cast(mask, "bool") before the expand_dims.

mask *= ops.divide_no_nan(total, valid)
valid = ops.sum(ops.cast(mask, dtype=dtype)) # May be 0!
mask_weight = ops.divide_no_nan(total, valid)
float_mask = ops.where(ops.cast(mask, "bool"), mask_weight, 0)
Copy link
Contributor

@MarcosAsh MarcosAsh Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same OpenVINO issue mask can be float after squeeze_or_expand_to_same_rank. Needs a bool cast before being used as the ops.where condition.

mask_weight = ops.divide_no_nan(total, valid)
float_mask = ops.where(ops.cast(mask, "bool"), mask_weight, 0)
else:
float_mask = ops.cast(mask, dtype=dtype)
Copy link
Contributor

@MarcosAsh MarcosAsh Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This else branch still does ops.cast(mask, dtype=dtype) and then sample_weight *= float_mask below that's the original cast+multiply pattern this PR is trying to remove. Should use ops.where here too.

@@ -567,7 +567,11 @@ def update_confusion_matrix_variables(
def weighted_assign_add(label, pred, weights, var):
label_and_pred = ops.cast(ops.logical_and(label, pred), dtype=var.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When weights is not None, this cast allocates a float tensor that gets immediately thrown away on the next line. And ops.logical_and(label, pred) is recomputed inside the ops.where below. Could keep the boolean and branch instead:

label_and_pred = ops.logical_and(label, pred)
if weights is not None:
    result = ops.where(label_and_pred, ops.cast(weights, dtype=var.dtype), 0)
else:
    result = ops.cast(label_and_pred, dtype=var.dtype)
var.assign(var + ops.sum(result, 1))

@MarcosAsh
Copy link
Contributor

Sorry just saw you fixed those changes the fixes I flagged on loss.py and metrics_utils.py still exist none the less.

@rstar327 rstar327 force-pushed the fix-cast-multiply-masking-memory branch from 905bf00 to a2022ed Compare March 10, 2026 17:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Torch backend cast+multiply masking pattern causes unnecessary memory allocations

4 participants