File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed
Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -232,9 +232,11 @@ def _get_geometry_base(
232232 continue
233233 if attr .name == 'Phi_b_dot' :
234234 if self .calcphibdot :
235- kwargs [attr .name ] = _Phi_b_grad (self .Phi_face , t )
235+ kwargs [attr .name ] = jnp .asarray (
236+ _Phi_b_grad (self .Phi_face , t ), dtype = jax_utils .get_dtype ()
237+ )
236238 else :
237- kwargs [attr .name ] = 0.0
239+ kwargs [attr .name ] = jnp . zeros ((), dtype = jax_utils . get_dtype ())
238240 continue
239241 provider_attr = getattr (self , attr .name )
240242 if isinstance (
Original file line number Diff line number Diff line change 3232from torax ._src import array_typing
3333from torax ._src import constants
3434from torax ._src import interpolated_param
35+ from torax ._src import jax_utils
3536from torax ._src .geometry import geometry
3637from torax ._src .geometry import geometry_loader
3738from torax ._src .geometry import geometry_provider
@@ -1385,7 +1386,7 @@ def build_standard_geometry(
13851386 # always initialize Phibdot as zero. It will be replaced once both geo_t
13861387 # and geo_t_plus_dt are provided, and set to be the same for geo_t and
13871388 # geo_t_plus_dt for each given time interval.
1388- Phi_b_dot = np .asarray ( 0.0 ),
1389+ Phi_b_dot = np .zeros ((), dtype = jax_utils . get_int_dtype () ),
13891390 _z_magnetic_axis = intermediate .z_magnetic_axis ,
13901391 diverted = intermediate .diverted ,
13911392 connection_length_target = intermediate .connection_length_target ,
You can’t perform that action at this time.
0 commit comments