Skip to content

Commit 0782562

Browse files
hamelphiTorax team
authored andcommitted
Replace functools.cached_property with property in CellVariable.
Using functools.cached_property can cause issues with JAX transformations like jax.lax.cond, as it modifies the object's internal state upon first access, potentially leading to UnexpectedTracerError. Switching to standard properties avoids these side effects within JAX contexts. This unfortunately means that we won't be able to benefit from using `cached_properties` in CellVariable. The operations to get these properties are simple, so the impact to performance should be negligible. I added a unittest that exposed the tracer leak to prevent regression in the future. PiperOrigin-RevId: 861205793
1 parent ed6b853 commit 0782562

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

torax/_src/fvm/cell_variable.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
[https://www.ctcms.nist.gov/fipy/]
2020
"""
2121
import dataclasses
22-
import functools
2322

2423
import chex
2524
import jax
@@ -131,7 +130,7 @@ class CellVariable:
131130
# Can't make the above default values be jax zeros because that would be a
132131
# call to jax before absl.app.run
133132

134-
@functools.cached_property
133+
@property
135134
def cell_centers(self) -> jt.Float[chex.Array, 'cell']:
136135
"""Locations of the cell centers."""
137136
return (self.face_centers[..., 1:] + self.face_centers[..., :-1]) / 2.0
@@ -281,7 +280,7 @@ def left_face_value(self) -> jt.Float[chex.Array, '']:
281280
value = self.value[..., 0:1]
282281
return value
283282

284-
@functools.cached_property
283+
@property
285284
def right_face_value(self) -> jt.Float[chex.Array, '']:
286285
"""Calculates the value of the rightmost face."""
287286
if self.right_face_constraint is not None:

torax/_src/fvm/tests/cell_variable_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from absl.testing import absltest
1515
from absl.testing import parameterized
1616
import chex
17+
import jax
1718
from jax import numpy as jnp
1819
import numpy as np
1920
from torax._src.fvm import cell_variable
@@ -441,6 +442,43 @@ def test_compute_face_grad_compared_to_forward_diff(self):
441442
accurate_method_error, forward_difference_error
442443
)
443444

445+
def test_jax_tracer_leak_in_cond(self):
446+
"""Tests that accessing properties inside jax.cond does not leak tracers."""
447+
# This test verifies that we don't have side effects (like cached_property
448+
# updates) that leak tracers from inside a jax.cond to the outer scope.
449+
450+
n = 10
451+
value = jnp.zeros((n,))
452+
face_centers = jnp.linspace(0, 1, n + 1)
453+
454+
# Create the object OUTSIDE the cond
455+
var = cell_variable.CellVariable(value=value, face_centers=face_centers)
456+
457+
def true_branch(_):
458+
# Access properties INSIDE the cond
459+
# If these are cached_property and write to var.__dict__, it might leak
460+
_ = var.cell_centers
461+
_ = var.right_face_value
462+
return var.cell_centers
463+
464+
def false_branch(_):
465+
return jnp.zeros((n,))
466+
467+
jax.lax.cond(jnp.array(True), true_branch, false_branch, None)
468+
469+
# Access the property AFTER the cond to ensure no leaked tracer is stored.
470+
centers = var.cell_centers
471+
right_face = var.right_face_value
472+
473+
# Force use of the values in a way that triggers UnexpectedTracerError for
474+
# leaked tracers
475+
val_check = centers + 1.0
476+
right_check = right_face + 1.0
477+
478+
# Casting to numpy also forces evaluation
479+
np.array(val_check)
480+
np.array(right_check)
481+
444482

445483
if __name__ == '__main__':
446484
absltest.main()

0 commit comments

Comments
 (0)