1919
2020import dataclasses
2121from typing import Callable , Sequence
22-
2322import jax
2423import jax .numpy as jnp
24+ from torax ._src import jax_utils
2525from torax ._src import state
2626from torax ._src .config import runtime_params as runtime_params_lib
2727from torax ._src .geometry import geometry
2828from torax ._src .pedestal_model import pedestal_model as pedestal_model_lib
29+ from torax ._src .transport_model import enums
2930from torax ._src .transport_model import runtime_params as transport_runtime_params_lib
3031from torax ._src .transport_model import transport_model as transport_model_lib
3132
@@ -60,7 +61,7 @@ def __call__(
6061
6162 # Calculate the transport coefficients - includes contribution from pedestal
6263 # and core transport models.
63- transport_coeffs = self ._call_implementation (
64+ transport_coeffs = self .call_implementation (
6465 transport_runtime_params ,
6566 runtime_params ,
6667 geo ,
@@ -70,7 +71,7 @@ def __call__(
7071
7172 # In contrast to the base TransportModel, we do not apply domain restriction
7273 # or output masking (enabled/disabled channels) as these are handled at the
73- # component model level in _call_implementation here.
74+ # component model level in call_implementation here.
7475
7576 # Apply min/max clipping
7677 transport_coeffs = self ._apply_clipping (
@@ -93,7 +94,7 @@ def __call__(
9394 pedestal_model_output ,
9495 )
9596
96- def _call_implementation (
97+ def call_implementation (
9798 self ,
9899 transport_runtime_params : transport_runtime_params_lib .RuntimeParams ,
99100 runtime_params : runtime_params_lib .RuntimeParams ,
@@ -117,95 +118,142 @@ def _call_implementation(
117118 # Required for pytype
118119 assert isinstance (transport_runtime_params , RuntimeParams )
119120
120- def apply_and_restrict (
121- component_model : transport_model_lib .TransportModel ,
122- component_params : transport_runtime_params_lib .RuntimeParams ,
123- restriction_fn : Callable [
124- [
125- transport_runtime_params_lib .RuntimeParams ,
126- geometry .Geometry ,
127- transport_model_lib .TurbulentTransport ,
128- pedestal_model_lib .PedestalModelOutput ,
129- ],
130- transport_model_lib .TurbulentTransport ,
131- ],
132- ) -> transport_model_lib .TurbulentTransport :
133- # TODO(b/434175682): Consider only computing transport coefficients for
134- # the active domain, rather than masking them out later. This could be
135- # significantly more efficient especially for pedestal models, as these
136- # are only active in a small region of the domain.
137- component_transport_coeffs = component_model ._call_implementation (
138- component_params ,
139- runtime_params ,
140- geo ,
141- core_profiles ,
142- pedestal_model_output ,
143- )
144- component_transport_coeffs = component_model ._apply_output_mask (
145- component_params ,
146- component_transport_coeffs ,
147- )
148- component_transport_coeffs = restriction_fn (
149- component_params ,
150- geo ,
151- component_transport_coeffs ,
152- pedestal_model_output ,
153- )
154- return component_transport_coeffs
155-
156- pedestal_coeffs = [
157- apply_and_restrict (
158- model , params , self ._apply_pedestal_domain_restriction
159- )
160- for model , params in zip (
161- self .pedestal_transport_models ,
162- transport_runtime_params .pedestal_transport_model_params ,
163- )
164- ]
121+ core_coeffs = self ._combine (
122+ self .transport_models ,
123+ transport_runtime_params .transport_model_params ,
124+ runtime_params ,
125+ geo ,
126+ core_profiles ,
127+ pedestal_model_output ,
128+ transport_model_lib .compute_core_domain_mask ,
129+ )
165130
166- core_coeffs = [
167- apply_and_restrict (model , params , model ._apply_domain_restriction )
168- for model , params in zip (
169- self .transport_models ,
170- transport_runtime_params .transport_model_params ,
171- )
172- ]
131+ pedestal_coeffs = self ._combine (
132+ self .pedestal_transport_models ,
133+ transport_runtime_params .pedestal_transport_model_params ,
134+ runtime_params ,
135+ geo ,
136+ core_profiles ,
137+ pedestal_model_output ,
138+ _pedestal_domain_mask ,
139+ )
173140
174141 # Combine the transport coefficients from core and pedestal models.
175- def _combine_maybe_none_coeffs (* leaves ):
176- non_none_leaves = [leaf for leaf in leaves if leaf is not None ]
177- return sum (non_none_leaves ) if non_none_leaves else None
178-
179142 combined_transport_coeffs = jax .tree .map (
180- _combine_maybe_none_coeffs ,
181- * pedestal_coeffs ,
182- * core_coeffs ,
183- # Needed to handle the case where some coefficients are None and others
184- # are not.
185- is_leaf = lambda x : x is None ,
143+ _add_optional , core_coeffs , pedestal_coeffs
186144 )
187145
188146 return combined_transport_coeffs
189147
190- def _apply_pedestal_domain_restriction (
148+ def _combine (
191149 self ,
192- unused_transport_runtime_params : transport_runtime_params_lib .RuntimeParams ,
150+ models : tuple [transport_model_lib .TransportModel , ...],
151+ params_list : Sequence [transport_runtime_params_lib .RuntimeParams ],
152+ runtime_params : runtime_params_lib .RuntimeParams ,
193153 geo : geometry .Geometry ,
194- transport_coeffs : transport_model_lib . TurbulentTransport ,
154+ core_profiles : state . CoreProfiles ,
195155 pedestal_model_output : pedestal_model_lib .PedestalModelOutput ,
156+ domain_mask_fn : Callable [
157+ [
158+ transport_runtime_params_lib .RuntimeParams ,
159+ geometry .Geometry ,
160+ pedestal_model_lib .PedestalModelOutput ,
161+ ],
162+ jax .Array ,
163+ ],
196164 ) -> transport_model_lib .TurbulentTransport :
197- del unused_transport_runtime_params
198- active_mask = geo .rho_face_norm > pedestal_model_output .rho_norm_ped_top
199-
200- chi_face_ion = jnp .where (active_mask , transport_coeffs .chi_face_ion , 0.0 )
201- chi_face_el = jnp .where (active_mask , transport_coeffs .chi_face_el , 0.0 )
202- d_face_el = jnp .where (active_mask , transport_coeffs .d_face_el , 0.0 )
203- v_face_el = jnp .where (active_mask , transport_coeffs .v_face_el , 0.0 )
165+ """Calculates and combines transport coefficients from a list of models."""
204166
205- return dataclasses .replace (
206- transport_coeffs ,
207- chi_face_ion = chi_face_ion ,
208- chi_face_el = chi_face_el ,
209- d_face_el = d_face_el ,
210- v_face_el = v_face_el ,
167+ # Initialize accumulators with zeros. Will be iteratively updated based on
168+ # model outputs and merge modes.
169+ zero_profile = jnp .zeros_like (
170+ geo .rho_face_norm , dtype = jax_utils .get_dtype ()
211171 )
172+ accumulators = {}
173+ locks = {}
174+
175+ for channel , config in transport_model_lib .CHANNEL_CONFIG_STRUCT .items ():
176+ accumulators [channel ] = zero_profile
177+ locks [channel ] = jnp .zeros_like (geo .rho_face_norm , dtype = bool )
178+ for sub in config ['sub_channels' ]:
179+ accumulators [sub ] = None
180+
181+ # TODO(b/344023668) explore batching or fori_loop for performance.
182+ for model , params in zip (models , params_list , strict = True ):
183+ # 1. Calculate raw coefficients
184+ coeffs = model .call_implementation (
185+ params , runtime_params , geo , core_profiles , pedestal_model_output
186+ )
187+
188+ # 2. Zero out disabled channels. Unused subchannels returned as None.
189+ coeffs = model .zero_out_disabled_channels (params , coeffs )
190+
191+ # 3. Calculate active domain mask. Values outside this are set to 0.
192+ domain_mask = domain_mask_fn (params , geo , pedestal_model_output )
193+
194+ coeffs_dict = dataclasses .asdict (coeffs )
195+ for k in coeffs_dict :
196+ # Apply domain restriction to values.
197+ if coeffs_dict [k ] is not None :
198+ coeffs_dict [k ] = jnp .where (domain_mask , coeffs_dict [k ], 0.0 )
199+
200+ for channel , config in transport_model_lib .CHANNEL_CONFIG_STRUCT .items ():
201+ disable_flag_name = config ['disable_flag' ]
202+ is_disabled = getattr (params , disable_flag_name )
203+
204+ # A channel is active for this model if it's in the domain AND enabled.
205+ # Note that this is a boolean array over the face grid.
206+ channel_active = jnp .logical_and (
207+ domain_mask , jnp .logical_not (is_disabled )
208+ )
209+
210+ val = coeffs_dict [channel ]
211+ if params .merge_mode == enums .MergeMode .OVERWRITE :
212+ # Wiping: Replace accumulator values where active.
213+ accumulators [channel ] = jnp .where (
214+ channel_active , val , accumulators [channel ]
215+ )
216+ # Update lock.
217+ locks [channel ] = jnp .logical_or (locks [channel ], channel_active )
218+ else : # ADD
219+ # Add where not locked.
220+ factor = jnp .where (locks [channel ], 0.0 , 1.0 )
221+ accumulators [channel ] = accumulators [channel ] + val * factor
222+
223+ # Handle sub-channels.
224+ for sub in config ['sub_channels' ]:
225+ sub_val = coeffs_dict [sub ]
226+ if sub_val is not None :
227+ if accumulators [sub ] is None :
228+ accumulators [sub ] = zero_profile
229+
230+ if params .merge_mode == enums .MergeMode .OVERWRITE :
231+ accumulators [sub ] = jnp .where (
232+ channel_active , sub_val , accumulators [sub ]
233+ )
234+ else : # ADD
235+ # Add where not locked (using main channel lock).
236+ factor = jnp .where (locks [channel ], 0.0 , 1.0 )
237+ accumulators [sub ] = accumulators [sub ] + sub_val * factor
238+
239+ return transport_model_lib .TurbulentTransport (** accumulators )
240+
241+
242+ def _add_optional (
243+ core_value : jax .Array | None , pedestal_value : jax .Array | None
244+ ) -> jax .Array | None :
245+ """Adds two values, treating None as zero. Returns None if both are None."""
246+ if core_value is None :
247+ return pedestal_value
248+ if pedestal_value is None :
249+ return core_value
250+ return core_value + pedestal_value
251+
252+
253+ def _pedestal_domain_mask (
254+ unused_params : transport_runtime_params_lib .RuntimeParams ,
255+ geo : geometry .Geometry ,
256+ pedestal_output : pedestal_model_lib .PedestalModelOutput ,
257+ ) -> jax .Array :
258+ """Calculates the active domain mask for pedestal transport models."""
259+ return jnp .asarray (geo .rho_face_norm > pedestal_output .rho_norm_ped_top )
0 commit comments