Skip to content

Commit e798d98

Browse files
Nush395Torax team
authored andcommitted
Fix cache misses on second simulation step due to weakly type Phi_b_dot.
This affects cases with calcphibdot=True. PiperOrigin-RevId: 834316997
1 parent 42aac33 commit e798d98

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

torax/_src/geometry/geometry_provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,11 @@ def _get_geometry_base(
232232
continue
233233
if attr.name == 'Phi_b_dot':
234234
if self.calcphibdot:
235-
kwargs[attr.name] = _Phi_b_grad(self.Phi_face, t)
235+
kwargs[attr.name] = jnp.asarray(
236+
_Phi_b_grad(self.Phi_face, t), dtype=jax_utils.get_dtype()
237+
)
236238
else:
237-
kwargs[attr.name] = 0.0
239+
kwargs[attr.name] = jnp.zeros((), dtype=jax_utils.get_dtype())
238240
continue
239241
provider_attr = getattr(self, attr.name)
240242
if isinstance(

torax/_src/geometry/standard_geometry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torax._src import array_typing
3333
from torax._src import constants
3434
from torax._src import interpolated_param
35+
from torax._src import jax_utils
3536
from torax._src.geometry import geometry
3637
from torax._src.geometry import geometry_loader
3738
from torax._src.geometry import geometry_provider
@@ -1385,7 +1386,7 @@ def build_standard_geometry(
13851386
# always initialize Phibdot as zero. It will be replaced once both geo_t
13861387
# and geo_t_plus_dt are provided, and set to be the same for geo_t and
13871388
# geo_t_plus_dt for each given time interval.
1388-
Phi_b_dot=np.asarray(0.0),
1389+
Phi_b_dot=np.zeros((), dtype=jax_utils.get_int_dtype()),
13891390
_z_magnetic_axis=intermediate.z_magnetic_axis,
13901391
diverted=intermediate.diverted,
13911392
connection_length_target=intermediate.connection_length_target,

0 commit comments

Comments
 (0)