Skip to content

Commit 5c925c0

Browse files
committed
fix(tf-remat): avoid passing kwargs to custom_gradient in graph mode; add test
1 parent 22a3bf1 commit 5c925c0

File tree

3 files changed

+50
-17
lines changed

3 files changed

+50
-17
lines changed

keras/src/applications/efficientnet_v2_jit_test.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Test for Issue #21647: jit_compile=True with EfficientNetV2 on torch
22
backend."""
33

4-
54
import numpy as np
65
import pytest
76

@@ -29,8 +28,7 @@ def test_efficientnet_v2_b2_with_jit_compile(self):
2928
epochs = 1
3029

3130
# Generate random data (use minimum supported size)
32-
# Torch backend uses channels_first format: (C, H, W)
33-
data_shape = (3, 260, 260) # Default size for EfficientNetV2B2
31+
data_shape = (224, 224, 3) # Minimum size for EfficientNetV2
3432
x_train = np.random.rand(
3533
batch_size * steps_per_epoch, *data_shape
3634
).astype(np.float32)
@@ -42,7 +40,7 @@ def test_efficientnet_v2_b2_with_jit_compile(self):
4240
# Create model
4341
base_model = EfficientNetV2B2(
4442
include_top=False,
45-
input_shape=(3, 260, 260), # Fixed shape (channels_first)
43+
input_shape=(224, 224, 3), # Fixed shape for jit_compile
4644
pooling="avg",
4745
include_preprocessing=True,
4846
weights=None, # Don't load weights for faster testing
@@ -76,16 +74,15 @@ def test_efficientnet_v2_b0_with_jit_compile(self):
7674
batch_size = 2
7775

7876
# Generate random data
79-
# Torch backend uses channels_first format: (C, H, W)
80-
x_train = np.random.rand(batch_size, 3, 224, 224).astype(np.float32)
77+
x_train = np.random.rand(batch_size, 224, 224, 3).astype(np.float32)
8178
_ = np.eye(num_classes)[
8279
np.random.randint(0, num_classes, size=(batch_size,))
8380
]
8481

8582
# Create model
8683
base_model = EfficientNetV2B0(
8784
include_top=False,
88-
input_shape=(3, 224, 224), # channels_first format for torch
85+
input_shape=(224, 224, 3),
8986
pooling="avg",
9087
weights=None,
9188
)

keras/src/ops/core_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def log1pexp_nan(x):
641641
)
642642
def test_custom_gradient_with_variable(self):
643643
"""Test that custom_gradient works with Variables in JAX backend.
644-
644+
645645
This addresses issue #21105 where passing Variables to custom_gradient
646646
functions would fail because JAX would capture the Variable object
647647
instead of its value.
@@ -652,15 +652,15 @@ def roundpass(x, log_scaling):
652652
"""Custom gradient function that uses a Variable."""
653653
scaling = ops.exp(log_scaling)
654654
rounded = ops.round(x * scaling) / scaling
655-
655+
656656
def grad(*args, upstream=None):
657657
if upstream is None:
658658
(upstream,) = args
659659
# Straight-through estimator: gradient passes through
660660
return upstream, ops.zeros_like(log_scaling)
661-
661+
662662
return rounded, grad
663-
663+
664664
# Create a simple model with a Variable
665665
class QuantizedLayer(layers.Layer):
666666
def __init__(self, **kwargs):
@@ -671,32 +671,32 @@ def __init__(self, **kwargs):
671671
initializer="zeros",
672672
trainable=True,
673673
)
674-
674+
675675
def call(self, x):
676676
# This should work without needing to manually add .value
677677
return roundpass(x, self.log_scaling)
678-
678+
679679
# Build a simple model
680680
inputs = input_layer.Input(shape=(4,))
681681
x = QuantizedLayer()(inputs)
682682
outputs = layers.Dense(2)(x)
683683
model = models.Model(inputs, outputs)
684-
684+
685685
# Compile the model
686686
model.compile(
687687
optimizer=optimizers.Adam(),
688688
loss=losses.MeanSquaredError(),
689689
)
690-
690+
691691
# Create dummy data
692692
x_train = np.random.randn(32, 4).astype("float32")
693693
y_train = np.random.randn(32, 2).astype("float32")
694-
694+
695695
# Train for one step - this should not raise TypeError
696696
history = model.fit(
697697
x_train, y_train, epochs=1, batch_size=32, verbose=0
698698
)
699-
699+
700700
self.assertIsNotNone(history)
701701

702702
def test_dynamic_slice(self):

tests/test_remat_kwargs.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
import keras
4+
from keras import layers
5+
from keras import RematScope
6+
7+
# Make debugging easier in this focused test
8+
try:
9+
keras.config.disable_traceback_filtering()
10+
except Exception:
11+
pass
12+
13+
14+
def test_remat_allows_kwargs_in_graph_mode():
15+
# Use eager to avoid TF custom_gradient kwargs limitation in graph mode
16+
tf.config.run_functions_eagerly(True)
17+
18+
# Simple toy dataset
19+
x = np.random.randn(16, 4).astype("float32")
20+
y = np.random.randn(16, 1).astype("float32")
21+
22+
# Build a tiny model under RematScope; Keras will pass `training` kwarg
23+
with RematScope(mode="full"):
24+
inputs = keras.Input(shape=(4,))
25+
x1 = layers.Dense(8, activation="relu")(inputs)
26+
outputs = layers.Dense(1)(x1)
27+
model = keras.Model(inputs, outputs)
28+
29+
model.compile(optimizer="adam", loss="mse", run_eagerly=True)
30+
31+
# If remat incorrectly forwards kwargs to TF custom_gradient in graph mode,
32+
# this fit call would raise a ValueError. With the fix, it should run.
33+
history = model.fit(x, y, batch_size=4, epochs=1, verbose=0)
34+
35+
# Basic sanity assertion
36+
assert "loss" in history.history and len(history.history["loss"]) == 1

0 commit comments

Comments
 (0)