1313# limitations under the License.
1414
1515"""Plasma composition parameters used throughout TORAX simulations."""
16+ import copy
1617import dataclasses
1718import functools
18-
19+ import logging
20+ from typing import Annotated , Any , Literal
1921import chex
2022import jax
2123from jax import numpy as jnp
24+ import pydantic
2225from torax ._src import array_typing
2326from torax ._src import constants
2427from torax ._src .config import runtime_validation_utils
@@ -50,6 +53,16 @@ class DynamicIonMixture:
5053 Z_override : array_typing .ScalarFloat | None = None
5154
5255
56+ @jax .tree_util .register_dataclass
57+ @dataclasses .dataclass (frozen = True )
58+ class DynamicImpurityFractions (DynamicIonMixture ):
59+ """Extends DynamicIonMixture to include (static) impurity_mode."""
60+
61+ impurity_mode : str = dataclasses .field (
62+ default = 'fractions' , metadata = {'static' : True }
63+ )
64+
65+
5366class IonMixture (torax_pydantic .BaseModelFrozen ):
5467 """Represents a mixture of ion species. The mixture can depend on time.
5568
@@ -102,11 +115,44 @@ def build_dynamic_params(
102115 )
103116
104117
118+ class ImpurityFractionsModel (IonMixture ):
119+ """Impurity content defined by fractional abundances."""
120+
121+ impurity_mode : Annotated [Literal ['fractions' ], torax_pydantic .JAX_STATIC ] = (
122+ 'fractions'
123+ )
124+ # Default impurity setting. Parent class has species without a default.
125+ species : runtime_validation_utils .IonMapping = (
126+ torax_pydantic .ValidatedDefault ({'Ne' : 1.0 })
127+ )
128+
129+ def build_dynamic_params (self , t : chex .Numeric ) -> DynamicImpurityFractions :
130+ # Call the parent IonMixture's builder
131+ dynamic_impurity_mixture = super ().build_dynamic_params (t )
132+ # Use the result to construct the specialized DynamicFractions dataclass
133+ return DynamicImpurityFractions (
134+ fractions = dynamic_impurity_mixture .fractions ,
135+ avg_A = dynamic_impurity_mixture .avg_A ,
136+ Z_override = dynamic_impurity_mixture .Z_override ,
137+ )
138+
139+ @pydantic .model_validator (mode = 'before' )
140+ @classmethod
141+ def _conform_impurity_data (cls , data : dict [str , Any ]) -> dict [str , Any ]:
142+ """Ensures backward compatibility if infered that data in legacy format."""
143+
144+ # Maps legacy inputs to the new API format.
145+ # TODO(b/434175938): Remove this once V1 API is deprecated.
146+ if 'species' not in data and 'impurity_mode' not in data :
147+ return {'species' : data , 'impurity_mode' : 'fractions' }
148+ return data
149+
150+
105151@jax .tree_util .register_dataclass
106152@dataclasses .dataclass
107153class DynamicPlasmaComposition :
108154 main_ion : DynamicIonMixture
109- impurity : DynamicIonMixture
155+ impurity : DynamicImpurityFractions
110156 Z_eff : array_typing .ArrayFloat
111157 Z_eff_face : array_typing .ArrayFloat
112158
@@ -125,7 +171,27 @@ class PlasmaComposition(torax_pydantic.BaseModelFrozen):
125171 `ION_SYMBOLS`. For mixtures the input is an IonMixture object, constructed
126172 from a dict mapping ion symbols to their fractional concentration in the
127173 mixture.
128- impurity: Impurity ion species. Same format as main_ion.
174+ impurity: Impurity species. This is a dictionary that configures the
175+ impurity ions in the plasma. It has the following keys: `impurity_mode`:
176+ Determines how the impurity species are defined. Currently, only
177+ 'fractions' is implemented. * `'fractions'`: The impurities are defined
178+ by their fractional concentrations within the total impurity population.
179+ `species`: A dictionary mapping ion symbols (e.g., 'Ne', 'W') to their
180+ respective values. The interpretation of these values depends on the
181+ `impurity_mode`. * If `impurity_mode` is `'fractions'`, the values
182+ represent the fraction of each ion species within the total impurity
183+ density. These fractions must sum to 1.0. `Z_override`: Optional. If
184+ provided, this value overrides the calculated average charge (Z) of the
185+ impurity mixture. `A_override`: Optional. If provided, this value
186+ overrides the calculated average mass (A) of the impurity mixture. Other
187+ modes (`'n_e_density_ratios'` and `'n_e_density_ratios_and_Z_eff'`) are
188+ not yet implemented. Finally, backwards compatibility is provided for
189+ legacy inputs to `'impurity'`, e.g. string or dict inputs similar to
190+ `main_ion`, such as `'Ar'` or `{'Ar': 0.6, 'Ne': 0.4}`. In these cases,
191+ `impurity_mode` is inferred to be `'fractions'`, and `Z_override` and
192+ `A_override` are set to the top-level `Z_impurity_override` and
193+ `A_impurity_override` if provided, or `None` otherwise. A pydantic before
194+ validator handles this.
129195 Z_eff: Constraint for impurity densities.
130196 Z_i_override: Optional arbitrary masses and charges which can be used to
131197 override the data for the average Z and A of each IonMixture for main_ions
@@ -135,15 +201,17 @@ class PlasmaComposition(torax_pydantic.BaseModelFrozen):
135201 override the data for the average Z and A of each IonMixture for main_ions
136202 or impurities. Useful for testing or testing physical sensitivities,
137203 outside the constraint of allowed impurity species.
138- Z_impurity_override: Optional arbitrary masses and charges which can
204+ Z_impurity_override: DEPRECATED. As Z_i_override, but for the impurities.
205+ A_impurity_override: DEPRECATED. As A_i_override, but for the impurities.
139206 """
140207
208+ impurity : Annotated [
209+ ImpurityFractionsModel ,
210+ pydantic .Field (discriminator = 'impurity_mode' ),
211+ ]
141212 main_ion : runtime_validation_utils .IonMapping = (
142213 torax_pydantic .ValidatedDefault ({'D' : 0.5 , 'T' : 0.5 })
143214 )
144- impurity : runtime_validation_utils .IonMapping = (
145- torax_pydantic .ValidatedDefault ('Ne' )
146- )
147215 Z_eff : (
148216 runtime_validation_utils .TimeVaryingArrayDefinedAtRightBoundaryAndBounded
149217 ) = torax_pydantic .ValidatedDefault (1.0 )
@@ -152,10 +220,60 @@ class PlasmaComposition(torax_pydantic.BaseModelFrozen):
152220 Z_impurity_override : torax_pydantic .TimeVaryingScalar | None = None
153221 A_impurity_override : torax_pydantic .TimeVaryingScalar | None = None
154222
155- # Generate the IonMixture objects from the input for either a mixture (dict)
156- # or the shortcut for a single ion (string). IonMixture objects with a
157- # single key and fraction=1.0 is used also for the single ion case to reduce
158- # code duplication.
223+ # For main_ions, IonMixture objects are generated by either a fractional
224+ # mixture (dict[str, TimeVaryingScalar]) or the shortcut for a single constant
225+ # ion (string).
226+ # For impurities, this input is legacy but still supported. A new API is also
227+ # available with different impurity_modes, e.g. fractions or n_e_ratios.
228+ # A pydantic before validator infers the API format and handles conversions.
229+
230+ @pydantic .model_validator (mode = 'before' )
231+ @classmethod
232+ def _conform_impurity_data (cls , data : dict [str , Any ]) -> dict [str , Any ]:
233+ """Sets defaults and ensures backward compatibility for impurity inputs."""
234+ configurable_data = copy .deepcopy (data )
235+
236+ Z_impurity_override = configurable_data .get ('Z_impurity_override' )
237+ A_impurity_override = configurable_data .get ('A_impurity_override' )
238+
239+ # Set defaults for impurity if not specified. To maintain same default
240+ # behaviour as before, the top-level Z_impurity_override and
241+ # A_impurity_override are used as overrides for the impurity fractions.
242+ # TODO(b/434175938): Remove this once V1 API is deprecated and the top-level
243+ # overrides are removed, and set default directly in class attribute.
244+ if 'impurity' not in configurable_data :
245+ configurable_data ['impurity' ] = {
246+ 'impurity_mode' : 'fractions' ,
247+ 'Z_override' : Z_impurity_override ,
248+ 'A_override' : A_impurity_override ,
249+ }
250+ return configurable_data
251+
252+ impurity_data = configurable_data ['impurity' ]
253+
254+ # New API format: impurity_mode is specified.
255+ if isinstance (impurity_data , dict ) and 'impurity_mode' in impurity_data :
256+ if Z_impurity_override is not None or A_impurity_override is not None :
257+ logging .warning (
258+ 'Z_impurity_override and/or A_impurity_override are set at the'
259+ ' plasma_composition level, but the new impurity API is being used'
260+ ' (impurity_mode is set). These top-level overrides are deprecated'
261+ ' and will be ignored. Use Z_override and A_override within the'
262+ ' impurity dictionary instead.'
263+ )
264+ return configurable_data
265+
266+ # Legacy format from here on.
267+ # This handles conformant V1 inputs like 'Ne' or {'Ne': 0.8, 'Ar': 0.2}.
268+ # Non-conformant inputs are caught by ImpurityFractionsModel validation.
269+ # TODO(b/434175938): Remove this once V1 API is deprecated.
270+ configurable_data ['impurity' ] = {
271+ 'impurity_mode' : 'fractions' ,
272+ 'species' : impurity_data ,
273+ 'Z_override' : Z_impurity_override ,
274+ 'A_override' : A_impurity_override ,
275+ }
276+ return configurable_data
159277
160278 def tree_flatten (self ):
161279 # Override the default tree_flatten to also save out the cached
@@ -169,7 +287,6 @@ def tree_flatten(self):
169287 self .Z_impurity_override ,
170288 self .A_impurity_override ,
171289 self ._main_ion_mixture ,
172- self ._impurity_mixture ,
173290 )
174291 aux_data = ()
175292 return children , aux_data
@@ -186,7 +303,6 @@ def tree_unflatten(cls, aux_data, children):
186303 A_impurity_override = children [6 ],
187304 )
188305 obj ._main_ion_mixture = children [7 ] # pylint: disable=protected-access
189- obj ._impurity_mixture = children [8 ] # pylint: disable=protected-access
190306 return obj
191307
192308 @functools .cached_property
@@ -199,28 +315,18 @@ def _main_ion_mixture(self) -> IonMixture:
199315 A_override = self .A_i_override ,
200316 )
201317
202- @functools .cached_property
203- def _impurity_mixture (self ) -> IonMixture :
204- """Returns the IonMixture object for the impurity ions."""
205- # Use `model_construct` as no validation required.
206- return IonMixture .model_construct (
207- species = self .impurity ,
208- Z_override = self .Z_impurity_override ,
209- A_override = self .A_impurity_override ,
210- )
211-
212318 def get_main_ion_names (self ) -> tuple [str , ...]:
213319 """Returns the main ion symbol strings from the input."""
214320 return tuple (self ._main_ion_mixture .species .keys ())
215321
216322 def get_impurity_names (self ) -> tuple [str , ...]:
217323 """Returns the impurity symbol strings from the input."""
218- return tuple (self ._impurity_mixture .species .keys ())
324+ return tuple (self .impurity .species .keys ())
219325
220326 def build_dynamic_params (self , t : chex .Numeric ) -> DynamicPlasmaComposition :
221327 return DynamicPlasmaComposition (
222328 main_ion = self ._main_ion_mixture .build_dynamic_params (t ),
223- impurity = self ._impurity_mixture .build_dynamic_params (t ),
329+ impurity = self .impurity .build_dynamic_params (t ),
224330 Z_eff = self .Z_eff .get_value (t ),
225331 Z_eff_face = self .Z_eff .get_value (t , grid_type = 'face' ),
226332 )
0 commit comments