|
| 1 | +# Copyright 2024 DeepMind Technologies Limited |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +r"""Script to regenerate reference values for sawtooth model tests. |
| 16 | +
|
| 17 | +This script recalculates the sawtooth crash reference values and saves them |
| 18 | +to a local JSON file: `sawtooth_references.json` in this test directory. |
| 19 | +
|
| 20 | +Usage Examples: |
| 21 | +
|
| 22 | +# To regenerate references and print a summary: |
| 23 | +python -m torax._src.mhd.sawtooth.tests.regenerate_sawtooth_refs |
| 24 | +
|
| 25 | +# To regenerate and save to JSON file: |
| 26 | +python -m torax._src.mhd.sawtooth.tests.regenerate_sawtooth_refs --write_to_file |
| 27 | +""" |
| 28 | + |
| 29 | +from collections.abc import Sequence |
| 30 | +import json |
| 31 | +import logging |
| 32 | +import pathlib |
| 33 | +import pprint |
| 34 | +from typing import Any |
| 35 | + |
| 36 | +from absl import app |
| 37 | +from absl import flags |
| 38 | +import numpy as np |
| 39 | + |
| 40 | +from torax._src.config import build_runtime_params |
| 41 | +from torax._src.orchestration import initial_state as initial_state_lib |
| 42 | +from torax._src.orchestration import step_function |
| 43 | +from torax._src.torax_pydantic import model_config |
| 44 | + |
| 45 | + |
| 46 | +FLAGS = flags.FLAGS |
| 47 | + |
| 48 | +_WRITE_TO_FILE = flags.DEFINE_bool( |
| 49 | + 'write_to_file', |
| 50 | + False, |
| 51 | + 'If True, saves the new reference values to sawtooth_references.json.', |
| 52 | +) |
| 53 | +_PRINT_SUMMARY = flags.DEFINE_bool( |
| 54 | + 'print_summary', |
| 55 | + True, |
| 56 | + 'If True, prints the arrays to the console.', |
| 57 | +) |
| 58 | + |
| 59 | +# Test configuration constants |
| 60 | +NRHO = 10 |
| 61 | +CRASH_STEP_DURATION = 1e-3 |
| 62 | +FIXED_DT = 0.1 |
| 63 | + |
| 64 | +# Path to the local references file |
| 65 | +REFERENCES_FILE = 'sawtooth_references.json' |
| 66 | + |
| 67 | + |
| 68 | +class NumpyEncoder(json.JSONEncoder): |
| 69 | + """Custom JSON encoder for NumPy types.""" |
| 70 | + |
| 71 | + def default(self, o): |
| 72 | + if isinstance(o, np.ndarray): |
| 73 | + return o.tolist() |
| 74 | + if isinstance(o, np.integer): |
| 75 | + return int(o) |
| 76 | + if isinstance(o, np.floating): |
| 77 | + return float(o) |
| 78 | + return super().default(o) |
| 79 | + |
| 80 | + |
| 81 | +def get_sawtooth_test_config() -> dict[str, Any]: |
| 82 | + """Returns the test configuration dictionary for sawtooth tests. |
| 83 | + |
| 84 | + This configuration is shared between the test and reference generation. |
| 85 | + """ |
| 86 | + return { |
| 87 | + 'numerics': { |
| 88 | + 'evolve_current': True, |
| 89 | + 'evolve_density': True, |
| 90 | + 'evolve_ion_heat': True, |
| 91 | + 'evolve_electron_heat': True, |
| 92 | + 'fixed_dt': FIXED_DT, |
| 93 | + }, |
| 94 | + # Default initial current will lead to a sawtooth being triggered. |
| 95 | + 'profile_conditions': { |
| 96 | + 'Ip': 13e6, |
| 97 | + 'initial_j_is_total_current': True, |
| 98 | + 'initial_psi_from_j': True, |
| 99 | + 'current_profile_nu': 3, |
| 100 | + 'n_e_nbar_is_fGW': True, |
| 101 | + 'normalize_n_e_to_nbar': True, |
| 102 | + 'nbar': 0.85, |
| 103 | + 'n_e': {0: {0.0: 1.5, 1.0: 1.0}}, |
| 104 | + }, |
| 105 | + 'plasma_composition': {}, |
| 106 | + 'geometry': {'geometry_type': 'circular', 'n_rho': NRHO}, |
| 107 | + 'pedestal': {}, |
| 108 | + 'sources': {'ohmic': {}}, |
| 109 | + 'solver': { |
| 110 | + 'solver_type': 'linear', |
| 111 | + 'use_pereverzev': False, |
| 112 | + }, |
| 113 | + 'time_step_calculator': {'calculator_type': 'fixed'}, |
| 114 | + 'transport': {'model_name': 'constant'}, |
| 115 | + 'mhd': { |
| 116 | + 'sawtooth': { |
| 117 | + 'trigger_model': { |
| 118 | + 'model_name': 'simple', |
| 119 | + 'minimum_radius': 0.2, |
| 120 | + 's_critical': 0.2, |
| 121 | + }, |
| 122 | + 'redistribution_model': { |
| 123 | + 'model_name': 'simple', |
| 124 | + 'flattening_factor': 1.01, |
| 125 | + 'mixing_radius_multiplier': 1.5, |
| 126 | + }, |
| 127 | + 'crash_step_duration': CRASH_STEP_DURATION, |
| 128 | + } |
| 129 | + }, |
| 130 | + } |
| 131 | + |
| 132 | + |
| 133 | +def calculate_sawtooth_crash_references() -> dict[str, Any]: |
| 134 | + """Calculates sawtooth crash reference values by running a simulation step. |
| 135 | +
|
| 136 | + This function: |
| 137 | + 1. Builds a test configuration that triggers a sawtooth crash |
| 138 | + 2. Runs one simulation step |
| 139 | + 3. Verifies the crash occurred |
| 140 | + 4. Returns post-crash profile values |
| 141 | +
|
| 142 | + Returns: |
| 143 | + Dictionary containing post-crash reference values. |
| 144 | +
|
| 145 | + Raises: |
| 146 | + ValueError: If sawtooth crash did not occur. |
| 147 | + """ |
| 148 | + test_config_dict = get_sawtooth_test_config() |
| 149 | + torax_config = model_config.ToraxConfig.from_dict(test_config_dict) |
| 150 | + |
| 151 | + # Build solver and step function |
| 152 | + solver = torax_config.solver.build_solver( |
| 153 | + physics_models=torax_config.build_physics_models(), |
| 154 | + ) |
| 155 | + geometry_provider = torax_config.geometry.build_provider |
| 156 | + runtime_params_provider = ( |
| 157 | + build_runtime_params.RuntimeParamsProvider.from_config(torax_config) |
| 158 | + ) |
| 159 | + |
| 160 | + step_fn = step_function.SimulationStepFn( |
| 161 | + solver=solver, |
| 162 | + time_step_calculator=torax_config.time_step_calculator.time_step_calculator, |
| 163 | + geometry_provider=geometry_provider, |
| 164 | + runtime_params_provider=runtime_params_provider, |
| 165 | + ) |
| 166 | + |
| 167 | + # Get initial state |
| 168 | + initial_state, initial_post_processed_outputs = ( |
| 169 | + initial_state_lib.get_initial_state_and_post_processed_outputs( |
| 170 | + t=torax_config.numerics.t_initial, |
| 171 | + runtime_params_provider=runtime_params_provider, |
| 172 | + geometry_provider=geometry_provider, |
| 173 | + step_fn=step_fn, |
| 174 | + ) |
| 175 | + ) |
| 176 | + |
| 177 | + # Run one step - this should trigger sawtooth crash |
| 178 | + output_state, _ = step_fn( |
| 179 | + input_state=initial_state, |
| 180 | + previous_post_processed_outputs=initial_post_processed_outputs, |
| 181 | + ) |
| 182 | + |
| 183 | + # Verify sawtooth crash occurred |
| 184 | + if not output_state.solver_numeric_outputs.sawtooth_crash: |
| 185 | + raise ValueError( |
| 186 | + 'Sawtooth crash did not occur! Check sawtooth model configuration.' |
| 187 | + ) |
| 188 | + |
| 189 | + # Extract post-crash profiles |
| 190 | + return { |
| 191 | + 'post_crash_temperature': np.asarray(output_state.core_profiles.T_e.value), |
| 192 | + 'post_crash_n': np.asarray(output_state.core_profiles.n_e.value), |
| 193 | + 'post_crash_psi': np.asarray(output_state.core_profiles.psi.value), |
| 194 | + } |
| 195 | + |
| 196 | + |
| 197 | +def _print_full_summary(new_values: dict[str, np.ndarray]): |
| 198 | + """Prints the full regenerated reference values for inspection.""" |
| 199 | + pretty_printer = pprint.PrettyPrinter(indent=4, width=100) |
| 200 | + logging.info('Sawtooth crash reference values:') |
| 201 | + for name, value in new_values.items(): |
| 202 | + logging.info(' %s:', name) |
| 203 | + pretty_printer.pprint(value) |
| 204 | + print('-' * 20) |
| 205 | + |
| 206 | + |
| 207 | +def get_references_path() -> pathlib.Path: |
| 208 | + """Returns the path to the local references file.""" |
| 209 | + return pathlib.Path(__file__).parent / REFERENCES_FILE |
| 210 | + |
| 211 | + |
| 212 | +def main(argv: Sequence[str]) -> None: |
| 213 | + if len(argv) > 1: |
| 214 | + raise app.UsageError('Too many command-line arguments.') |
| 215 | + |
| 216 | + output_path = get_references_path() |
| 217 | + |
| 218 | + np.set_printoptions( |
| 219 | + precision=12, suppress=True, threshold=np.inf, linewidth=100 |
| 220 | + ) |
| 221 | + |
| 222 | + logging.info('Regenerating sawtooth crash references...') |
| 223 | + new_values = calculate_sawtooth_crash_references() |
| 224 | + |
| 225 | + if _PRINT_SUMMARY.value: |
| 226 | + _print_full_summary(new_values) |
| 227 | + |
| 228 | + if _WRITE_TO_FILE.value: |
| 229 | + logging.info('Writing regenerated data to %s...', output_path) |
| 230 | + with open(output_path, 'w') as f: |
| 231 | + json.dump(new_values, f, indent=2, cls=NumpyEncoder) |
| 232 | + logging.info('Done.') |
| 233 | + else: |
| 234 | + logging.info( |
| 235 | + 'Finished dry run. Not writing to file. Use --write_to_file to save.' |
| 236 | + ) |
| 237 | + |
| 238 | + |
| 239 | +if __name__ == '__main__': |
| 240 | + app.run(main) |
0 commit comments