|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +import jax.typing as jpt |
| 4 | +import mujoco |
| 5 | +from mujoco import mjx |
| 6 | +import numpy as np |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +from mujoco_sysid.mjx.convert import logchol2theta, theta2logchol |
| 9 | +from mujoco_sysid.mjx.parameters import get_dynamic_parameters, set_dynamic_parameters |
| 10 | +from time import perf_counter |
| 11 | +import optax |
| 12 | + |
| 13 | + |
| 14 | +@jax.jit |
| 15 | +def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model: |
| 16 | + """Map new parameters to the model.""" |
| 17 | + log_cholesky, damping, friction_loss = jnp.split(parameters, [10, 11]) |
| 18 | + inertial_parameters = logchol2theta(log_cholesky) |
| 19 | + model = set_dynamic_parameters(model, 1, inertial_parameters) |
| 20 | + return model.tree_replace( |
| 21 | + { |
| 22 | + "dof_damping": model.dof_damping.at[0].set(damping[0]), |
| 23 | + "dof_frictionloss": model.dof_frictionloss.at[0].set(friction_loss[0]), |
| 24 | + } |
| 25 | + ) |
| 26 | + |
| 27 | + |
| 28 | +@jax.jit |
| 29 | +def parametric_step(parameters: jnp.ndarray, model: mjx.Model, state: jnp.ndarray, control: jnp.ndarray) -> jnp.ndarray: |
| 30 | + """Perform a step with new parameter mapping.""" |
| 31 | + new_model = parameters_map(parameters, model) |
| 32 | + data = mjx.make_data(new_model).replace(qpos=state[: new_model.nq], qvel=state[new_model.nq :], ctrl=control) |
| 33 | + data = mjx.step(new_model, data) |
| 34 | + return jnp.concatenate([data.qpos, data.qvel]) |
| 35 | + |
| 36 | + |
| 37 | +@jax.jit |
| 38 | +def rollout_trajectory( |
| 39 | + parameters: jnp.ndarray, model: mjx.Model, initial_state: jnp.ndarray, control_inputs: jnp.ndarray |
| 40 | +) -> jnp.ndarray: |
| 41 | + """Rollout a trajectory given parameters, initial state, and control inputs.""" |
| 42 | + |
| 43 | + def step_fn(state, control): |
| 44 | + new_state = parametric_step(parameters, model, state, control) |
| 45 | + return new_state, new_state |
| 46 | + |
| 47 | + (_, states) = jax.lax.scan(step_fn, initial_state, control_inputs) |
| 48 | + return states |
| 49 | + |
| 50 | + |
| 51 | +# Initialize random key |
| 52 | +key = jax.random.PRNGKey(0) |
| 53 | + |
| 54 | +# Load the model |
| 55 | +MJCF_PATH = "../data/models/pendulum/pendulum.xml" |
| 56 | +model = mujoco.MjModel.from_xml_path(MJCF_PATH) |
| 57 | +data = mujoco.MjData(model) |
| 58 | +model.opt.integrator = 1 |
| 59 | + |
| 60 | +# Setting up constraint solver to ensure differentiability and faster simulations |
| 61 | +model.opt.solver = 2 # 2 corresponds to Newton solver |
| 62 | +model.opt.iterations = 2 |
| 63 | +model.opt.ls_iterations = 10 |
| 64 | + |
| 65 | +mjx_model = mjx.put_model(model) |
| 66 | + |
| 67 | +# Load test data |
| 68 | +TEST_DATA_PATH = "../data/trajectories/pendulum/free_fall_2.csv" |
| 69 | +data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500) |
| 70 | +timespan = data_array[:, 0] - data_array[0, 0] |
| 71 | +sampling = np.mean(np.diff(timespan)) |
| 72 | +angle = data_array[:, 1] |
| 73 | +velocity = data_array[:, 2] |
| 74 | +control = data_array[:, 3] |
| 75 | + |
| 76 | +model.opt.timestep = sampling |
| 77 | + |
| 78 | +HORIZON = 100 |
| 79 | +N_INTERVALS = len(timespan) // HORIZON - 1 |
| 80 | +timespan = timespan[: N_INTERVALS * HORIZON] |
| 81 | +angle = angle[: N_INTERVALS * HORIZON] |
| 82 | +velocity = velocity[: N_INTERVALS * HORIZON] |
| 83 | +control = control[: N_INTERVALS * HORIZON] |
| 84 | + |
| 85 | +# Prepare data for simulation and optimization |
| 86 | +initial_state = jnp.array([angle[0], velocity[0]]) |
| 87 | +true_trajectory = jnp.column_stack((angle, velocity)) |
| 88 | +control_inputs = jnp.array(control) |
| 89 | + |
| 90 | +interval_true_trajectory = true_trajectory[::HORIZON] |
| 91 | +interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON) |
| 92 | + |
| 93 | +# Get default parameters from the model |
| 94 | +default_parameters = jnp.concatenate( |
| 95 | + [theta2logchol(get_dynamic_parameters(mjx_model, 1)), mjx_model.dof_damping, mjx_model.dof_frictionloss] |
| 96 | +) |
| 97 | + |
| 98 | +# ////////////////////////////////////// |
| 99 | +# SIMULATION BATCHES: THIS WILL BE HANDY IN OPTIMIZATION |
| 100 | + |
| 101 | +# Vectorize over both initial states and control inputs |
| 102 | +batched_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0))) |
| 103 | + |
| 104 | +# Create a batch of initial states |
| 105 | +key, subkey = jax.random.split(key) |
| 106 | +batch_initial_states = jax.random.uniform(subkey, (N_INTERVALS, 2), minval=-0.1, maxval=0.1) + initial_state |
| 107 | +# Create a batch of control input sequences |
| 108 | +key, subkey = jax.random.split(key) |
| 109 | +batch_control_inputs = jax.random.normal(subkey, (N_INTERVALS, HORIZON)) * 0.1 # + control_inputs |
| 110 | +# Run warm up for batched rollout |
| 111 | +t1 = perf_counter() |
| 112 | +batched_trajectories = batched_rollout(default_parameters, mjx_model, batch_initial_states, batch_control_inputs) |
| 113 | +t2 = perf_counter() |
| 114 | +print(f"Batch simulation time: {t2 - t1} seconds") |
| 115 | + |
| 116 | +# Run batched rollout on shor horizon data from pendulum |
| 117 | +interval_initial_states = true_trajectory[::HORIZON] |
| 118 | +interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON) |
| 119 | +t1 = perf_counter() |
| 120 | +batched_states_trajectories = batched_rollout( |
| 121 | + default_parameters * 0.7, mjx_model, interval_initial_states, interval_controls |
| 122 | +) |
| 123 | +t2 = perf_counter() |
| 124 | +print(f"Batch simulation time: {t2 - t1} seconds") |
| 125 | + |
| 126 | +batched_states_trajectories = np.array(batched_states_trajectories).reshape(N_INTERVALS * HORIZON, 2) |
| 127 | + |
| 128 | +# Plotting simulation results for batсhed state trajectories |
| 129 | +plt.figure(figsize=(10, 5)) |
| 130 | + |
| 131 | +plt.subplot(2, 2, 1) |
| 132 | +plt.plot(timespan, angle, label="Actual Angle", color="black", linestyle="dashed", linewidth=2) |
| 133 | +plt.plot(timespan, batched_states_trajectories[:, 0], alpha=0.5, color="blue", label="Simulated Angle") |
| 134 | +plt.ylabel("Angle (rad)") |
| 135 | +plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4) |
| 136 | +plt.legend() |
| 137 | +plt.title("Pendulum Dynamics - Bathed State Trajectories") |
| 138 | + |
| 139 | +plt.subplot(2, 2, 3) |
| 140 | +plt.plot(timespan, velocity, label="Actual Velocity", color="black", linestyle="dashed", linewidth=2) |
| 141 | +plt.plot(timespan, batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated Velocity") |
| 142 | +plt.xlabel("Time (s)") |
| 143 | +plt.ylabel("Velocity (rad/s)") |
| 144 | +plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4) |
| 145 | +plt.legend() |
| 146 | + |
| 147 | +# Add phase portrait |
| 148 | +plt.subplot(1, 2, 2) |
| 149 | +plt.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2) |
| 150 | +plt.plot( |
| 151 | + batched_states_trajectories[:, 0], batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated" |
| 152 | +) |
| 153 | +plt.xlabel("Angle (rad)") |
| 154 | +plt.ylabel("Angular Velocity (rad/s)") |
| 155 | +plt.title("Phase Portrait") |
| 156 | +plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4) |
| 157 | +plt.legend() |
| 158 | + |
| 159 | +plt.tight_layout() |
| 160 | +plt.show() |
| 161 | + |
| 162 | +# ////////////////////////////////////////////////// |
| 163 | +# PARAMETRIC BATCHES |
| 164 | +# Create a batch of 200 randomized parameters |
| 165 | +num_batches = 200 |
| 166 | +key, subkey1, subkey2, subkey3 = jax.random.split(key, 4) |
| 167 | + |
| 168 | +default_log_cholesky_first = default_parameters[0] |
| 169 | +default_damping = default_parameters[-2] |
| 170 | +default_dry_friction = default_parameters[-1] |
| 171 | + |
| 172 | +randomized_log_cholesky_first = jax.random.uniform( |
| 173 | + subkey1, (num_batches,), minval=default_log_cholesky_first * 0.8, maxval=default_log_cholesky_first * 1.1 |
| 174 | +) |
| 175 | + |
| 176 | +randomized_damping = jax.random.uniform( |
| 177 | + subkey2, (num_batches,), minval=default_damping * 0.9, maxval=default_damping * 1.5 |
| 178 | +) |
| 179 | + |
| 180 | +randomized_dry_friction = jax.random.uniform( |
| 181 | + subkey3, (num_batches,), minval=default_dry_friction * 0.9, maxval=default_dry_friction * 1.5 |
| 182 | +) |
| 183 | + |
| 184 | +# Create a batch of parameters with randomized first log-Cholesky parameter, damping, and dry frictions |
| 185 | +batch_parameters = jnp.tile(default_parameters, (num_batches, 1)) |
| 186 | +batch_parameters = batch_parameters.at[:, 0].set(randomized_log_cholesky_first) |
| 187 | +batch_parameters = batch_parameters.at[:, -2].set(randomized_damping) |
| 188 | +batch_parameters = batch_parameters.at[:, -1].set(randomized_dry_friction) |
| 189 | + |
| 190 | + |
| 191 | +# Define a batched version of rollout_trajectory using vmap |
| 192 | +batched_parameters_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(0, None, None, None))) |
| 193 | + |
| 194 | +# Simulation with XML parameters |
| 195 | +xml_trajectory = rollout_trajectory(default_parameters, mjx_model, initial_state, control_inputs) |
| 196 | + |
| 197 | +# Simulate trajectories with randomized parameters using vmap |
| 198 | +t1 = perf_counter() |
| 199 | +randomized_trajectories = batched_parameters_rollout(batch_parameters, mjx_model, initial_state, control_inputs) |
| 200 | +t2 = perf_counter() |
| 201 | + |
| 202 | +print(f"Simulation with randomized parameters using vmap took {t2-t1:.2f} seconds.") |
| 203 | +# Plotting simulation results (XML vs Randomized) |
| 204 | +plt.figure(figsize=(10, 5)) |
| 205 | + |
| 206 | +plt.subplot(2, 2, 1) |
| 207 | +plt.plot(timespan, angle, label="Actual Angle", color="black", linestyle="dashed", linewidth=2) |
| 208 | +for trajectory in randomized_trajectories: |
| 209 | + plt.plot(timespan, trajectory[:, 0], alpha=0.02, color="blue") |
| 210 | +plt.plot(timespan, xml_trajectory[:, 0], label="XML Model Angle", color="red", linewidth=2) |
| 211 | +plt.ylabel("Angle (rad)") |
| 212 | +plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4) |
| 213 | +plt.grid(True) |
| 214 | +plt.legend() |
| 215 | +plt.title("Pendulum Dynamics - Randomized Parameters") |
| 216 | + |
| 217 | +plt.subplot(2, 2, 3) |
| 218 | +plt.plot(timespan, velocity, label="Actual Velocity", color="black", linestyle="dashed", linewidth=2) |
| 219 | +for trajectory in randomized_trajectories: |
| 220 | + plt.plot(timespan, trajectory[:, 1], alpha=0.02, color="blue") |
| 221 | +plt.plot(timespan, xml_trajectory[:, 1], label="XML Model Velocity", color="red", linewidth=2) |
| 222 | +plt.xlabel("Time (s)") |
| 223 | +plt.ylabel("Velocity (rad/s)") |
| 224 | +plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4) |
| 225 | +plt.grid(True) |
| 226 | +plt.legend() |
| 227 | + |
| 228 | +# Add phase portrait |
| 229 | +plt.subplot(1, 2, 2) |
| 230 | +plt.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2) |
| 231 | +for trajectory in randomized_trajectories: |
| 232 | + plt.plot(trajectory[:, 0], trajectory[:, 1], alpha=0.02, color="blue") |
| 233 | +plt.plot(xml_trajectory[:, 0], xml_trajectory[:, 1], label="XML Model", color="red", linewidth=2) |
| 234 | +plt.xlabel("Angle (rad)") |
| 235 | +plt.ylabel("Angular Velocity (rad/s)") |
| 236 | +plt.title("Phase Portrait") |
| 237 | +plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4) |
| 238 | +plt.grid(True) |
| 239 | +plt.legend() |
| 240 | + |
| 241 | +plt.tight_layout() |
| 242 | +plt.show() |
| 243 | + |
| 244 | + |
| 245 | +randomized_log_cholesky_first = jax.random.uniform( |
| 246 | + subkey1, (num_batches,), minval=default_log_cholesky_first * 0.8, maxval=default_log_cholesky_first * 1.1 |
| 247 | +) |
| 248 | + |
| 249 | +randomized_damping = jax.random.uniform( |
| 250 | + subkey2, (num_batches,), minval=default_damping * 0.9, maxval=default_damping * 1.5 |
| 251 | +) |
| 252 | + |
| 253 | +randomized_dry_friction = jax.random.uniform( |
| 254 | + subkey3, (num_batches,), minval=default_dry_friction * 0.9, maxval=default_dry_friction * 1.5 |
| 255 | +) |
| 256 | + |
| 257 | +# Create a batch of parameters with randomized first log-Cholesky parameter, damping, and dry frictions |
| 258 | +batch_parameters = jnp.tile(default_parameters, (num_batches, 1)) |
| 259 | +batch_parameters = batch_parameters.at[:, 0].set(randomized_log_cholesky_first) |
| 260 | +batch_parameters = batch_parameters.at[:, -2].set(randomized_damping) |
| 261 | +batch_parameters = batch_parameters.at[:, -1].set(randomized_dry_friction) |
| 262 | + |
| 263 | +# Simulate trajectories with randomized parameters using vmap |
| 264 | +t1 = perf_counter() |
| 265 | +randomized_trajectories = batched_parameters_rollout(batch_parameters, mjx_model, initial_state, control_inputs) |
| 266 | +t2 = perf_counter() |
| 267 | +print(f"Simulation with randomized parameters using vmap took {t2-t1:.2f} seconds.") |
| 268 | + |
| 269 | + |
| 270 | +# TODO: OPTIMIZATION |
| 271 | + |
| 272 | + |
| 273 | +# Optimization |
| 274 | + |
| 275 | +# # Error function |
| 276 | +# def trajectory_error(parameters: jnp.ndarray, model: mjx.Model, initial_state: jnp.ndarray, control_inputs: jnp.ndarray, true_trajectory: jnp.ndarray) -> jnp.ndarray: |
| 277 | +# predicted_trajectory = rollout_trajectory(parameters, model, initial_state, control_inputs) |
| 278 | +# return jnp.mean(jnp.square(predicted_trajectory - true_trajectory)) |
| 279 | + |
| 280 | +# # Optimization |
| 281 | +# @jax.jit |
| 282 | +# def update_step(parameters, opt_state, model, initial_state, control_inputs, true_trajectory): |
| 283 | +# loss, grads = jax.value_and_grad(trajectory_error)(parameters, model, initial_state, control_inputs, true_trajectory) |
| 284 | +# updates, opt_state = optimizer.update(grads, opt_state, parameters) |
| 285 | +# parameters = optax.apply_updates(parameters, updates) |
| 286 | +# return parameters, opt_state, loss |
| 287 | + |
| 288 | +# # Initial parameters for optimization (using randomized parameters as starting point) |
| 289 | +# initial_parameters = randomized_parameters |
| 290 | + |
| 291 | +# # Optimization setup |
| 292 | +# learning_rate = 1e-3 |
| 293 | +# optimizer = optax.adam(learning_rate) |
| 294 | +# opt_state = optimizer.init(initial_parameters) |
| 295 | + |
| 296 | +# # Optimization loop |
| 297 | +# num_iterations = 1000 |
| 298 | +# for i in range(num_iterations): |
| 299 | +# initial_parameters, opt_state, loss = update_step(initial_parameters, opt_state, mjx_model, initial_state, control_inputs, true_trajectory) |
| 300 | +# if i % 100 == 0: |
| 301 | +# print(f"Iteration {i}, Loss: {loss}") |
| 302 | + |
| 303 | +# # Final simulation with learned parameters |
| 304 | +# final_trajectory = rollout_trajectory(initial_parameters, mjx_model, initial_state, control_inputs) |
| 305 | + |
| 306 | +# # Plotting optimization results |
| 307 | +# plt.figure(figsize=(12, 9)) |
| 308 | + |
| 309 | +# plt.subplot(2, 1, 1) |
| 310 | +# plt.plot(timespan, angle, label='Actual Angle') |
| 311 | +# plt.plot(timespan, xml_trajectory[:, 0], label='XML Model Angle', linestyle='dashed') |
| 312 | +# plt.plot(timespan, randomized_trajectory[:, 0], label='Initial Randomized Angle', linestyle='dotted') |
| 313 | +# plt.plot(timespan, final_trajectory[:, 0], label='Learned Model Angle', linestyle='dashdot') |
| 314 | +# plt.ylabel('Angle (rad)') |
| 315 | +# plt.legend() |
| 316 | +# plt.title('Pendulum Dynamics Comparison - Optimization') |
| 317 | + |
| 318 | +# plt.subplot(2, 1, 2) |
| 319 | +# plt.plot(timespan, velocity, label='Actual Velocity') |
| 320 | +# plt.plot(timespan, xml_trajectory[:, 1], label='XML Model Velocity', linestyle='dashed') |
| 321 | +# plt.plot(timespan, randomized_trajectory[:, 1], label='Initial Randomized Velocity', linestyle='dotted') |
| 322 | +# plt.plot(timespan, final_trajectory[:, 1], label='Learned Model Velocity', linestyle='dashdot') |
| 323 | +# plt.xlabel('Time (s)') |
| 324 | +# plt.ylabel('Velocity (rad/s)') |
| 325 | +# plt.legend() |
| 326 | + |
| 327 | +# plt.tight_layout() |
| 328 | +# plt.show() |
| 329 | + |
| 330 | +# # Print learned parameters |
| 331 | +# learned_log_cholesky, learned_damping, learned_friction_loss = jnp.split(initial_parameters, [10, 11]) |
| 332 | +# learned_theta = logchol2theta(learned_log_cholesky) |
| 333 | + |
| 334 | +# print("\nParameters comparison:") |
| 335 | +# print("Parameter\t\tXML\t\tRandomized\tLearned") |
| 336 | +# print(f"Inertia:\t\t{get_dynamic_parameters(mjx_model, 1)[0]:.6f}\t{logchol2theta(randomized_parameters[:10])[0]:.6f}\t{learned_theta[0]:.6f}") |
| 337 | +# print(f"Damping:\t\t{mjx_model.dof_damping[0]:.6f}\t{randomized_parameters[10]:.6f}\t{learned_damping[0]:.6f}") |
| 338 | +# print(f"Friction loss:\t{mjx_model.dof_frictionloss[0]:.6f}\t{randomized_parameters[11]:.6f}\t{learned_friction_loss[0]:.6f}") |
| 339 | + |
| 340 | +# TODO: Save the learned parameters to a new XML file |
| 341 | +# This would require additional code to modify the XML file with the learned parameters |
0 commit comments