Skip to content

Commit f01509f

Browse files
committed
refactor(sawtooth): move test references to local file
- Create regenerate_sawtooth_refs.py next to sawtooth tests - Create sawtooth_references.json for local reference storage - Update sawtooth_model_test.py to use local references - Remove sawtooth code from central regenerate_torax_refs.py - Remove sawtooth_references from torax_refs.py and references.json This keeps sawtooth test references co-located with the tests, making them easier to maintain independently. Fixes #1762
1 parent ee8a0f8 commit f01509f

File tree

6 files changed

+306
-340
lines changed

6 files changed

+306
-340
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
r"""Script to regenerate reference values for sawtooth model tests.
16+
17+
This script recalculates the sawtooth crash reference values and saves them
18+
to a local JSON file: `sawtooth_references.json` in this test directory.
19+
20+
Usage Examples:
21+
22+
# To regenerate references and print a summary:
23+
python -m torax._src.mhd.sawtooth.tests.regenerate_sawtooth_refs
24+
25+
# To regenerate and save to JSON file:
26+
python -m torax._src.mhd.sawtooth.tests.regenerate_sawtooth_refs --write_to_file
27+
"""
28+
29+
from collections.abc import Sequence
30+
import json
31+
import logging
32+
import pathlib
33+
import pprint
34+
from typing import Any
35+
36+
from absl import app
37+
from absl import flags
38+
import numpy as np
39+
40+
from torax._src.config import build_runtime_params
41+
from torax._src.orchestration import initial_state as initial_state_lib
42+
from torax._src.orchestration import step_function
43+
from torax._src.torax_pydantic import model_config
44+
45+
46+
FLAGS = flags.FLAGS
47+
48+
_WRITE_TO_FILE = flags.DEFINE_bool(
49+
'write_to_file',
50+
False,
51+
'If True, saves the new reference values to sawtooth_references.json.',
52+
)
53+
_PRINT_SUMMARY = flags.DEFINE_bool(
54+
'print_summary',
55+
True,
56+
'If True, prints the arrays to the console.',
57+
)
58+
59+
# Test configuration constants
60+
NRHO = 10
61+
CRASH_STEP_DURATION = 1e-3
62+
FIXED_DT = 0.1
63+
64+
# Path to the local references file
65+
REFERENCES_FILE = 'sawtooth_references.json'
66+
67+
68+
class NumpyEncoder(json.JSONEncoder):
69+
"""Custom JSON encoder for NumPy types."""
70+
71+
def default(self, o):
72+
if isinstance(o, np.ndarray):
73+
return o.tolist()
74+
if isinstance(o, np.integer):
75+
return int(o)
76+
if isinstance(o, np.floating):
77+
return float(o)
78+
return super().default(o)
79+
80+
81+
def get_sawtooth_test_config() -> dict[str, Any]:
82+
"""Returns the test configuration dictionary for sawtooth tests.
83+
84+
This configuration is shared between the test and reference generation.
85+
"""
86+
return {
87+
'numerics': {
88+
'evolve_current': True,
89+
'evolve_density': True,
90+
'evolve_ion_heat': True,
91+
'evolve_electron_heat': True,
92+
'fixed_dt': FIXED_DT,
93+
},
94+
# Default initial current will lead to a sawtooth being triggered.
95+
'profile_conditions': {
96+
'Ip': 13e6,
97+
'initial_j_is_total_current': True,
98+
'initial_psi_from_j': True,
99+
'current_profile_nu': 3,
100+
'n_e_nbar_is_fGW': True,
101+
'normalize_n_e_to_nbar': True,
102+
'nbar': 0.85,
103+
'n_e': {0: {0.0: 1.5, 1.0: 1.0}},
104+
},
105+
'plasma_composition': {},
106+
'geometry': {'geometry_type': 'circular', 'n_rho': NRHO},
107+
'pedestal': {},
108+
'sources': {'ohmic': {}},
109+
'solver': {
110+
'solver_type': 'linear',
111+
'use_pereverzev': False,
112+
},
113+
'time_step_calculator': {'calculator_type': 'fixed'},
114+
'transport': {'model_name': 'constant'},
115+
'mhd': {
116+
'sawtooth': {
117+
'trigger_model': {
118+
'model_name': 'simple',
119+
'minimum_radius': 0.2,
120+
's_critical': 0.2,
121+
},
122+
'redistribution_model': {
123+
'model_name': 'simple',
124+
'flattening_factor': 1.01,
125+
'mixing_radius_multiplier': 1.5,
126+
},
127+
'crash_step_duration': CRASH_STEP_DURATION,
128+
}
129+
},
130+
}
131+
132+
133+
def calculate_sawtooth_crash_references() -> dict[str, Any]:
134+
"""Calculates sawtooth crash reference values by running a simulation step.
135+
136+
This function:
137+
1. Builds a test configuration that triggers a sawtooth crash
138+
2. Runs one simulation step
139+
3. Verifies the crash occurred
140+
4. Returns post-crash profile values
141+
142+
Returns:
143+
Dictionary containing post-crash reference values.
144+
145+
Raises:
146+
ValueError: If sawtooth crash did not occur.
147+
"""
148+
test_config_dict = get_sawtooth_test_config()
149+
torax_config = model_config.ToraxConfig.from_dict(test_config_dict)
150+
151+
# Build solver and step function
152+
solver = torax_config.solver.build_solver(
153+
physics_models=torax_config.build_physics_models(),
154+
)
155+
geometry_provider = torax_config.geometry.build_provider
156+
runtime_params_provider = (
157+
build_runtime_params.RuntimeParamsProvider.from_config(torax_config)
158+
)
159+
160+
step_fn = step_function.SimulationStepFn(
161+
solver=solver,
162+
time_step_calculator=torax_config.time_step_calculator.time_step_calculator,
163+
geometry_provider=geometry_provider,
164+
runtime_params_provider=runtime_params_provider,
165+
)
166+
167+
# Get initial state
168+
initial_state, initial_post_processed_outputs = (
169+
initial_state_lib.get_initial_state_and_post_processed_outputs(
170+
t=torax_config.numerics.t_initial,
171+
runtime_params_provider=runtime_params_provider,
172+
geometry_provider=geometry_provider,
173+
step_fn=step_fn,
174+
)
175+
)
176+
177+
# Run one step - this should trigger sawtooth crash
178+
output_state, _ = step_fn(
179+
input_state=initial_state,
180+
previous_post_processed_outputs=initial_post_processed_outputs,
181+
)
182+
183+
# Verify sawtooth crash occurred
184+
if not output_state.solver_numeric_outputs.sawtooth_crash:
185+
raise ValueError(
186+
'Sawtooth crash did not occur! Check sawtooth model configuration.'
187+
)
188+
189+
# Extract post-crash profiles
190+
return {
191+
'post_crash_temperature': np.asarray(output_state.core_profiles.T_e.value),
192+
'post_crash_n': np.asarray(output_state.core_profiles.n_e.value),
193+
'post_crash_psi': np.asarray(output_state.core_profiles.psi.value),
194+
}
195+
196+
197+
def _print_full_summary(new_values: dict[str, np.ndarray]):
198+
"""Prints the full regenerated reference values for inspection."""
199+
pretty_printer = pprint.PrettyPrinter(indent=4, width=100)
200+
logging.info('Sawtooth crash reference values:')
201+
for name, value in new_values.items():
202+
logging.info(' %s:', name)
203+
pretty_printer.pprint(value)
204+
print('-' * 20)
205+
206+
207+
def get_references_path() -> pathlib.Path:
208+
"""Returns the path to the local references file."""
209+
return pathlib.Path(__file__).parent / REFERENCES_FILE
210+
211+
212+
def main(argv: Sequence[str]) -> None:
213+
if len(argv) > 1:
214+
raise app.UsageError('Too many command-line arguments.')
215+
216+
output_path = get_references_path()
217+
218+
np.set_printoptions(
219+
precision=12, suppress=True, threshold=np.inf, linewidth=100
220+
)
221+
222+
logging.info('Regenerating sawtooth crash references...')
223+
new_values = calculate_sawtooth_crash_references()
224+
225+
if _PRINT_SUMMARY.value:
226+
_print_full_summary(new_values)
227+
228+
if _WRITE_TO_FILE.value:
229+
logging.info('Writing regenerated data to %s...', output_path)
230+
with open(output_path, 'w') as f:
231+
json.dump(new_values, f, indent=2, cls=NumpyEncoder)
232+
logging.info('Done.')
233+
else:
234+
logging.info(
235+
'Finished dry run. Not writing to file. Use --write_to_file to save.'
236+
)
237+
238+
239+
if __name__ == '__main__':
240+
app.run(main)

torax/_src/mhd/sawtooth/tests/sawtooth_model_test.py

Lines changed: 26 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
### FILEPATH: torax/mhd/sawtooth/tests/sawtooth_model_test.py
21
# Copyright 2024 DeepMind Technologies Limited
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,79 +12,48 @@
1312
# See the License for the specific language governing permissions and
1413
# limitations under the License.
1514

15+
"""Sawtooth model integration tests."""
16+
1617
import dataclasses
18+
import json
19+
import pathlib
1720

1821
from absl.testing import absltest
1922
from absl.testing import parameterized
2023
import jax
2124
import numpy as np
2225
from torax._src import state
2326
from torax._src.config import build_runtime_params
27+
from torax._src.mhd.sawtooth.tests import regenerate_sawtooth_refs
2428
from torax._src.orchestration import initial_state as initial_state_lib
2529
from torax._src.orchestration import step_function
2630
from torax._src.torax_pydantic import model_config
27-
import json
28-
from torax._src import path_utils
2931

30-
_NRHO = 10
31-
_CRASH_STEP_DURATION = 1e-3
32-
_FIXED_DT = 0.1
32+
33+
# Import shared constants from the regeneration script
34+
_NRHO = regenerate_sawtooth_refs.NRHO
35+
_CRASH_STEP_DURATION = regenerate_sawtooth_refs.CRASH_STEP_DURATION
36+
_FIXED_DT = regenerate_sawtooth_refs.FIXED_DT
3337

3438
# Needed since we do not call torax.__init__ in this test, which normally sets
3539
# this.
3640
jax.config.update('jax_enable_x64', True)
3741

3842

43+
def _load_sawtooth_references() -> dict:
44+
"""Loads sawtooth reference values from the local JSON file."""
45+
json_path = pathlib.Path(__file__).parent / regenerate_sawtooth_refs.REFERENCES_FILE
46+
with open(json_path, 'r') as f:
47+
return json.load(f)
48+
49+
3950
class SawtoothModelTest(parameterized.TestCase):
4051
"""Sawtooth model integration tests by running the SimulationStepFn."""
4152

4253
def setUp(self):
4354
super().setUp()
44-
test_config_dict = {
45-
'numerics': {
46-
'evolve_current': True,
47-
'evolve_density': True,
48-
'evolve_ion_heat': True,
49-
'evolve_electron_heat': True,
50-
'fixed_dt': _FIXED_DT,
51-
},
52-
# Default initial current will lead to a sawtooth being triggered.
53-
'profile_conditions': {
54-
'Ip': 13e6,
55-
'initial_j_is_total_current': True,
56-
'initial_psi_from_j': True,
57-
'current_profile_nu': 3,
58-
'n_e_nbar_is_fGW': True,
59-
'normalize_n_e_to_nbar': True,
60-
'nbar': 0.85,
61-
'n_e': {0: {0.0: 1.5, 1.0: 1.0}},
62-
},
63-
'plasma_composition': {},
64-
'geometry': {'geometry_type': 'circular', 'n_rho': _NRHO},
65-
'pedestal': {},
66-
'sources': {'ohmic': {}},
67-
'solver': {
68-
'solver_type': 'linear',
69-
'use_pereverzev': False,
70-
},
71-
'time_step_calculator': {'calculator_type': 'fixed'},
72-
'transport': {'model_name': 'constant'},
73-
'mhd': {
74-
'sawtooth': {
75-
'trigger_model': {
76-
'model_name': 'simple',
77-
'minimum_radius': 0.2,
78-
's_critical': 0.2,
79-
},
80-
'redistribution_model': {
81-
'model_name': 'simple',
82-
'flattening_factor': 1.01,
83-
'mixing_radius_multiplier': 1.5,
84-
},
85-
'crash_step_duration': _CRASH_STEP_DURATION,
86-
}
87-
},
88-
}
55+
# Use the shared test configuration from the regeneration script
56+
test_config_dict = regenerate_sawtooth_refs.get_sawtooth_test_config()
8957
torax_config = model_config.ToraxConfig.from_dict(test_config_dict)
9058
self._torax_config = torax_config
9159

@@ -114,15 +82,11 @@ def setUp(self):
11482
)
11583
)
11684

117-
# Load sawtooth crash reference values from JSON
118-
json_path = path_utils.torax_path() / '_src' / 'test_utils' / 'references.json'
119-
with open(json_path, 'r') as f:
120-
all_refs = json.load(f)
121-
sawtooth_refs = all_refs.get('sawtooth_references', {})
122-
123-
self._post_crash_temperature = np.array(sawtooth_refs.get('post_crash_temperature', []))
124-
self._post_crash_n = np.array(sawtooth_refs.get('post_crash_n', []))
125-
self._post_crash_psi = np.array(sawtooth_refs.get('post_crash_psi', []))
85+
# Load sawtooth crash reference values from local JSON
86+
sawtooth_refs = _load_sawtooth_references()
87+
self._post_crash_temperature = np.array(sawtooth_refs['post_crash_temperature'])
88+
self._post_crash_n = np.array(sawtooth_refs['post_crash_n'])
89+
self._post_crash_psi = np.array(sawtooth_refs['post_crash_psi'])
12690

12791
def test_sawtooth_crash(self):
12892
"""Tests that default values lead to crash and compares post-crash to ref."""
@@ -153,12 +117,12 @@ def test_sawtooth_crash(self):
153117
np.testing.assert_allclose(
154118
output_state.core_profiles.n_e.value,
155119
self._post_crash_n,
156-
rtol=1e-6
120+
rtol=1e-6,
157121
)
158122
np.testing.assert_allclose(
159123
output_state.core_profiles.psi.value,
160124
self._post_crash_psi,
161-
rtol=1e-6
125+
rtol=1e-6,
162126
)
163127

164128
def test_no_sawtooth_crash(self):
@@ -263,8 +227,5 @@ def test_no_subsequent_sawtooth_crashes(self):
263227
)
264228

265229

266-
267-
268-
269230
if __name__ == '__main__':
270231
absltest.main()

0 commit comments

Comments
 (0)