Skip to content

Commit b17dc86

Browse files
Verify GPU memory consistency for Huber loss (delta=0.5)
1 parent 6d06085 commit b17dc86

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

keras/src/losses/losses.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1972,7 +1972,9 @@ def huber(y_true, y_pred, delta=1.0):
19721972
delta = ops.convert_to_tensor(delta, dtype=y_pred.dtype)
19731973
error = ops.subtract(y_pred, y_true)
19741974
abs_error = ops.abs(error)
1975-
half = ops.convert_to_tensor(0.5, dtype=abs_error.dtype)
1975+
half = ops.cast(ops.convert_to_tensor(0.5), dtype=abs_error.dtype)
1976+
delta = ops.cast(delta, dtype=abs_error.dtype)
1977+
19761978
return ops.mean(
19771979
ops.where(
19781980
abs_error <= delta,

keras/src/losses/losses_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,35 @@ def test_dtype_arg(self):
708708
loss = h_obj(self.y_true, self.y_pred)
709709
self.assertDType(loss, "bfloat16")
710710

711+
def test_huber_memory_usage_debug_05(self):
712+
import tensorflow as tf, numpy as np, keras
713+
714+
print("\n[Huber GPU Memory Debug: delta=0.5]")
715+
gpus = tf.config.experimental.list_physical_devices('GPU')
716+
if not gpus:
717+
print("No GPU found. Skipping test.")
718+
return
719+
try:
720+
for gpu in gpus:
721+
tf.config.experimental.set_memory_growth(gpu, True)
722+
except RuntimeError as e:
723+
print(f"[Info] GPU memory growth already set or GPU initialized: {e}")
724+
725+
x = np.random.rand(1000, 1)
726+
y = ((3 * x) + 2) + np.random.randn(1000, 1)
727+
huber_loss = keras.losses.Huber(delta=0.5)
728+
loss = huber_loss(y, x)
729+
print(f"Huber loss: {loss.numpy():.6f}")
730+
731+
memory = sum(tf.config.experimental.get_memory_info(f'GPU:{i}')['current'] for i in range(len(gpus)))
732+
print(f"GPU memory usage: {memory} bytes")
733+
734+
# sanity check for stable GPU usage (adjust threshold as needed)
735+
assert memory > 0, (
736+
f"GPU memory not allocated or usage is zero. "
737+
f"Current usage: {memory} bytes"
738+
)
739+
711740

712741
class LogCoshTest(testing.TestCase):
713742
def setup(self):

0 commit comments

Comments
 (0)