Skip to content

Commit 98e321b

Browse files
jcitrinTorax team
authored andcommitted
Refactor to specify impurity_mode as 'fractions'.
Part of series of PRs to allow specifying the impurity type as either fractions, or n_e ratios with optional Z_eff constraint. The impurity type is handled by discriminated unions in pydantic. This PR solely implements the changes for the 'fractions' option which is the present option, and is thus a refactor This is implemented with backward compatibility for the V1 API. A `model_validator` method on PlasmaComposition handles this conversion for backward compatibility, logging warnings for deprecated behaviour. PiperOrigin-RevId: 788590896
1 parent 8c4ae9e commit 98e321b

File tree

3 files changed

+296
-68
lines changed

3 files changed

+296
-68
lines changed

torax/_src/config/plasma_composition.py

Lines changed: 131 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
# limitations under the License.
1414

1515
"""Plasma composition parameters used throughout TORAX simulations."""
16+
import copy
1617
import dataclasses
1718
import functools
18-
19+
import logging
20+
from typing import Annotated, Any, Literal
1921
import chex
2022
import jax
2123
from jax import numpy as jnp
24+
import pydantic
2225
from torax._src import array_typing
2326
from torax._src import constants
2427
from torax._src.config import runtime_validation_utils
@@ -50,6 +53,16 @@ class DynamicIonMixture:
5053
Z_override: array_typing.ScalarFloat | None = None
5154

5255

56+
@jax.tree_util.register_dataclass
57+
@dataclasses.dataclass(frozen=True)
58+
class DynamicImpurityFractions(DynamicIonMixture):
59+
"""Extends DynamicIonMixture to include (static) impurity_mode."""
60+
61+
impurity_mode: str = dataclasses.field(
62+
default='fractions', metadata={'static': True}
63+
)
64+
65+
5366
class IonMixture(torax_pydantic.BaseModelFrozen):
5467
"""Represents a mixture of ion species. The mixture can depend on time.
5568
@@ -102,11 +115,44 @@ def build_dynamic_params(
102115
)
103116

104117

118+
class ImpurityFractionsModel(IonMixture):
119+
"""Impurity content defined by fractional abundances."""
120+
121+
impurity_mode: Annotated[Literal['fractions'], torax_pydantic.JAX_STATIC] = (
122+
'fractions'
123+
)
124+
# Default impurity setting. Parent class has species without a default.
125+
species: runtime_validation_utils.IonMapping = (
126+
torax_pydantic.ValidatedDefault({'Ne': 1.0})
127+
)
128+
129+
def build_dynamic_params(self, t: chex.Numeric) -> DynamicImpurityFractions:
130+
# Call the parent IonMixture's builder
131+
dynamic_impurity_mixture = super().build_dynamic_params(t)
132+
# Use the result to construct the specialized DynamicFractions dataclass
133+
return DynamicImpurityFractions(
134+
fractions=dynamic_impurity_mixture.fractions,
135+
avg_A=dynamic_impurity_mixture.avg_A,
136+
Z_override=dynamic_impurity_mixture.Z_override,
137+
)
138+
139+
@pydantic.model_validator(mode='before')
140+
@classmethod
141+
def _conform_impurity_data(cls, data: dict[str, Any]) -> dict[str, Any]:
142+
"""Ensures backward compatibility if infered that data in legacy format."""
143+
144+
# Maps legacy inputs to the new API format.
145+
# TODO(b/434175938): Remove this once V1 API is deprecated.
146+
if 'species' not in data and 'impurity_mode' not in data:
147+
return {'species': data, 'impurity_mode': 'fractions'}
148+
return data
149+
150+
105151
@jax.tree_util.register_dataclass
106152
@dataclasses.dataclass
107153
class DynamicPlasmaComposition:
108154
main_ion: DynamicIonMixture
109-
impurity: DynamicIonMixture
155+
impurity: DynamicImpurityFractions
110156
Z_eff: array_typing.ArrayFloat
111157
Z_eff_face: array_typing.ArrayFloat
112158

@@ -125,7 +171,27 @@ class PlasmaComposition(torax_pydantic.BaseModelFrozen):
125171
`ION_SYMBOLS`. For mixtures the input is an IonMixture object, constructed
126172
from a dict mapping ion symbols to their fractional concentration in the
127173
mixture.
128-
impurity: Impurity ion species. Same format as main_ion.
174+
impurity: Impurity species. This is a dictionary that configures the
175+
impurity ions in the plasma. It has the following keys: `impurity_mode`:
176+
Determines how the impurity species are defined. Currently, only
177+
'fractions' is implemented. * `'fractions'`: The impurities are defined
178+
by their fractional concentrations within the total impurity population.
179+
`species`: A dictionary mapping ion symbols (e.g., 'Ne', 'W') to their
180+
respective values. The interpretation of these values depends on the
181+
`impurity_mode`. * If `impurity_mode` is `'fractions'`, the values
182+
represent the fraction of each ion species within the total impurity
183+
density. These fractions must sum to 1.0. `Z_override`: Optional. If
184+
provided, this value overrides the calculated average charge (Z) of the
185+
impurity mixture. `A_override`: Optional. If provided, this value
186+
overrides the calculated average mass (A) of the impurity mixture. Other
187+
modes (`'n_e_density_ratios'` and `'n_e_density_ratios_and_Z_eff'`) are
188+
not yet implemented. Finally, backwards compatibility is provided for
189+
legacy inputs to `'impurity'`, e.g. string or dict inputs similar to
190+
`main_ion`, such as `'Ar'` or `{'Ar': 0.6, 'Ne': 0.4}`. In these cases,
191+
`impurity_mode` is inferred to be `'fractions'`, and `Z_override` and
192+
`A_override` are set to the top-level `Z_impurity_override` and
193+
`A_impurity_override` if provided, or `None` otherwise. A pydantic before
194+
validator handles this.
129195
Z_eff: Constraint for impurity densities.
130196
Z_i_override: Optional arbitrary masses and charges which can be used to
131197
override the data for the average Z and A of each IonMixture for main_ions
@@ -135,15 +201,17 @@ class PlasmaComposition(torax_pydantic.BaseModelFrozen):
135201
override the data for the average Z and A of each IonMixture for main_ions
136202
or impurities. Useful for testing or testing physical sensitivities,
137203
outside the constraint of allowed impurity species.
138-
Z_impurity_override: Optional arbitrary masses and charges which can
204+
Z_impurity_override: DEPRECATED. As Z_i_override, but for the impurities.
205+
A_impurity_override: DEPRECATED. As A_i_override, but for the impurities.
139206
"""
140207

208+
impurity: Annotated[
209+
ImpurityFractionsModel,
210+
pydantic.Field(discriminator='impurity_mode'),
211+
]
141212
main_ion: runtime_validation_utils.IonMapping = (
142213
torax_pydantic.ValidatedDefault({'D': 0.5, 'T': 0.5})
143214
)
144-
impurity: runtime_validation_utils.IonMapping = (
145-
torax_pydantic.ValidatedDefault('Ne')
146-
)
147215
Z_eff: (
148216
runtime_validation_utils.TimeVaryingArrayDefinedAtRightBoundaryAndBounded
149217
) = torax_pydantic.ValidatedDefault(1.0)
@@ -152,10 +220,60 @@ class PlasmaComposition(torax_pydantic.BaseModelFrozen):
152220
Z_impurity_override: torax_pydantic.TimeVaryingScalar | None = None
153221
A_impurity_override: torax_pydantic.TimeVaryingScalar | None = None
154222

155-
# Generate the IonMixture objects from the input for either a mixture (dict)
156-
# or the shortcut for a single ion (string). IonMixture objects with a
157-
# single key and fraction=1.0 is used also for the single ion case to reduce
158-
# code duplication.
223+
# For main_ions, IonMixture objects are generated by either a fractional
224+
# mixture (dict[str, TimeVaryingScalar]) or the shortcut for a single constant
225+
# ion (string).
226+
# For impurities, this input is legacy but still supported. A new API is also
227+
# available with different impurity_modes, e.g. fractions or n_e_ratios.
228+
# A pydantic before validator infers the API format and handles conversions.
229+
230+
@pydantic.model_validator(mode='before')
231+
@classmethod
232+
def _conform_impurity_data(cls, data: dict[str, Any]) -> dict[str, Any]:
233+
"""Sets defaults and ensures backward compatibility for impurity inputs."""
234+
configurable_data = copy.deepcopy(data)
235+
236+
Z_impurity_override = configurable_data.get('Z_impurity_override')
237+
A_impurity_override = configurable_data.get('A_impurity_override')
238+
239+
# Set defaults for impurity if not specified. To maintain same default
240+
# behaviour as before, the top-level Z_impurity_override and
241+
# A_impurity_override are used as overrides for the impurity fractions.
242+
# TODO(b/434175938): Remove this once V1 API is deprecated and the top-level
243+
# overrides are removed, and set default directly in class attribute.
244+
if 'impurity' not in configurable_data:
245+
configurable_data['impurity'] = {
246+
'impurity_mode': 'fractions',
247+
'Z_override': Z_impurity_override,
248+
'A_override': A_impurity_override,
249+
}
250+
return configurable_data
251+
252+
impurity_data = configurable_data['impurity']
253+
254+
# New API format: impurity_mode is specified.
255+
if isinstance(impurity_data, dict) and 'impurity_mode' in impurity_data:
256+
if Z_impurity_override is not None or A_impurity_override is not None:
257+
logging.warning(
258+
'Z_impurity_override and/or A_impurity_override are set at the'
259+
' plasma_composition level, but the new impurity API is being used'
260+
' (impurity_mode is set). These top-level overrides are deprecated'
261+
' and will be ignored. Use Z_override and A_override within the'
262+
' impurity dictionary instead.'
263+
)
264+
return configurable_data
265+
266+
# Legacy format from here on.
267+
# This handles conformant V1 inputs like 'Ne' or {'Ne': 0.8, 'Ar': 0.2}.
268+
# Non-conformant inputs are caught by ImpurityFractionsModel validation.
269+
# TODO(b/434175938): Remove this once V1 API is deprecated.
270+
configurable_data['impurity'] = {
271+
'impurity_mode': 'fractions',
272+
'species': impurity_data,
273+
'Z_override': Z_impurity_override,
274+
'A_override': A_impurity_override,
275+
}
276+
return configurable_data
159277

160278
def tree_flatten(self):
161279
# Override the default tree_flatten to also save out the cached
@@ -169,7 +287,6 @@ def tree_flatten(self):
169287
self.Z_impurity_override,
170288
self.A_impurity_override,
171289
self._main_ion_mixture,
172-
self._impurity_mixture,
173290
)
174291
aux_data = ()
175292
return children, aux_data
@@ -186,7 +303,6 @@ def tree_unflatten(cls, aux_data, children):
186303
A_impurity_override=children[6],
187304
)
188305
obj._main_ion_mixture = children[7] # pylint: disable=protected-access
189-
obj._impurity_mixture = children[8] # pylint: disable=protected-access
190306
return obj
191307

192308
@functools.cached_property
@@ -199,28 +315,18 @@ def _main_ion_mixture(self) -> IonMixture:
199315
A_override=self.A_i_override,
200316
)
201317

202-
@functools.cached_property
203-
def _impurity_mixture(self) -> IonMixture:
204-
"""Returns the IonMixture object for the impurity ions."""
205-
# Use `model_construct` as no validation required.
206-
return IonMixture.model_construct(
207-
species=self.impurity,
208-
Z_override=self.Z_impurity_override,
209-
A_override=self.A_impurity_override,
210-
)
211-
212318
def get_main_ion_names(self) -> tuple[str, ...]:
213319
"""Returns the main ion symbol strings from the input."""
214320
return tuple(self._main_ion_mixture.species.keys())
215321

216322
def get_impurity_names(self) -> tuple[str, ...]:
217323
"""Returns the impurity symbol strings from the input."""
218-
return tuple(self._impurity_mixture.species.keys())
324+
return tuple(self.impurity.species.keys())
219325

220326
def build_dynamic_params(self, t: chex.Numeric) -> DynamicPlasmaComposition:
221327
return DynamicPlasmaComposition(
222328
main_ion=self._main_ion_mixture.build_dynamic_params(t),
223-
impurity=self._impurity_mixture.build_dynamic_params(t),
329+
impurity=self.impurity.build_dynamic_params(t),
224330
Z_eff=self.Z_eff.get_value(t),
225331
Z_eff_face=self.Z_eff.get_value(t, grid_type='face'),
226332
)

torax/_src/config/tests/plasma_composition_test.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torax._src.config import plasma_composition
2222
from torax._src.geometry import pydantic_model as geometry_pydantic_model
2323
from torax._src.physics import charge_states
24+
from torax._src.torax_pydantic import model_config
2425
from torax._src.torax_pydantic import torax_pydantic
2526

2627

@@ -150,6 +151,136 @@ def f(pc_model: plasma_composition.PlasmaComposition, t: chex.Numeric):
150151
chex.assert_trees_all_close(output.Z_eff, updated_zeff)
151152
self.assertEqual(jax_utils.get_number_of_compiles(f), 1)
152153

154+
@parameterized.named_parameters(
155+
dict(
156+
testcase_name='default',
157+
config={},
158+
expected_impurity_names=('Ne',),
159+
expected_Z_override=None,
160+
expected_A_override=None,
161+
),
162+
dict(
163+
testcase_name='legacy_impurity_string',
164+
config={'impurity': 'Ar'},
165+
expected_impurity_names=('Ar',),
166+
expected_Z_override=None,
167+
expected_A_override=None,
168+
),
169+
dict(
170+
testcase_name='legacy_impurity_dict_single_species',
171+
config={'impurity': {'Be': 1.0}},
172+
expected_impurity_names=('Be',),
173+
expected_Z_override=None,
174+
expected_A_override=None,
175+
),
176+
dict(
177+
testcase_name='legacy_impurity_dict_multiple_species',
178+
config={'impurity': {'Ar': 0.6, 'Ne': 0.4}},
179+
expected_impurity_names=('Ar', 'Ne'),
180+
expected_Z_override=None,
181+
expected_A_override=None,
182+
),
183+
dict(
184+
testcase_name='legacy_with_overrides',
185+
config={'impurity': 'Ar', 'Z_impurity_override': 8.0},
186+
expected_impurity_names=('Ar',),
187+
expected_Z_override=8.0,
188+
expected_A_override=None,
189+
),
190+
dict(
191+
testcase_name='new_api_explicit',
192+
config={
193+
'impurity': {
194+
'impurity_mode': 'fractions',
195+
'species': {'C': 0.5, 'N': 0.5},
196+
'Z_override': 6.5,
197+
'A_override': 13.0,
198+
},
199+
},
200+
expected_impurity_names=('C', 'N'),
201+
expected_Z_override=6.5,
202+
expected_A_override=13.0,
203+
),
204+
)
205+
def test_impurity_api(
206+
self,
207+
config,
208+
expected_impurity_names,
209+
expected_Z_override,
210+
expected_A_override,
211+
):
212+
pc = plasma_composition.PlasmaComposition(**config)
213+
self.assertEqual(pc.get_impurity_names(), expected_impurity_names)
214+
if pc.impurity.Z_override is not None:
215+
self.assertEqual(
216+
pc.impurity.Z_override.get_value(0.0), expected_Z_override
217+
)
218+
else:
219+
self.assertIsNone(expected_Z_override)
220+
if pc.impurity.A_override is not None:
221+
self.assertEqual(
222+
pc.impurity.A_override.get_value(0.0), expected_A_override
223+
)
224+
else:
225+
self.assertIsNone(expected_A_override)
226+
227+
def test_impurity_api_warning(self):
228+
with self.assertLogs(level='WARNING') as log_output:
229+
plasma_composition.PlasmaComposition(
230+
impurity={
231+
'impurity_mode': 'fractions',
232+
'species': 'Ne',
233+
'Z_override': 5.0,
234+
},
235+
Z_impurity_override=6.0,
236+
)
237+
self.assertIn(
238+
'Z_impurity_override and/or A_impurity_override are set',
239+
log_output[0][0].message,
240+
)
241+
242+
def test_update_fields_with_legacy_impurity_input(self):
243+
"""Tests updating legacy impurity format via update_fields."""
244+
config_dict = {
245+
'profile_conditions': {},
246+
'plasma_composition': {'impurity': {'Ne': 0.99, 'W': 0.01}},
247+
'numerics': {},
248+
'geometry': {'geometry_type': 'circular', 'n_rho': 4},
249+
'sources': {},
250+
'solver': {},
251+
'transport': {},
252+
'pedestal': {},
253+
}
254+
255+
config_updates = {'plasma_composition.impurity': {'Ne': 0.98, 'W': 0.02}}
256+
257+
torax_config = model_config.ToraxConfig.from_dict(config_dict)
258+
self.assertEqual(
259+
torax_config.plasma_composition.get_impurity_names(), ('Ne', 'W')
260+
)
261+
assert(torax_config.plasma_composition.impurity.species['Ne'] is not None)
262+
assert(torax_config.plasma_composition.impurity.species['W'] is not None)
263+
self.assertEqual(
264+
torax_config.plasma_composition.impurity.species['Ne'].get_value(0.0),
265+
0.99,
266+
)
267+
self.assertEqual(
268+
torax_config.plasma_composition.impurity.species['W'].get_value(0.0),
269+
0.01,
270+
)
271+
torax_config.update_fields(config_updates)
272+
self.assertEqual(
273+
torax_config.plasma_composition.get_impurity_names(), ('Ne', 'W')
274+
)
275+
self.assertEqual(
276+
torax_config.plasma_composition.impurity.species['Ne'].get_value(0.0),
277+
0.98,
278+
)
279+
self.assertEqual(
280+
torax_config.plasma_composition.impurity.species['W'].get_value(0.0),
281+
0.02,
282+
)
283+
153284

154285
class IonMixtureTest(parameterized.TestCase):
155286
"""Unit tests for constructing the IonMixture class."""

0 commit comments

Comments
 (0)