Skip to content

Commit 86ab118

Browse files
jcitrinTorax team
authored andcommitted
Implement ADD and OVERWRITE merge modes in CombinedTransportModel.
Drive-by: fixed a bug where previously sub-channels (like chi_bohm_e) were leaking outside their configured radial domain in combined_transport_model. The combined transport model now supports two merge modes for combining coefficients from multiple sub-models: ADD and OVERWRITE. OVERWRITE mode replaces existing coefficients in the model's active domain and locks those regions, preventing subsequent ADD models from contributing. Disabled channels in an OVERWRITE model are transparent and do not affect the accumulated coefficients. The implementation uses an accumulator and lock system within a new _accumulate helper function. PiperOrigin-RevId: 858780479
1 parent a2919e0 commit 86ab118

18 files changed

+503
-149
lines changed

torax/_src/output_tools/output.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -659,16 +659,17 @@ def _save_core_transport(
659659

660660
# Save optional BohmGyroBohm attributes if present.
661661
core_transport = self._stacked_core_transport
662-
if (
663-
core_transport.chi_face_el_bohm is not None
664-
or core_transport.chi_face_el_gyrobohm is not None
665-
or core_transport.chi_face_ion_bohm is not None
666-
or core_transport.chi_face_ion_gyrobohm is not None
667-
):
668-
xr_dict[CHI_BOHM_E] = core_transport.chi_face_el_bohm
669-
xr_dict[CHI_GYROBOHM_E] = core_transport.chi_face_el_gyrobohm
670-
xr_dict[CHI_BOHM_I] = core_transport.chi_face_ion_bohm
671-
xr_dict[CHI_GYROBOHM_I] = core_transport.chi_face_ion_gyrobohm
662+
optional_transport_map = {
663+
CHI_BOHM_E: core_transport.chi_face_el_bohm,
664+
CHI_GYROBOHM_E: core_transport.chi_face_el_gyrobohm,
665+
CHI_BOHM_I: core_transport.chi_face_ion_bohm,
666+
CHI_GYROBOHM_I: core_transport.chi_face_ion_gyrobohm,
667+
}
668+
669+
for name, data in optional_transport_map.items():
670+
# Skip if None or an array of Nones from stack
671+
if data is not None and data.dtype != object:
672+
xr_dict[name] = data
672673

673674
xr_dict = {
674675
name: self._pack_into_data_array(

torax/_src/transport_model/bohm_gyrobohm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class RuntimeParams(transport_runtime_params_lib.RuntimeParams):
5050
class BohmGyroBohmTransportModel(transport_model_lib.TransportModel):
5151
"""Calculates various coefficients related to particle transport according to the Bohm + gyro-Bohm Model."""
5252

53-
def _call_implementation(
53+
def call_implementation(
5454
self,
5555
transport_runtime_params: transport_runtime_params_lib.RuntimeParams,
5656
runtime_params: runtime_params_lib.RuntimeParams,

torax/_src/transport_model/combined.py

Lines changed: 130 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919

2020
import dataclasses
2121
from typing import Callable, Sequence
22-
2322
import jax
2423
import jax.numpy as jnp
24+
from torax._src import jax_utils
2525
from torax._src import state
2626
from torax._src.config import runtime_params as runtime_params_lib
2727
from torax._src.geometry import geometry
2828
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
29+
from torax._src.transport_model import enums
2930
from torax._src.transport_model import runtime_params as transport_runtime_params_lib
3031
from torax._src.transport_model import transport_model as transport_model_lib
3132

@@ -60,7 +61,7 @@ def __call__(
6061

6162
# Calculate the transport coefficients - includes contribution from pedestal
6263
# and core transport models.
63-
transport_coeffs = self._call_implementation(
64+
transport_coeffs = self.call_implementation(
6465
transport_runtime_params,
6566
runtime_params,
6667
geo,
@@ -70,7 +71,7 @@ def __call__(
7071

7172
# In contrast to the base TransportModel, we do not apply domain restriction
7273
# or output masking (enabled/disabled channels) as these are handled at the
73-
# component model level in _call_implementation here.
74+
# component model level in call_implementation here.
7475

7576
# Apply min/max clipping
7677
transport_coeffs = self._apply_clipping(
@@ -93,7 +94,7 @@ def __call__(
9394
pedestal_model_output,
9495
)
9596

96-
def _call_implementation(
97+
def call_implementation(
9798
self,
9899
transport_runtime_params: transport_runtime_params_lib.RuntimeParams,
99100
runtime_params: runtime_params_lib.RuntimeParams,
@@ -117,95 +118,142 @@ def _call_implementation(
117118
# Required for pytype
118119
assert isinstance(transport_runtime_params, RuntimeParams)
119120

120-
def apply_and_restrict(
121-
component_model: transport_model_lib.TransportModel,
122-
component_params: transport_runtime_params_lib.RuntimeParams,
123-
restriction_fn: Callable[
124-
[
125-
transport_runtime_params_lib.RuntimeParams,
126-
geometry.Geometry,
127-
transport_model_lib.TurbulentTransport,
128-
pedestal_model_lib.PedestalModelOutput,
129-
],
130-
transport_model_lib.TurbulentTransport,
131-
],
132-
) -> transport_model_lib.TurbulentTransport:
133-
# TODO(b/434175682): Consider only computing transport coefficients for
134-
# the active domain, rather than masking them out later. This could be
135-
# significantly more efficient especially for pedestal models, as these
136-
# are only active in a small region of the domain.
137-
component_transport_coeffs = component_model._call_implementation(
138-
component_params,
139-
runtime_params,
140-
geo,
141-
core_profiles,
142-
pedestal_model_output,
143-
)
144-
component_transport_coeffs = component_model._apply_output_mask(
145-
component_params,
146-
component_transport_coeffs,
147-
)
148-
component_transport_coeffs = restriction_fn(
149-
component_params,
150-
geo,
151-
component_transport_coeffs,
152-
pedestal_model_output,
153-
)
154-
return component_transport_coeffs
155-
156-
pedestal_coeffs = [
157-
apply_and_restrict(
158-
model, params, self._apply_pedestal_domain_restriction
159-
)
160-
for model, params in zip(
161-
self.pedestal_transport_models,
162-
transport_runtime_params.pedestal_transport_model_params,
163-
)
164-
]
121+
core_coeffs = self._combine(
122+
self.transport_models,
123+
transport_runtime_params.transport_model_params,
124+
runtime_params,
125+
geo,
126+
core_profiles,
127+
pedestal_model_output,
128+
transport_model_lib.compute_core_domain_mask,
129+
)
165130

166-
core_coeffs = [
167-
apply_and_restrict(model, params, model._apply_domain_restriction)
168-
for model, params in zip(
169-
self.transport_models,
170-
transport_runtime_params.transport_model_params,
171-
)
172-
]
131+
pedestal_coeffs = self._combine(
132+
self.pedestal_transport_models,
133+
transport_runtime_params.pedestal_transport_model_params,
134+
runtime_params,
135+
geo,
136+
core_profiles,
137+
pedestal_model_output,
138+
_pedestal_domain_mask,
139+
)
173140

174141
# Combine the transport coefficients from core and pedestal models.
175-
def _combine_maybe_none_coeffs(*leaves):
176-
non_none_leaves = [leaf for leaf in leaves if leaf is not None]
177-
return sum(non_none_leaves) if non_none_leaves else None
178-
179142
combined_transport_coeffs = jax.tree.map(
180-
_combine_maybe_none_coeffs,
181-
*pedestal_coeffs,
182-
*core_coeffs,
183-
# Needed to handle the case where some coefficients are None and others
184-
# are not.
185-
is_leaf=lambda x: x is None,
143+
_add_optional, core_coeffs, pedestal_coeffs
186144
)
187145

188146
return combined_transport_coeffs
189147

190-
def _apply_pedestal_domain_restriction(
148+
def _combine(
191149
self,
192-
unused_transport_runtime_params: transport_runtime_params_lib.RuntimeParams,
150+
models: tuple[transport_model_lib.TransportModel, ...],
151+
params_list: Sequence[transport_runtime_params_lib.RuntimeParams],
152+
runtime_params: runtime_params_lib.RuntimeParams,
193153
geo: geometry.Geometry,
194-
transport_coeffs: transport_model_lib.TurbulentTransport,
154+
core_profiles: state.CoreProfiles,
195155
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
156+
domain_mask_fn: Callable[
157+
[
158+
transport_runtime_params_lib.RuntimeParams,
159+
geometry.Geometry,
160+
pedestal_model_lib.PedestalModelOutput,
161+
],
162+
jax.Array,
163+
],
196164
) -> transport_model_lib.TurbulentTransport:
197-
del unused_transport_runtime_params
198-
active_mask = geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top
199-
200-
chi_face_ion = jnp.where(active_mask, transport_coeffs.chi_face_ion, 0.0)
201-
chi_face_el = jnp.where(active_mask, transport_coeffs.chi_face_el, 0.0)
202-
d_face_el = jnp.where(active_mask, transport_coeffs.d_face_el, 0.0)
203-
v_face_el = jnp.where(active_mask, transport_coeffs.v_face_el, 0.0)
165+
"""Calculates and combines transport coefficients from a list of models."""
204166

205-
return dataclasses.replace(
206-
transport_coeffs,
207-
chi_face_ion=chi_face_ion,
208-
chi_face_el=chi_face_el,
209-
d_face_el=d_face_el,
210-
v_face_el=v_face_el,
167+
# Initialize accumulators with zeros. Will be iteratively updated based on
168+
# model outputs and merge modes.
169+
zero_profile = jnp.zeros_like(
170+
geo.rho_face_norm, dtype=jax_utils.get_dtype()
211171
)
172+
accumulators = {}
173+
locks = {}
174+
175+
for channel, config in transport_model_lib.CHANNEL_CONFIG_STRUCT.items():
176+
accumulators[channel] = zero_profile
177+
locks[channel] = jnp.zeros_like(geo.rho_face_norm, dtype=bool)
178+
for sub in config['sub_channels']:
179+
accumulators[sub] = None
180+
181+
# TODO(b/344023668) explore batching or fori_loop for performance.
182+
for model, params in zip(models, params_list, strict=True):
183+
# 1. Calculate raw coefficients
184+
coeffs = model.call_implementation(
185+
params, runtime_params, geo, core_profiles, pedestal_model_output
186+
)
187+
188+
# 2. Zero out disabled channels. Unused subchannels returned as None.
189+
coeffs = model.zero_out_disabled_channels(params, coeffs)
190+
191+
# 3. Calculate active domain mask. Values outside this are set to 0.
192+
domain_mask = domain_mask_fn(params, geo, pedestal_model_output)
193+
194+
coeffs_dict = dataclasses.asdict(coeffs)
195+
for k in coeffs_dict:
196+
# Apply domain restriction to values.
197+
if coeffs_dict[k] is not None:
198+
coeffs_dict[k] = jnp.where(domain_mask, coeffs_dict[k], 0.0)
199+
200+
for channel, config in transport_model_lib.CHANNEL_CONFIG_STRUCT.items():
201+
disable_flag_name = config['disable_flag']
202+
is_disabled = getattr(params, disable_flag_name)
203+
204+
# A channel is active for this model if it's in the domain AND enabled.
205+
# Note that this is a boolean array over the face grid.
206+
channel_active = jnp.logical_and(
207+
domain_mask, jnp.logical_not(is_disabled)
208+
)
209+
210+
val = coeffs_dict[channel]
211+
if params.merge_mode == enums.MergeMode.OVERWRITE:
212+
# Wiping: Replace accumulator values where active.
213+
accumulators[channel] = jnp.where(
214+
channel_active, val, accumulators[channel]
215+
)
216+
# Update lock.
217+
locks[channel] = jnp.logical_or(locks[channel], channel_active)
218+
else: # ADD
219+
# Add where not locked.
220+
factor = jnp.where(locks[channel], 0.0, 1.0)
221+
accumulators[channel] = accumulators[channel] + val * factor
222+
223+
# Handle sub-channels.
224+
for sub in config['sub_channels']:
225+
sub_val = coeffs_dict[sub]
226+
if sub_val is not None:
227+
if accumulators[sub] is None:
228+
accumulators[sub] = zero_profile
229+
230+
if params.merge_mode == enums.MergeMode.OVERWRITE:
231+
accumulators[sub] = jnp.where(
232+
channel_active, sub_val, accumulators[sub]
233+
)
234+
else: # ADD
235+
# Add where not locked (using main channel lock).
236+
factor = jnp.where(locks[channel], 0.0, 1.0)
237+
accumulators[sub] = accumulators[sub] + sub_val * factor
238+
239+
return transport_model_lib.TurbulentTransport(**accumulators)
240+
241+
242+
def _add_optional(
243+
core_value: jax.Array | None, pedestal_value: jax.Array | None
244+
) -> jax.Array | None:
245+
"""Adds two values, treating None as zero. Returns None if both are None."""
246+
if core_value is None:
247+
return pedestal_value
248+
if pedestal_value is None:
249+
return core_value
250+
return core_value + pedestal_value
251+
252+
253+
def _pedestal_domain_mask(
254+
unused_params: transport_runtime_params_lib.RuntimeParams,
255+
geo: geometry.Geometry,
256+
pedestal_output: pedestal_model_lib.PedestalModelOutput,
257+
) -> jax.Array:
258+
"""Calculates the active domain mask for pedestal transport models."""
259+
return jnp.asarray(geo.rho_face_norm > pedestal_output.rho_norm_ped_top)

torax/_src/transport_model/constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class RuntimeParams(transport_runtime_params_lib.RuntimeParams):
5050
class ConstantTransportModel(transport_model_lib.TransportModel):
5151
"""Calculates various coefficients related to particle transport."""
5252

53-
def _call_implementation(
53+
def call_implementation(
5454
self,
5555
transport_runtime_params: transport_runtime_params_lib.RuntimeParams,
5656
runtime_params: runtime_params_lib.RuntimeParams,

torax/_src/transport_model/critical_gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class RuntimeParams(transport_runtime_params_lib.RuntimeParams):
4444
class CriticalGradientTransportModel(transport_model.TransportModel):
4545
"""Calculates various coefficients related to particle transport."""
4646

47-
def _call_implementation(
47+
def call_implementation(
4848
self,
4949
transport_runtime_params: transport_runtime_params_lib.RuntimeParams,
5050
runtime_params: runtime_params_lib.RuntimeParams,

torax/_src/transport_model/qlknn_transport_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ class QLKNNTransportModel(
217217
path: str
218218
name: str
219219

220-
def _call_implementation(
220+
def call_implementation(
221221
self,
222222
transport_runtime_params: RuntimeParams,
223223
runtime_params: runtime_params_lib.RuntimeParams,

torax/_src/transport_model/qualikiz_transport_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(self):
8181
)
8282
self._runpath = os.path.join(self._qlkrun_parentdir.name, self._qlkrun_name)
8383

84-
def _call_implementation(
84+
def call_implementation(
8585
self,
8686
transport_runtime_params: transport_runtime_params_lib.RuntimeParams,
8787
runtime_params: runtime_params_lib.RuntimeParams,

torax/_src/transport_model/tests/bohm_gyrobohm_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,14 @@ def test_coeff_multiplier_feature(self):
134134
chi_i_gyrobohm_multiplier=5.0,
135135
)
136136

137-
output_A = self.model._call_implementation(
137+
output_A = self.model.call_implementation(
138138
dyn_params_A.transport,
139139
dyn_params_A,
140140
self.geo,
141141
self.core_profiles,
142142
self.pedestal_outputs,
143143
)
144-
output_B = self.model._call_implementation(
144+
output_B = self.model.call_implementation(
145145
dyn_params_B.transport,
146146
dyn_params_B,
147147
self.geo,
@@ -179,14 +179,14 @@ def test_raw_bohm_and_gyrobohm_fields(self):
179179
chi_i_gyrobohm_multiplier=5.0,
180180
)
181181

182-
output_A = self.model._call_implementation(
182+
output_A = self.model.call_implementation(
183183
dyn_params_A.transport,
184184
dyn_params_A,
185185
self.geo,
186186
self.core_profiles,
187187
self.pedestal_outputs,
188188
)
189-
output_B = self.model._call_implementation(
189+
output_B = self.model.call_implementation(
190190
dyn_params_B.transport,
191191
dyn_params_B,
192192
self.geo,

0 commit comments

Comments
 (0)