|
| 1 | +from time import perf_counter |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import mujoco |
| 7 | +import numpy as np |
| 8 | +from mujoco import mjx |
| 9 | +from mujoco_sysid.mjx.convert import logchol2theta, theta2logchol |
| 10 | +from mujoco_sysid.mjx.model import create_rollout |
| 11 | +from mujoco_sysid.mjx.parameters import get_dynamic_parameters, set_dynamic_parameters |
| 12 | +import os |
| 13 | +import optax |
| 14 | +from mujoco.mjx._src.types import IntegratorType |
| 15 | + |
| 16 | +# SHOULD WE MOVE THIS IN TO MODULE INIT? |
| 17 | +xla_flags = os.environ.get("XLA_FLAGS", "") |
| 18 | +xla_flags += " --xla_gpu_triton_gemm_any=True" |
| 19 | +os.environ["XLA_FLAGS"] = xla_flags |
| 20 | + |
| 21 | + |
| 22 | +@jax.jit |
| 23 | +def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model: |
| 24 | + """Map new parameters to the model.""" |
| 25 | + log_cholesky, log_damping, log_friction_loss = jnp.split(parameters, [10, 11]) |
| 26 | + inertial_parameters = logchol2theta(log_cholesky) |
| 27 | + model = set_dynamic_parameters(model, 0, inertial_parameters) |
| 28 | + damping = jnp.exp(log_damping[0]) |
| 29 | + friction_loss = jnp.exp(log_friction_loss[0]) |
| 30 | + return model.tree_replace( |
| 31 | + { |
| 32 | + "dof_damping": model.dof_damping.at[0].set(damping), |
| 33 | + "dof_frictionloss": model.dof_frictionloss.at[0].set(friction_loss), |
| 34 | + } |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +rollout_trajectory = jax.jit(create_rollout(parameters_map)) |
| 39 | + |
| 40 | + |
| 41 | +# Initialize random key |
| 42 | +key = jax.random.PRNGKey(0) |
| 43 | + |
| 44 | +# Load the model |
| 45 | +MJCF_PATH = "../../data/models/pendulum/pendulum.xml" |
| 46 | +model = mujoco.MjModel.from_xml_path(MJCF_PATH) |
| 47 | +data = mujoco.MjData(model) |
| 48 | +model.opt.integrator = IntegratorType.EULER |
| 49 | + |
| 50 | +# Setting up constraint solver to ensure differentiability and faster simulations |
| 51 | +model.opt.solver = 2 # 2 corresponds to Newton solver |
| 52 | +model.opt.iterations = 1 |
| 53 | +model.opt.ls_iterations = 10 |
| 54 | + |
| 55 | +mjx_model = mjx.put_model(model) |
| 56 | + |
| 57 | +# Load test data |
| 58 | +TEST_DATA_PATH = "../../data/trajectories/pendulum/free_fall_2.csv" |
| 59 | +data_array = np.genfromtxt( |
| 60 | + TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500) |
| 61 | +timespan = data_array[:, 0] - data_array[0, 0] |
| 62 | +sampling = np.mean(np.diff(timespan)) |
| 63 | +angle = data_array[:, 1] |
| 64 | +velocity = data_array[:, 2] |
| 65 | +control = data_array[:, 3] |
| 66 | + |
| 67 | +model.opt.timestep = sampling |
| 68 | + |
| 69 | +HORIZON = 10 |
| 70 | +N_INTERVALS = len(timespan) // HORIZON - 1 |
| 71 | +timespan = timespan[: N_INTERVALS * HORIZON] |
| 72 | +angle = angle[: N_INTERVALS * HORIZON] |
| 73 | +velocity = velocity[: N_INTERVALS * HORIZON] |
| 74 | +control = control[: N_INTERVALS * HORIZON] |
| 75 | + |
| 76 | +# Prepare data for simulation and optimization |
| 77 | +initial_state = jnp.array([angle[0], velocity[0]]) |
| 78 | +true_trajectory = jnp.column_stack((angle, velocity)) |
| 79 | +control_inputs = jnp.array(control) |
| 80 | + |
| 81 | +interval_true_trajectory = true_trajectory[::HORIZON] |
| 82 | +interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON) |
| 83 | + |
| 84 | +# Get default parameters from the model |
| 85 | +default_parameters = jnp.concatenate( |
| 86 | + [theta2logchol(get_dynamic_parameters(mjx_model, 1)), |
| 87 | + jnp.log(mjx_model.dof_damping), jnp.log(mjx_model.dof_frictionloss)] |
| 88 | +) |
| 89 | + |
| 90 | +# print() |
| 91 | +@jax.jit |
| 92 | +def rollout_errors(parameters, states, controls): |
| 93 | + interval_initial_states = states[::HORIZON] |
| 94 | + interval_terminal_states = states[HORIZON+1:][::HORIZON] |
| 95 | + interval_controls = jnp.reshape(controls, (N_INTERVALS, HORIZON)) |
| 96 | + batched_rollout = jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0)) |
| 97 | + batched_states_trajectories = batched_rollout(parameters, mjx_model, interval_initial_states, interval_controls) |
| 98 | + predicted_terminal_points = batched_states_trajectories[:,-1,:] |
| 99 | + loss = jnp.mean(optax.l2_loss(predicted_terminal_points[:-1], interval_terminal_states)) + 0.05*jnp.mean(optax.huber_loss(parameters, jnp.zeros_like(parameters))) |
| 100 | + return loss |
| 101 | + |
| 102 | +start_learning_rate = 1e-3 |
| 103 | +optimizer = optax.adam(learning_rate = start_learning_rate) |
| 104 | + |
| 105 | + |
| 106 | +# Initialize parameters of the model + optimizer. |
| 107 | +params = jnp.array(0.5*default_parameters) |
| 108 | +opt_state = optimizer.init(params) |
| 109 | +val_and_grad = jax.jit(jax.value_and_grad(rollout_errors)) |
| 110 | +loss_val, loss_grad = val_and_grad(params, true_trajectory, control_inputs) |
| 111 | + |
| 112 | +# A simple update loop. |
| 113 | +for _ in range(100): |
| 114 | + loss_val, loss_grad = val_and_grad(params, true_trajectory, control_inputs) |
| 115 | + updates, opt_state = optimizer.update(loss_grad, opt_state) |
| 116 | + params = optax.apply_updates(params, updates) |
| 117 | + print(loss_val, params) |
| 118 | + |
| 119 | +# assert jnp.allclose(params, target_params), \ |
| 120 | +# 'Optimization should retrive the target params used to generate the data.' |
| 121 | +# # ////////////////////////////////////// |
| 122 | +# # SIMULATION BATCHES: THIS WILL BE HANDY IN OPTIMIZATION |
| 123 | + |
| 124 | +# # Vectorize over both initial states and control inputs |
| 125 | + |
| 126 | + |
| 127 | +# # Create a batch of initial states |
| 128 | +# key, subkey = jax.random.split(key) |
| 129 | +# batch_initial_states = jax.random.uniform( |
| 130 | +# subkey, (N_INTERVALS, 2), minval=-0.1, maxval=0.1) + initial_state |
| 131 | +# # Create a batch of control input sequences |
| 132 | +# key, subkey = jax.random.split(key) |
| 133 | +# batch_control_inputs = jax.random.normal( |
| 134 | +# subkey, (N_INTERVALS, HORIZON)) * 0.1 # + control_inputs |
| 135 | +# # Run warm up for batched rollout |
| 136 | +# t1 = perf_counter() |
| 137 | +# batched_trajectories = batched_rollout( |
| 138 | +# default_parameters, mjx_model, batch_initial_states, batch_control_inputs) |
| 139 | +# t2 = perf_counter() |
| 140 | +# print(f"Batch simulation time: {t2 - t1} seconds") |
| 141 | + |
| 142 | +# # Run batched rollout on shor horizon data from pendulum |
| 143 | +# interval_initial_states = true_trajectory[::HORIZON] |
| 144 | +# interval_terminal_states = true_trajectory[HORIZON+1:][::HORIZON] |
| 145 | +# interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON) |
| 146 | +# batched_states_trajectories = batched_rollout( |
| 147 | +# default_parameters*0.999999, mjx_model, interval_initial_states, interval_controls) |
| 148 | +# t1 = perf_counter() |
| 149 | +# t2 = perf_counter() |
| 150 | +# print(f"Batch simulation time: {t2 - t1} seconds") |
| 151 | + |
| 152 | +# predicted_terminal_points = np.array(batched_states_trajectories)[:,-1,:] |
| 153 | +# batched_states_trajectories = np.array( |
| 154 | +# batched_states_trajectories).reshape(N_INTERVALS * HORIZON, 2) |
| 155 | +# # Plotting simulation results for batсhed state trajectories |
| 156 | +# plt.figure(figsize=(10, 5)) |
| 157 | + |
| 158 | +# plt.subplot(2, 2, 1) |
| 159 | +# plt.plot(timespan, angle, label="Actual Angle", |
| 160 | +# color="black", linestyle="dashed", linewidth=2) |
| 161 | +# plt.plot(timespan, batched_states_trajectories[:, 0], |
| 162 | +# alpha=0.5, color="blue", label="Simulated Angle") |
| 163 | +# plt.plot(timespan, angle, label="Actual Angle", |
| 164 | +# color="black", linestyle="dashed", linewidth=2) |
| 165 | +# plt.plot(timespan[HORIZON+1:][::HORIZON], predicted_terminal_points[:-1,0], 'ob') |
| 166 | +# plt.plot(timespan[HORIZON+1:][::HORIZON], interval_terminal_states[:, 0], 'or') |
| 167 | +# plt.ylabel("Angle (rad)") |
| 168 | +# plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4) |
| 169 | +# plt.legend() |
| 170 | +# plt.title("Pendulum Dynamics - Bathed State Trajectories") |
| 171 | + |
| 172 | +# plt.subplot(2, 2, 3) |
| 173 | +# plt.plot(timespan, velocity, label="Actual Velocity", |
| 174 | +# color="black", linestyle="dashed", linewidth=2) |
| 175 | +# plt.plot(timespan[HORIZON+1:][::HORIZON], predicted_terminal_points[:-1,1], 'ob') |
| 176 | +# plt.plot(timespan[HORIZON+1:][::HORIZON], interval_terminal_states[:, 1], 'or') |
| 177 | +# plt.plot(timespan, batched_states_trajectories[:, 1], |
| 178 | +# alpha=0.5, color="blue", label="Simulated Velocity") |
| 179 | +# plt.xlabel("Time (s)") |
| 180 | +# plt.ylabel("Velocity (rad/s)") |
| 181 | +# plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4) |
| 182 | +# plt.legend() |
| 183 | + |
| 184 | +# # Add phase portrait |
| 185 | +# plt.subplot(1, 2, 2) |
| 186 | +# plt.plot(angle, velocity, label="Actual", |
| 187 | +# color="black", linestyle="dashed", linewidth=2) |
| 188 | +# plt.plot( |
| 189 | +# batched_states_trajectories[:, 0], batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated" |
| 190 | +# ) |
| 191 | +# plt.plot(predicted_terminal_points[:-1,0], predicted_terminal_points[:-1,1], 'ob') |
| 192 | +# plt.plot(interval_terminal_states[:, 0], interval_terminal_states[:, 1], 'or') |
| 193 | +# plt.xlabel("Angle (rad)") |
| 194 | +# plt.ylabel("Angular Velocity (rad/s)") |
| 195 | +# plt.title("Phase Portrait") |
| 196 | +# plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4) |
| 197 | +# plt.legend() |
| 198 | + |
| 199 | +# plt.tight_layout() |
| 200 | +# plt.show() |
| 201 | +# # TODO: |
| 202 | +# # Optimization |
| 203 | + |
0 commit comments