Skip to content

Commit fa9286c

Browse files
jcitrinTorax team
authored andcommitted
Bugfix: apply domain restrictions to optional transport sub-channels.
Drive-by: avoid explicitly writing out sub-channels elsewhere and use CHANNEL_CONFIG_STRUCT for iterations. Will help with subsequent additions to optional transport coefficients. bohmgyrobohm test updated due to the bohm and gyrobohm transport coefficient outputs being different in the pedestal region due to the now correct domain restriction. No impact on actual results (since these coefficients were not actually used directly) PiperOrigin-RevId: 860323344
1 parent ff9010c commit fa9286c

File tree

3 files changed

+61
-11
lines changed

3 files changed

+61
-11
lines changed

torax/_src/transport_model/tests/transport_model_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,53 @@ def test_preserves_none_disabled(self):
596596
)
597597
self.assertIsNone(new_coeffs_disabled.chi_face_ion_bohm)
598598

599+
def test_sub_channel_domain_restriction(self):
600+
"""Tests that sub-channels are masked by domain restriction."""
601+
config = default_configs.get_default_config_dict()
602+
config['transport'] = {
603+
'model_name': 'fixed',
604+
'rho_max': 0.8,
605+
'smoothing_width': 0.0,
606+
'chi_min': 0.0,
607+
'D_e_min': 0.0,
608+
'V_e_min': 0.0,
609+
}
610+
torax_config = model_config.ToraxConfig.from_dict(config)
611+
runtime_params = build_runtime_params.RuntimeParamsProvider.from_config(
612+
torax_config
613+
)(t=0.0)
614+
geo = torax_config.geometry.build_provider(t=0.0)
615+
# We need a pedestal model even if unused by the fixed transport
616+
pedestal_model = torax_config.pedestal.build_pedestal_model()
617+
# Mock core profiles (not used by FixedTransportModel but needed for API)
618+
core_profiles = initialization.initial_core_profiles(
619+
runtime_params,
620+
geo,
621+
torax_config.sources.build_models(),
622+
torax_config.neoclassical.build_models(),
623+
)
624+
pedestal_outputs = pedestal_model(runtime_params, geo, core_profiles)
625+
626+
transport_model = torax_config.transport.build_transport_model()
627+
coeffs = transport_model(
628+
runtime_params, geo, core_profiles, pedestal_outputs
629+
)
630+
631+
# Find index where rho > 0.8
632+
cutoff_idx = np.searchsorted(geo.rho_face_norm, 0.8, side='right')
633+
634+
# Verify main channel is zeroed
635+
np.testing.assert_allclose(coeffs.chi_face_ion[cutoff_idx:], 0.0)
636+
637+
# Verify sub-channels are also zeroed
638+
# FixedTransportModel sets chi_face_ion_bohm = chi_face_ion * 0.3
639+
# If not masked, it would be non-zero because FixedTransportModel computes
640+
# it everywhere
641+
self.assertIsNotNone(coeffs.chi_face_ion_bohm)
642+
np.testing.assert_allclose(coeffs.chi_face_ion_bohm[cutoff_idx:], 0.0)
643+
self.assertIsNotNone(coeffs.chi_face_ion_gyrobohm)
644+
np.testing.assert_allclose(coeffs.chi_face_ion_gyrobohm[cutoff_idx:], 0.0)
645+
599646

600647
if __name__ == '__main__':
601648
absltest.main()

torax/_src/transport_model/transport_model.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,21 @@ def _apply_domain_restriction(
193193
transport_runtime_params, geo, pedestal_model_output
194194
)
195195

196-
chi_face_ion = jnp.where(active_mask, transport_coeffs.chi_face_ion, 0.0)
197-
chi_face_el = jnp.where(active_mask, transport_coeffs.chi_face_el, 0.0)
198-
d_face_el = jnp.where(active_mask, transport_coeffs.d_face_el, 0.0)
199-
v_face_el = jnp.where(active_mask, transport_coeffs.v_face_el, 0.0)
196+
coeffs_dict = dataclasses.asdict(transport_coeffs)
197+
to_replace = {}
200198

201-
return dataclasses.replace(
202-
transport_coeffs,
203-
chi_face_ion=chi_face_ion,
204-
chi_face_el=chi_face_el,
205-
d_face_el=d_face_el,
206-
v_face_el=v_face_el,
207-
)
199+
for channel_name, config in CHANNEL_CONFIG_STRUCT.items():
200+
# Mask main channel
201+
val = coeffs_dict[channel_name]
202+
to_replace[channel_name] = jnp.where(active_mask, val, 0.0)
203+
204+
# Mask sub-channels
205+
for sub_channel in config['sub_channels']:
206+
sub_val = coeffs_dict[sub_channel]
207+
if sub_val is not None:
208+
to_replace[sub_channel] = jnp.where(active_mask, sub_val, 0.0)
209+
210+
return dataclasses.replace(transport_coeffs, **to_replace)
208211

209212
def _apply_clipping(
210213
self,
872 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)