How to vmap across a parameter or batch of configs?
#1686
-
|
I am trying to do a scan over values of a single parameter, for example trying to see how H98 varies with BgB scaling. My ideal conception of what this would look like is something like: from torax.examples import step_flattop_bgb
import torax
import jax
import jax.numpy as jnp
base_config = torax.ToraxConfig.from_dict(step_flattop_bgb.CONFIG)
def evaluate_h98(bgb_scaling: float) -> float:
config = base_config.model_copy(update={
"transport.chi_e_bohm_multiplier": bgb_scaling,
"transport.chi_e_gyrobohm_multiplier": bgb_scaling,
"transport.chi_i_bohm_multiplier": bgb_scaling,
"transport.chi_i_gyrobohm_multiplier": bgb_scaling,
})
_, state_history = torax.run_simulation(config)
return state_history.post_processed_outputs[-1].H98
jax.vmap(evaluate_h98)(jnp.array([0.1, 0.5, 1.0]))However, while this does run, it doesn't actually work (H98 are the same for all 3 simulations). A few other obstacles I've found along the way:
Possibly related to #1685 ? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
The answer is to use the JIT-compatible run loop. This hasn't been documented yet as it's part of the internal experimental API, so use at your own risk! Working example: from torax.examples import step_flattop_bgb
import torax
import jax
import copy
import dataclasses
from torax._src.orchestration import jit_run_loop
import jax.numpy as jnp
from torax._src.config import build_runtime_params
from torax._src.orchestration import initial_state as initial_state_lib
from torax._src.orchestration import step_function
# Prepare the simulation objects we will modify
torax_config = torax.ToraxConfig.from_dict(step_flattop_bgb.CONFIG)
geometry_provider = torax_config.geometry.build_provider
solver = torax_config.solver.build_solver(
physics_models=torax_config.build_physics_models()
)
runtime_params_provider = (
build_runtime_params.RuntimeParamsProvider.from_config(torax_config)
)
# Objective function
def evaluate_h98(bgb_scaling: float) -> float:
# Update the runtime params provider to use the new variable
bgb_scaling_field = torax._src.torax_pydantic.interpolated_param_1d.TimeVaryingScalar.model_construct(
time=jnp.array([0.0]), value=jnp.array([bgb_scaling])
)
modified_transport_model = runtime_params_provider.transport_model.model_copy(
update=dict(
chi_e_bohm_multiplier=bgb_scaling_field,
chi_e_gyrobohm_multiplier=bgb_scaling_field,
chi_i_bohm_multiplier=bgb_scaling_field,
chi_i_gyrobohm_multiplier=bgb_scaling_field,
)
)
modified_runtime_params_provider = dataclasses.replace(
runtime_params_provider, transport_model=modified_transport_model
)
# Create new step function using the updated runtime params provider
step_fn = step_function.SimulationStepFn(
solver=solver,
time_step_calculator=torax_config.time_step_calculator.time_step_calculator,
geometry_provider=geometry_provider,
runtime_params_provider=modified_runtime_params_provider,
)
# Create new initial state using the updated runtime params provider
initial_state, post_processed_outputs = (
initial_state_lib.get_initial_state_and_post_processed_outputs(
t=torax_config.numerics.t_initial,
runtime_params_provider=modified_runtime_params_provider,
geometry_provider=geometry_provider,
step_fn=step_fn,
)
)
# Run
(state_history, post_processed_outputs_history, n_completed_steps) = (
jit_run_loop.run_loop_jit(
initial_state=initial_state,
initial_post_processed_outputs=post_processed_outputs,
step_fn=step_fn,
max_steps=10,
)
)
# Extract the variable of interest - in this case, the H98 from the final step
return post_processed_outputs_history.H98[n_completed_steps]This lets you do jax.vmap(evaluate_h98)(jnp.linspace(0.25, 1.0, 10))
|
Beta Was this translation helpful? Give feedback.
The answer is to use the JIT-compatible run loop. This hasn't been documented yet as it's part of the internal experimental API, so use at your own risk!
Working example: