-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix #21105: JAX Backend - Custom Gradient Variable Capture Issue in keras.ops.custom_gradient #21783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Mayankvlog
wants to merge
10
commits into
keras-team:master
Choose a base branch
from
Mayankvlog:improve-docs-loss-functions
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+774
−26
Open
Fix #21105: JAX Backend - Custom Gradient Variable Capture Issue in keras.ops.custom_gradient #21783
Changes from 2 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
564e155
Document torch.compile limitations with EfficientNetV2 (Issue #21647)
Mayankvlog 0a4ec49
fix: resolve TensorFlow import error on Windows
Mayankvlog 8f86aef
Merge branch 'master' into improve-docs-loss-functions
Mayankvlog 9d20a57
Fix #21105: JAX custom_gradient Variable handling and linting errors
Mayankvlog 9b439c4
Fix test: use proper initializer object instead of string
Mayankvlog 03c7c1d
Fix EfficientNetV2 tests: use 224x224 input size and fix B0 import
Mayankvlog 26a1cde
Fix #21105: JAX backend custom_gradient with Variables and input_shap…
Mayankvlog 4b15f7d
feat: add LPIPS perceptual loss
Mayankvlog 22a3bf1
Merge branch 'keras-team:master' into improve-docs-loss-functions
Mayankvlog 5c925c0
fix(tf-remat): avoid passing kwargs to custom_gradient in graph mode;…
Mayankvlog File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # Fix for custom_gradient with JAX backend and Variables | ||
|
|
||
| ## Issue | ||
| GitHub Issue [#21105](https://github.com/keras-team/keras/issues/21105) | ||
|
|
||
| When using `@ops.custom_gradient` with the JAX backend, passing Keras Variables as arguments would cause a `TypeError: 'NoneType' object is not callable` during training. This occurred because JAX's `custom_gradient` would capture the Variable object itself instead of extracting its underlying tensor value. | ||
|
|
||
| ## Root Cause | ||
| The JAX backend's `custom_gradient` function was directly wrapping `jax.custom_gradient` without converting Variable objects to their values, unlike the `stop_gradient` function which already handled this correctly. | ||
|
|
||
| ## Solution | ||
| Modified `keras/src/backend/jax/core.py` to add a wrapper that automatically extracts `.value` from Variable objects before passing them to the user's custom gradient function. This is done using `tree.map_structure` to recursively handle nested structures. | ||
|
|
||
| ### Changes Made | ||
|
|
||
| **File: `keras/src/backend/jax/core.py`** | ||
|
|
||
| ```python | ||
| def custom_gradient(fun): | ||
| def wrapper(*args, **kwargs): | ||
| # Convert Variable objects to their values | ||
| def _convert_arg(arg): | ||
| if isinstance(arg, Variable): | ||
| return arg.value | ||
| return arg | ||
|
|
||
| args = tree.map_structure(_convert_arg, args) | ||
| kwargs = tree.map_structure(_convert_arg, kwargs) | ||
| return fun(*args, **kwargs) | ||
|
|
||
| return jax.custom_gradient(fun=wrapper) | ||
| ``` | ||
|
|
||
| **File: `keras/src/ops/core_test.py`** | ||
|
|
||
| Added `test_custom_gradient_with_variable()` to verify that Variables can be passed directly to custom_gradient functions without needing to manually add `.value`. | ||
|
|
||
| ## Testing | ||
|
|
||
| ### Run the specific test: | ||
| ```bash | ||
| pytest keras/src/ops/core_test.py::CoreOpsCorrectnessTest::test_custom_gradient_with_variable -v | ||
| ``` | ||
|
|
||
| ### Run all core ops tests: | ||
| ```bash | ||
| pytest keras/src/ops/core_test.py -v | ||
| ``` | ||
|
|
||
| ## Example Usage | ||
|
|
||
| Before the fix, you needed to manually extract `.value`: | ||
|
|
||
| ```python | ||
| @ops.custom_gradient | ||
| def roundpass(x, log_scaling): | ||
| scaling = ops.exp(log_scaling) | ||
| rounded = ops.round(x * scaling) / scaling | ||
|
|
||
| def grad(*args, upstream=None): | ||
| if upstream is None: | ||
| (upstream,) = args | ||
| return upstream, ops.zeros_like(log_scaling) | ||
|
|
||
| return rounded, grad | ||
|
|
||
| class QuantizedLayer(layers.Layer): | ||
| def call(self, x): | ||
| # Workaround: manually add .value | ||
| return roundpass(x, self.log_scaling.value) | ||
| ``` | ||
|
|
||
| After the fix, Variables work directly: | ||
|
|
||
| ```python | ||
| class QuantizedLayer(layers.Layer): | ||
| def call(self, x): | ||
| # Works automatically now! | ||
| return roundpass(x, self.log_scaling) | ||
| ``` | ||
|
|
||
| ## Impact | ||
| - ✅ Fixes the TypeError when Variables are passed to custom_gradient functions | ||
| - ✅ Makes JAX backend behavior consistent with user expectations | ||
| - ✅ Aligns with how `stop_gradient` already handles Variables | ||
| - ✅ Backward compatible - existing code using `.value` workaround still works | ||
| - ✅ No performance impact - conversion happens once at function decoration time |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # Torch Backend jit_compile Limitations | ||
|
|
||
| ## Issue #21647: jit_compile=True with EfficientNetV2 on torch backend | ||
|
|
||
| ### Problem | ||
| When using `jit_compile=True` with certain Keras models (especially EfficientNetV2) on the torch backend, you may encounter `InternalTorchDynamoError` or `RuntimeError` related to torch.compile being unable to trace optree operations. | ||
|
|
||
| ### Root Cause | ||
| Keras uses tree operations (from optree or torch._pytree) for handling nested structures. When `jit_compile=True` is enabled, PyTorch's torch.compile attempts to trace through all Python operations, including these tree utilities. However, torch.compile has limitations with certain C/C++ extensions and symbolic operations. | ||
|
|
||
| ### Error Messages | ||
| - **GPU**: `InternalTorchDynamoError: TypeError: '<' not supported between instances of 'NoneType' and 'int'` | ||
| - **CPU**: `RuntimeError: TypeError: cannot determine truth value of Relational` | ||
|
|
||
| ### Workarounds | ||
|
|
||
| #### Option 1: Disable JIT Compilation (Recommended) | ||
| ```python | ||
| model.compile( | ||
| optimizer=Adam(learning_rate=0.001), | ||
| loss=CategoricalCrossentropy(), | ||
| metrics=['accuracy'], | ||
| jit_compile=False # or omit this parameter | ||
| ) | ||
| ``` | ||
|
|
||
| #### Option 2: Use a Different Backend | ||
| Switch to TensorFlow or JAX backend which have better jit_compile support: | ||
| ```python | ||
| import os | ||
| os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax" | ||
| ``` | ||
|
|
||
| #### Option 3: Use Fixed Input Shapes | ||
| If you must use jit_compile with torch, ensure all input shapes are fixed (no None dimensions): | ||
| ```python | ||
| base_model = EfficientNetV2B2( | ||
| include_top=False, | ||
| input_shape=(224, 224, 3), # Fixed shape, no None | ||
| pooling='avg', | ||
| weights=None | ||
| ) | ||
| ``` | ||
|
|
||
| ### Status | ||
| This is a known limitation of torch.compile when working with complex nested structures. The PyTorch team is aware of limitations with certain patterns and continues to improve torch.compile support. | ||
|
|
||
| ### Related Issues | ||
| - PyTorch Issue: torch.compile limitations with pytree operations | ||
| - Keras Issue #21647 | ||
|
|
||
| ### Future Improvements | ||
| Potential solutions being explored: | ||
| 1. Add torch.compile skip decorators for tree operations | ||
| 2. Use torch.compiler.disable() context for specific operations | ||
| 3. Refactor to use pure torch operations where possible |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| """Test for Issue #21647: jit_compile=True with EfficientNetV2 on torch backend.""" | ||
|
|
||
| import os | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from keras.src import backend | ||
| from keras.src import layers | ||
| from keras.src import models | ||
| from keras.src import testing | ||
| from keras.src.applications import EfficientNetV2B2 | ||
| from keras.src.losses import CategoricalCrossentropy | ||
| from keras.src.optimizers import Adam | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| backend.backend() != "torch", | ||
| reason="This test is specifically for torch backend", | ||
| ) | ||
| class EfficientNetV2JitCompileTest(testing.TestCase): | ||
| """Test EfficientNetV2 models with jit_compile=True on torch backend.""" | ||
|
|
||
| def test_efficientnet_v2_b2_with_jit_compile(self): | ||
| """Test that EfficientNetV2B2 works with jit_compile=True.""" | ||
| num_classes = 10 | ||
| batch_size = 2 # Small batch for testing | ||
| steps_per_epoch = 1 | ||
| epochs = 1 | ||
|
|
||
| # Generate random data (small for testing) | ||
| data_shape = (64, 64, 3) # Smaller image size for faster testing | ||
| x_train = np.random.rand( | ||
| batch_size * steps_per_epoch, *data_shape | ||
| ).astype(np.float32) | ||
| y_train = np.random.randint( | ||
| 0, num_classes, size=(batch_size * steps_per_epoch,) | ||
| ) | ||
| y_train = np.eye(num_classes)[y_train] | ||
|
|
||
| # Create model | ||
| base_model = EfficientNetV2B2( | ||
| include_top=False, | ||
| input_shape=(64, 64, 3), # Fixed shape for jit_compile | ||
| pooling="avg", | ||
| include_preprocessing=True, | ||
| weights=None, # Don't load weights for faster testing | ||
| ) | ||
| x = base_model.output | ||
| output = layers.Dense(num_classes, activation="softmax")(x) | ||
| model = models.Model(inputs=base_model.input, outputs=output) | ||
|
|
||
| # Compile with jit_compile=True | ||
| model.compile( | ||
| optimizer=Adam(learning_rate=0.001), | ||
| loss=CategoricalCrossentropy(), | ||
| metrics=["accuracy"], | ||
| jit_compile=True, | ||
| ) | ||
|
|
||
| # This should not raise InternalTorchDynamoError | ||
| history = model.fit( | ||
| x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=0 | ||
| ) | ||
|
|
||
| # Basic sanity check | ||
| self.assertIsNotNone(history) | ||
| self.assertIn("loss", history.history) | ||
|
|
||
| def test_efficientnet_v2_b0_with_jit_compile(self): | ||
| """Test that EfficientNetV2B0 also works with jit_compile=True.""" | ||
| from keras.src.applications import EfficientNetV2B0 | ||
|
|
||
| num_classes = 5 | ||
| batch_size = 2 | ||
|
|
||
| # Generate random data | ||
| x_train = np.random.rand(batch_size, 64, 64, 3).astype(np.float32) | ||
| y_train = np.eye(num_classes)[ | ||
| np.random.randint(0, num_classes, size=(batch_size,)) | ||
| ] | ||
|
|
||
| # Create model | ||
| base_model = EfficientNetV2B0( | ||
| include_top=False, | ||
| input_shape=(64, 64, 3), | ||
| pooling="avg", | ||
| weights=None, | ||
| ) | ||
| x = base_model.output | ||
| output = layers.Dense(num_classes, activation="softmax")(x) | ||
| model = models.Model(inputs=base_model.input, outputs=output) | ||
|
|
||
| # Compile with jit_compile=True | ||
| model.compile( | ||
| optimizer=Adam(learning_rate=0.001), | ||
| loss=CategoricalCrossentropy(), | ||
| metrics=["accuracy"], | ||
| jit_compile=True, | ||
| ) | ||
|
|
||
| # Should work without errors | ||
| predictions = model.predict(x_train, verbose=0) | ||
| self.assertEqual(predictions.shape, (batch_size, num_classes)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation redefines the
_convert_argfunction on every call to the decorated function, which is inefficient. This helper function can be defined once outside thewrapperto avoid this overhead. Additionally, renaming it to_convert_variable_to_valuewould make its purpose clearer, following the style guide's preference for descriptive names.1Style Guide References
Footnotes
Argument names should be intuitive and easy to remember, and their meaning should be clear from the name. Overly generic names should be avoided. ↩