Skip to content

Commit c146f05

Browse files
jcitrinTorax team
authored andcommitted
Add non-inductive current outputs to post-processing.
This change introduces several new outputs related to non-inductive current: total non-inductive current density (toroidal and parallel), total external current, total non-inductive current, and the non-inductive current fraction. These are derived from existing bootstrap and external current sources. A new test case is added to verify these calculations, and the output documentation is updated. Sim tests are regenerated due to the new outputs. PiperOrigin-RevId: 861313287
1 parent ee68401 commit c146f05

File tree

54 files changed

+148
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+148
-0
lines changed

docs/output.rst

Lines changed: 24 additions & 0 deletions

torax/_src/output_tools/post_processing.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,14 @@ class PostProcessedOutputs:
172172
source [Am^-2]
173173
j_parallel_ecrh: Toroidal current density from electron cyclotron heating
174174
and current source [Am^-2]
175+
j_non_inductive: Total toroidal non-inductive current density [Am^-2]
176+
j_parallel_non_inductive: Total parallel non-inductive current density
177+
[Am^-2]
178+
I_external: Total external current [A]
179+
I_non_inductive: Total non-inductive current [A]
180+
f_non_inductive: Non-inductive current fraction of the total current
181+
[dimensionless]
182+
f_bootstrap: Bootstrap current fraction of the total current [dimensionless]
175183
S_gas_puff: Integrated gas puff source [s^-1]
176184
S_pellet: Integrated pellet source [s^-1]
177185
S_generic_particle: Integrated generic particle source [s^-1]
@@ -267,12 +275,19 @@ class PostProcessedOutputs:
267275
# TODO(b/434175938): rename j_* to j_toroidal_* for clarity
268276
j_parallel_total: array_typing.FloatVector
269277
j_external: array_typing.FloatVector
278+
j_parallel_external: array_typing.FloatVector
270279
j_ohmic: array_typing.FloatVector
271280
j_parallel_ohmic: array_typing.FloatVector
272281
j_bootstrap: array_typing.FloatVector
273282
j_bootstrap_face: array_typing.FloatVector
274283
j_generic_current: array_typing.FloatVector
275284
j_ecrh: array_typing.FloatVector
285+
j_non_inductive: array_typing.FloatVector
286+
j_parallel_non_inductive: array_typing.FloatVector
287+
I_external: array_typing.FloatScalar
288+
I_non_inductive: array_typing.FloatScalar
289+
f_non_inductive: array_typing.FloatScalar
290+
f_bootstrap: array_typing.FloatScalar
276291
S_gas_puff: array_typing.FloatScalar
277292
S_pellet: array_typing.FloatScalar
278293
S_generic_particle: array_typing.FloatScalar
@@ -374,6 +389,13 @@ def zeros(cls, geo: geometry.Geometry) -> typing_extensions.Self:
374389
j_external=jnp.zeros(geo.rho_face.shape),
375390
j_generic_current=jnp.zeros(geo.rho_face.shape),
376391
j_ecrh=jnp.zeros(geo.rho_face.shape),
392+
j_non_inductive=jnp.zeros(geo.rho_face.shape),
393+
j_parallel_external=jnp.zeros(geo.rho_face.shape),
394+
j_parallel_non_inductive=jnp.zeros(geo.rho_face.shape),
395+
I_external=jnp.array(0.0, dtype=jax_utils.get_dtype()),
396+
I_non_inductive=jnp.array(0.0, dtype=jax_utils.get_dtype()),
397+
f_non_inductive=jnp.array(0.0, dtype=jax_utils.get_dtype()),
398+
f_bootstrap=jnp.array(0.0, dtype=jax_utils.get_dtype()),
377399
S_gas_puff=jnp.array(0.0, dtype=jax_utils.get_dtype()),
378400
S_pellet=jnp.array(0.0, dtype=jax_utils.get_dtype()),
379401
S_generic_particle=jnp.array(0.0, dtype=jax_utils.get_dtype()),
@@ -925,6 +947,10 @@ def cumulative_values():
925947
I_bootstrap = math_utils.area_integration(
926948
j_toroidal_bootstrap, sim_state.geometry
927949
)
950+
I_external = math_utils.area_integration(
951+
j_toroidal_external, sim_state.geometry
952+
)
953+
I_non_inductive = I_bootstrap + I_external
928954

929955
beta_tor, beta_pol, beta_N = formulas.calculate_betas(
930956
sim_state.core_profiles, sim_state.geometry
@@ -998,6 +1024,17 @@ def cumulative_values():
9981024
j_external=j_toroidal_external,
9991025
j_ecrh=j_toroidal_sources['j_ecrh'],
10001026
j_generic_current=j_toroidal_sources['j_generic_current'],
1027+
j_non_inductive=j_toroidal_bootstrap + j_toroidal_external,
1028+
j_parallel_external=j_parallel_external,
1029+
j_parallel_non_inductive=j_parallel_bootstrap + j_parallel_external,
1030+
I_external=I_external,
1031+
I_non_inductive=I_non_inductive,
1032+
f_non_inductive=math_utils.safe_divide(
1033+
I_non_inductive, sim_state.core_profiles.Ip_profile_face[-1]
1034+
),
1035+
f_bootstrap=math_utils.safe_divide(
1036+
I_bootstrap, sim_state.core_profiles.Ip_profile_face[-1]
1037+
),
10011038
beta_tor=beta_tor,
10021039
beta_pol=beta_pol,
10031040
beta_N=beta_N,

torax/_src/output_tools/tests/post_processing_test.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import dataclasses
1516
from absl.testing import absltest
1617
from absl.testing import parameterized
1718
from jax import numpy as jnp
1819
import numpy as np
1920
import scipy
2021
from torax._src import jax_utils
22+
from torax._src import math_utils
2123
from torax._src import state
2224
from torax._src.config import build_runtime_params
2325
from torax._src.core_profiles import initialization
@@ -31,6 +33,8 @@
3133
from torax._src.test_utils import sim_test_case
3234
from torax._src.torax_pydantic import model_config
3335

36+
# pylint: disable=invalid-name
37+
3438

3539
class PostProcessingTest(parameterized.TestCase):
3640

@@ -234,6 +238,89 @@ def test_zero_sources_do_not_make_nans(self):
234238
post_processed_outputs.check_for_errors(), state.SimError.NO_ERROR
235239
)
236240

241+
def test_current_outputs(self):
242+
"""Checks calculation of current-related outputs."""
243+
# Setup non-zero bootstrap current
244+
ones = np.ones_like(self.geo.rho)
245+
bootstrap_current = bootstrap_current_base.BootstrapCurrent(
246+
j_parallel_bootstrap=1.0 * ones,
247+
j_parallel_bootstrap_face=1.0 * np.ones_like(self.geo.rho_face),
248+
)
249+
250+
# Source profiles with parallel currents (normalized <j.B>/B0)
251+
# generic: 2.0, ecrh: 2.0. Sum external = 4.0.
252+
source_profiles = dataclasses.replace(
253+
self.source_profiles, bootstrap_current=bootstrap_current
254+
)
255+
256+
# Mock j_total (toroidal).
257+
# We leave core_profiles as is (from setUp), but update source_profiles.
258+
259+
input_state = sim_state.SimState(
260+
t=jnp.array(0.0),
261+
dt=jnp.array(1e-3),
262+
core_profiles=self.core_profiles,
263+
core_transport=state.CoreTransport.zeros(self.geo),
264+
core_sources=source_profiles,
265+
geometry=self.geo,
266+
solver_numeric_outputs=state.SolverNumericOutputs(
267+
solver_error_state=np.array(0, jax_utils.get_int_dtype()),
268+
outer_solver_iterations=np.array(0, jax_utils.get_int_dtype()),
269+
inner_solver_iterations=np.array(0, jax_utils.get_int_dtype()),
270+
sawtooth_crash=False,
271+
),
272+
edge_outputs=None,
273+
)
274+
275+
outputs = post_processing.make_post_processed_outputs(
276+
sim_state=input_state,
277+
runtime_params=self.runtime_params,
278+
previous_post_processed_outputs=post_processing.PostProcessedOutputs.zeros(
279+
self.geo
280+
),
281+
)
282+
283+
# Check parallel currents
284+
# external = sum(psi) = 2 + 2 = 4
285+
np.testing.assert_allclose(outputs.j_parallel_external, 4.0 * ones)
286+
287+
# non_inductive = bootstrap (1.0) + external (4.0) = 5.0
288+
np.testing.assert_allclose(outputs.j_parallel_non_inductive, 5.0 * ones)
289+
290+
# Check toroidal currents relation
291+
# j_non_inductive should be j_bootstrap + j_external
292+
np.testing.assert_allclose(
293+
outputs.j_non_inductive, outputs.j_bootstrap + outputs.j_external
294+
)
295+
296+
# Check integrated currents
297+
# I_non_inductive should be I_bootstrap + area_int(j_external)
298+
I_external = math_utils.area_integration(outputs.j_external, self.geo)
299+
np.testing.assert_allclose(outputs.I_external, I_external)
300+
301+
np.testing.assert_allclose(
302+
outputs.I_non_inductive, outputs.I_bootstrap + I_external
303+
)
304+
305+
# Check fraction
306+
# f_non_inductive = I_non_inductive / Ip
307+
# Ip comes from core_profiles.Ip_profile_face[-1]
308+
ip = self.core_profiles.Ip_profile_face[-1]
309+
# Code uses constants.CONSTANTS.eps for division guard
310+
np.testing.assert_allclose(
311+
outputs.f_non_inductive,
312+
math_utils.safe_divide(outputs.I_non_inductive, ip),
313+
rtol=1e-5,
314+
)
315+
316+
# Check bootstrap fraction
317+
# f_bootstrap = I_bootstrap / Ip
318+
np.testing.assert_allclose(
319+
outputs.f_bootstrap,
320+
math_utils.safe_divide(outputs.I_bootstrap, ip),
321+
rtol=1e-5,
322+
)
323+
237324

238325
class PostProcessingSimTest(sim_test_case.SimTestCase):
239326
"""Tests for the cumulative outputs."""
15.3 KB
Binary file not shown.
21.5 KB
Binary file not shown.
7.26 KB
Binary file not shown.
11.8 KB
Binary file not shown.
12.4 KB
Binary file not shown.
10.2 KB
Binary file not shown.
21.9 KB
Binary file not shown.

0 commit comments

Comments
 (0)