Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions torax/_src/fvm/cell_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
[https://www.ctcms.nist.gov/fipy/]
"""
import dataclasses
import functools

import chex
import jax
Expand Down Expand Up @@ -130,7 +131,7 @@ class CellVariable:
# Can't make the above default values be jax zeros because that would be a
# call to jax before absl.app.run

@property
@functools.cached_property
def cell_centers(self) -> jt.Float[chex.Array, 'cell']:
"""Locations of the cell centers."""
return (self.face_centers[..., 1:] + self.face_centers[..., :-1]) / 2.0
Expand Down Expand Up @@ -280,7 +281,7 @@ def left_face_value(self) -> jt.Float[chex.Array, '']:
value = self.value[..., 0:1]
return value

@property
@functools.cached_property
def right_face_value(self) -> jt.Float[chex.Array, '']:
"""Calculates the value of the rightmost face."""
if self.right_face_constraint is not None:
Expand Down
38 changes: 0 additions & 38 deletions torax/_src/fvm/tests/cell_variable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
from jax import numpy as jnp
import numpy as np
from torax._src.fvm import cell_variable
Expand Down Expand Up @@ -442,43 +441,6 @@ def test_compute_face_grad_compared_to_forward_diff(self):
accurate_method_error, forward_difference_error
)

def test_jax_tracer_leak_in_cond(self):
"""Tests that accessing properties inside jax.cond does not leak tracers."""
# This test verifies that we don't have side effects (like cached_property
# updates) that leak tracers from inside a jax.cond to the outer scope.

n = 10
value = jnp.zeros((n,))
face_centers = jnp.linspace(0, 1, n + 1)

# Create the object OUTSIDE the cond
var = cell_variable.CellVariable(value=value, face_centers=face_centers)

def true_branch(_):
# Access properties INSIDE the cond
# If these are cached_property and write to var.__dict__, it might leak
_ = var.cell_centers
_ = var.right_face_value
return var.cell_centers

def false_branch(_):
return jnp.zeros((n,))

jax.lax.cond(jnp.array(True), true_branch, false_branch, None)

# Access the property AFTER the cond to ensure no leaked tracer is stored.
centers = var.cell_centers
right_face = var.right_face_value

# Force use of the values in a way that triggers UnexpectedTracerError for
# leaked tracers
val_check = centers + 1.0
right_check = right_face + 1.0

# Casting to numpy also forces evaluation
np.array(val_check)
np.array(right_check)


if __name__ == '__main__':
absltest.main()
13 changes: 11 additions & 2 deletions torax/_src/transport_model/tglf_based_transport_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ def _prepare_tglf_inputs(
# Normalized toroidal ExB velocity Doppler shift gradient.
# Calculated on the face grid.
# https://gacode.io/tglf/tglf_list.html#vexb-shear
def _get_v_ExB_shear():
def _get_v_ExB_shear(
core_profiles: state.CoreProfiles,
geo: geometry.Geometry,
poloidal_velocity_multiplier: array_typing.FloatScalar,
):
v_ExB, _, _ = rotation.calculate_rotation(
T_i=core_profiles.T_i,
psi=core_profiles.psi,
Expand Down Expand Up @@ -345,7 +349,12 @@ def _get_v_ExB_shear():
v_ExB_shear = jax.lax.cond(
transport.use_rotation,
_get_v_ExB_shear,
lambda: jnp.zeros_like(core_profiles.q_face),
lambda core_profiles, geo, poloidal_velocity_multiplier: jnp.zeros_like(
core_profiles.q_face
),
core_profiles,
geo,
poloidal_velocity_multiplier,
)

return TGLFInputs(
Expand Down
Loading