-
Notifications
You must be signed in to change notification settings - Fork 114
Open
Description
The use of the TGLFNN model was found to be incompatible with the new rotation model. trying to run TGLFNN on the latest commit on main led to the following stack trace
Traceback (most recent call last):
File "/home/saperstein/onsims/workflow/run_torax.py", line 54, in <module>
torax_output = launch_torax(torax_input, save_output_path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/saperstein/onsims/src/onsims/set_torax_config.py", line 223, in launch_torax
data_tree, state_history = torax.run_simulation(torax_config, log_timestep_info=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/orchestration/run_simulation.py", line 124, in run_simulation
) = prepare_simulation(torax_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/orchestration/run_simulation.py", line 89, in prepare_simulation
initial_state_lib.get_initial_state_and_post_processed_outputs(
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[51] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was _get_v_ExB_shear at /home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/transport_model/tglf_based_transport_model.py:319 traced for cond.
------------------------------
The leaked intermediate value was created on line /home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/fvm/cell_variable.py:137:11 (CellVariable.cell_centers).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/transport_model/tglf_based_transport_model.py:320:20 (TGLFBasedTransportModel._prepare_tglf_inputs.<locals>._get_v_ExB_shear)
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/physics/rotation.py:141:7 (calculate_rotation)
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/physics/rotation.py:55:11 (_calculate_radial_electric_field)
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/fvm/cell_variable.py:208:8 (CellVariable.face_grad)
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/fvm/cell_variable.py:137:11 (CellVariable.cell_centers)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.UnexpectedTracerError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Metadata
Metadata
Assignees
Labels
No labels