Skip to content

Commit d950df2

Browse files
committed
refactor(sawtooth): move test references to local file
- Create regenerate_sawtooth_refs.py next to sawtooth tests - Create sawtooth_references.json for local reference storage - Update sawtooth_model_test.py to use local references - Remove sawtooth code from central regenerate_torax_refs.py - Remove sawtooth_references from torax_refs.py and references.json This keeps sawtooth test references co-located with the tests, making them easier to maintain independently. Fixes #1762
1 parent d8788c5 commit d950df2

File tree

5 files changed

+280
-275
lines changed

5 files changed

+280
-275
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
r"""Script to regenerate reference values for sawtooth model tests.
16+
17+
This script recalculates the sawtooth crash reference values and saves them
18+
to a local JSON file: `sawtooth_references.json` in this test directory.
19+
20+
Usage Examples:
21+
22+
# To regenerate references and print a summary:
23+
python -m torax._src.mhd.sawtooth.tests.regenerate_sawtooth_refs
24+
25+
# To regenerate and save to JSON file:
26+
python -m torax._src.mhd.sawtooth.tests.regenerate_sawtooth_refs --write_to_file
27+
"""
28+
29+
from collections.abc import Sequence
30+
import json
31+
import logging
32+
import pathlib
33+
import pprint
34+
from typing import Any
35+
36+
from absl import app
37+
from absl import flags
38+
import numpy as np
39+
40+
from torax._src.config import build_runtime_params
41+
from torax._src.orchestration import initial_state as initial_state_lib
42+
from torax._src.orchestration import step_function
43+
from torax._src.torax_pydantic import model_config
44+
45+
46+
FLAGS = flags.FLAGS
47+
48+
_WRITE_TO_FILE = flags.DEFINE_bool(
49+
'write_to_file',
50+
False,
51+
'If True, saves the new reference values to sawtooth_references.json.',
52+
)
53+
_PRINT_SUMMARY = flags.DEFINE_bool(
54+
'print_summary',
55+
True,
56+
'If True, prints the arrays to the console.',
57+
)
58+
59+
# Test configuration constants
60+
NRHO = 10
61+
CRASH_STEP_DURATION = 1e-3
62+
FIXED_DT = 0.1
63+
64+
# Path to the local references file
65+
REFERENCES_FILE = 'sawtooth_references.json'
66+
67+
68+
class NumpyEncoder(json.JSONEncoder):
69+
"""Custom JSON encoder for NumPy types."""
70+
71+
def default(self, o):
72+
if isinstance(o, np.ndarray):
73+
return o.tolist()
74+
if isinstance(o, np.integer):
75+
return int(o)
76+
if isinstance(o, np.floating):
77+
return float(o)
78+
return super().default(o)
79+
80+
81+
def get_sawtooth_test_config() -> dict[str, Any]:
82+
"""Returns the test configuration dictionary for sawtooth tests.
83+
84+
This configuration is shared between the test and reference generation.
85+
"""
86+
return {
87+
'numerics': {
88+
'evolve_current': True,
89+
'evolve_density': True,
90+
'evolve_ion_heat': True,
91+
'evolve_electron_heat': True,
92+
'fixed_dt': FIXED_DT,
93+
},
94+
# Default initial current will lead to a sawtooth being triggered.
95+
'profile_conditions': {
96+
'Ip': 13e6,
97+
'initial_j_is_total_current': True,
98+
'initial_psi_from_j': True,
99+
'current_profile_nu': 3,
100+
'n_e_nbar_is_fGW': True,
101+
'normalize_n_e_to_nbar': True,
102+
'nbar': 0.85,
103+
'n_e': {0: {0.0: 1.5, 1.0: 1.0}},
104+
},
105+
'plasma_composition': {},
106+
'geometry': {'geometry_type': 'circular', 'n_rho': NRHO},
107+
'pedestal': {},
108+
'sources': {'ohmic': {}},
109+
'solver': {
110+
'solver_type': 'linear',
111+
'use_pereverzev': False,
112+
},
113+
'time_step_calculator': {'calculator_type': 'fixed'},
114+
'transport': {'model_name': 'constant'},
115+
'mhd': {
116+
'sawtooth': {
117+
'trigger_model': {
118+
'model_name': 'simple',
119+
'minimum_radius': 0.2,
120+
's_critical': 0.2,
121+
},
122+
'redistribution_model': {
123+
'model_name': 'simple',
124+
'flattening_factor': 1.01,
125+
'mixing_radius_multiplier': 1.5,
126+
},
127+
'crash_step_duration': CRASH_STEP_DURATION,
128+
}
129+
},
130+
}
131+
132+
133+
def calculate_sawtooth_crash_references() -> dict[str, Any]:
134+
"""Calculates sawtooth crash reference values by running a simulation step.
135+
136+
This function:
137+
1. Builds a test configuration that triggers a sawtooth crash
138+
2. Runs one simulation step
139+
3. Verifies the crash occurred
140+
4. Returns post-crash profile values
141+
142+
Returns:
143+
Dictionary containing post-crash reference values.
144+
145+
Raises:
146+
ValueError: If sawtooth crash did not occur.
147+
"""
148+
test_config_dict = get_sawtooth_test_config()
149+
torax_config = model_config.ToraxConfig.from_dict(test_config_dict)
150+
151+
# Build solver and step function
152+
solver = torax_config.solver.build_solver(
153+
physics_models=torax_config.build_physics_models(),
154+
)
155+
geometry_provider = torax_config.geometry.build_provider
156+
runtime_params_provider = (
157+
build_runtime_params.RuntimeParamsProvider.from_config(torax_config)
158+
)
159+
160+
step_fn = step_function.SimulationStepFn(
161+
solver=solver,
162+
time_step_calculator=torax_config.time_step_calculator.time_step_calculator,
163+
geometry_provider=geometry_provider,
164+
runtime_params_provider=runtime_params_provider,
165+
)
166+
167+
# Get initial state
168+
initial_state, initial_post_processed_outputs = (
169+
initial_state_lib.get_initial_state_and_post_processed_outputs(
170+
t=torax_config.numerics.t_initial,
171+
runtime_params_provider=runtime_params_provider,
172+
geometry_provider=geometry_provider,
173+
step_fn=step_fn,
174+
)
175+
)
176+
177+
# Run one step - this should trigger sawtooth crash
178+
output_state, _ = step_fn(
179+
input_state=initial_state,
180+
previous_post_processed_outputs=initial_post_processed_outputs,
181+
)
182+
183+
# Verify sawtooth crash occurred
184+
if not output_state.solver_numeric_outputs.sawtooth_crash:
185+
raise ValueError(
186+
'Sawtooth crash did not occur! Check sawtooth model configuration.'
187+
)
188+
189+
# Extract post-crash profiles
190+
return {
191+
'post_crash_temperature': np.asarray(output_state.core_profiles.T_e.value),
192+
'post_crash_n': np.asarray(output_state.core_profiles.n_e.value),
193+
'post_crash_psi': np.asarray(output_state.core_profiles.psi.value),
194+
}
195+
196+
197+
def _print_full_summary(new_values: dict[str, np.ndarray]):
198+
"""Prints the full regenerated reference values for inspection."""
199+
pretty_printer = pprint.PrettyPrinter(indent=4, width=100)
200+
logging.info('Sawtooth crash reference values:')
201+
for name, value in new_values.items():
202+
logging.info(' %s:', name)
203+
pretty_printer.pprint(value)
204+
print('-' * 20)
205+
206+
207+
def get_references_path() -> pathlib.Path:
208+
"""Returns the path to the local references file."""
209+
return pathlib.Path(__file__).parent / REFERENCES_FILE
210+
211+
212+
def main(argv: Sequence[str]) -> None:
213+
if len(argv) > 1:
214+
raise app.UsageError('Too many command-line arguments.')
215+
216+
output_path = get_references_path()
217+
218+
np.set_printoptions(
219+
precision=12, suppress=True, threshold=np.inf, linewidth=100
220+
)
221+
222+
logging.info('Regenerating sawtooth crash references...')
223+
new_values = calculate_sawtooth_crash_references()
224+
225+
if _PRINT_SUMMARY.value:
226+
_print_full_summary(new_values)
227+
228+
if _WRITE_TO_FILE.value:
229+
logging.info('Writing regenerated data to %s...', output_path)
230+
with open(output_path, 'w') as f:
231+
json.dump(new_values, f, indent=2, cls=NumpyEncoder)
232+
logging.info('Done.')
233+
else:
234+
logging.info(
235+
'Finished dry run. Not writing to file. Use --write_to_file to save.'
236+
)
237+
238+
239+
if __name__ == '__main__':
240+
app.run(main)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"post_crash_temperature": [
3+
9.80214763867072,
4+
9.774495567770007,
5+
9.746821543811965,
6+
9.719125386103409,
7+
9.69140691196268,
8+
8.179370745839305,
9+
6.225896595777405,
10+
4.499999999999998,
11+
3.0999999999999988,
12+
1.699999999999999
13+
],
14+
"post_crash_n": [
15+
9.290543832529517e+19,
16+
9.265262111341111e+19,
17+
9.239980390152702e+19,
18+
9.214698668964295e+19,
19+
9.189416947775886e+19,
20+
8.817802446668299e+19,
21+
8.345056959984748e+19,
22+
7.921901446691185e+19,
23+
7.569816937949354e+19,
24+
7.2177324292075225e+19
25+
],
26+
"post_crash_psi": [
27+
9.778742408396024,
28+
11.342102036297204,
29+
14.360383604641752,
30+
18.737049105346237,
31+
24.378127742662265,
32+
31.058184699891584,
33+
38.12617379019407,
34+
44.844898859093725,
35+
50.74281519306403,
36+
55.72986584479574
37+
]
38+
}

0 commit comments

Comments
 (0)