Skip to content

TGLFNN/rotation compatibility #1939

@AlexSaperstein

Description

@AlexSaperstein

The use of the TGLFNN model was found to be incompatible with the new rotation model. trying to run TGLFNN on the latest commit on main led to the following stack trace

Traceback (most recent call last):
  File "/home/saperstein/onsims/workflow/run_torax.py", line 54, in <module>
    torax_output = launch_torax(torax_input, save_output_path)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/saperstein/onsims/src/onsims/set_torax_config.py", line 223, in launch_torax
    data_tree, state_history = torax.run_simulation(torax_config, log_timestep_info=False)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/orchestration/run_simulation.py", line 124, in run_simulation
    ) = prepare_simulation(torax_config)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/orchestration/run_simulation.py", line 89, in prepare_simulation
    initial_state_lib.get_initial_state_and_post_processed_outputs(
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[51] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was _get_v_ExB_shear at /home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/transport_model/tglf_based_transport_model.py:319 traced for cond.
------------------------------
The leaked intermediate value was created on line /home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/fvm/cell_variable.py:137:11 (CellVariable.cell_centers). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/transport_model/tglf_based_transport_model.py:320:20 (TGLFBasedTransportModel._prepare_tglf_inputs.<locals>._get_v_ExB_shear)
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/physics/rotation.py:141:7 (calculate_rotation)
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/physics/rotation.py:55:11 (_calculate_radial_electric_field)
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/fvm/cell_variable.py:208:8 (CellVariable.face_grad)
/home/saperstein/onsims/.venv/lib/python3.12/site-packages/torax/_src/fvm/cell_variable.py:137:11 (CellVariable.cell_centers)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.UnexpectedTracerError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions