Skip to content

Commit d6c3714

Browse files
committed
add bathed short simulation
1 parent c8be148 commit d6c3714

File tree

3 files changed

+124
-39
lines changed

3 files changed

+124
-39
lines changed

examples/opti_loss.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@
281281
}
282282
],
283283
"source": [
284-
"theta2logchol(true_parameters)\n"
284+
"theta2logchol(true_parameters)"
285285
]
286286
},
287287
{
@@ -413,7 +413,7 @@
413413
"plt.title(\"Losses over time\")\n",
414414
"plt.fill_between(jnp.arange(horizon, len(log)), results[1].min(axis=0), results[1].max(axis=0), alpha=0.5)\n",
415415
"plt.plot(jnp.arange(horizon, len(log)), results[1].mean(axis=0))\n",
416-
"plt.show()\n"
416+
"plt.show()"
417417
]
418418
},
419419
{
@@ -469,7 +469,7 @@
469469
"for i in range(parameters_history.shape[0]):\n",
470470
" history.append(vmap_logchol2theta(parameters_history[i]))\n",
471471
"\n",
472-
"history = jnp.array(history)\n"
472+
"history = jnp.array(history)"
473473
]
474474
},
475475
{
@@ -523,7 +523,7 @@
523523
" # plt.fill_between(jnp.arange(horizon, len(log)), min_parameter, max_parameter, alpha=0.5)\n",
524524
" plt.plot(jnp.arange(horizon, len(log)), current_parameter.mean(axis=0))\n",
525525
" plt.plot(jnp.arange(horizon, len(log)), jnp.repeat(true_parameters[i], len(log) - horizon), label=\"True\")\n",
526-
" plt.legend()\n"
526+
" plt.legend()"
527527
]
528528
},
529529
{

examples/pendulum.py

Lines changed: 112 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,6 @@
1111
import optax
1212

1313

14-
# Initialize random key
15-
key = jax.random.PRNGKey(0)
16-
17-
# Load the model
18-
MJCF_PATH = "../data/models/pendulum/pendulum.xml"
19-
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
20-
data = mujoco.MjData(model)
21-
model.opt.integrator = 1
22-
23-
# Setting up constraint solver to ensure differentiability and faster simulations
24-
model.opt.solver = 2 # 2 corresponds to Newton solver
25-
model.opt.iterations = 2
26-
model.opt.ls_iterations = 10
27-
28-
mjx_model = mjx.put_model(model)
29-
30-
# Load test data
31-
TEST_DATA_PATH = "../data/trajectories/pendulum/free_fall_2.csv"
32-
data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500)
33-
timespan = data_array[:, 0] - data_array[0, 0]
34-
sampling = np.mean(np.diff(timespan))
35-
angle = data_array[:, 1]
36-
velocity = data_array[:, 2]
37-
control = data_array[:, 3]
38-
39-
model.opt.timestep = sampling
40-
41-
4214
@jax.jit
4315
def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
4416
"""Map new parameters to the model."""
@@ -76,19 +48,119 @@ def step_fn(state, control):
7648
return states
7749

7850

51+
# Initialize random key
52+
key = jax.random.PRNGKey(0)
53+
54+
# Load the model
55+
MJCF_PATH = "../data/models/pendulum/pendulum.xml"
56+
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
57+
data = mujoco.MjData(model)
58+
model.opt.integrator = 1
59+
60+
# Setting up constraint solver to ensure differentiability and faster simulations
61+
model.opt.solver = 2 # 2 corresponds to Newton solver
62+
model.opt.iterations = 2
63+
model.opt.ls_iterations = 10
64+
65+
mjx_model = mjx.put_model(model)
66+
67+
# Load test data
68+
TEST_DATA_PATH = "../data/trajectories/pendulum/free_fall_2.csv"
69+
data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500)
70+
timespan = data_array[:, 0] - data_array[0, 0]
71+
sampling = np.mean(np.diff(timespan))
72+
angle = data_array[:, 1]
73+
velocity = data_array[:, 2]
74+
control = data_array[:, 3]
75+
76+
model.opt.timestep = sampling
77+
78+
HORIZON = 100
79+
N_INTERVALS = len(timespan) // HORIZON - 1
80+
timespan = timespan[: N_INTERVALS * HORIZON]
81+
angle = angle[: N_INTERVALS * HORIZON]
82+
velocity = velocity[: N_INTERVALS * HORIZON]
83+
control = control[: N_INTERVALS * HORIZON]
84+
7985
# Prepare data for simulation and optimization
8086
initial_state = jnp.array([angle[0], velocity[0]])
8187
true_trajectory = jnp.column_stack((angle, velocity))
8288
control_inputs = jnp.array(control)
8389

90+
interval_true_trajectory = true_trajectory[::HORIZON]
91+
interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)
92+
8493
# Get default parameters from the model
8594
default_parameters = jnp.concatenate(
8695
[theta2logchol(get_dynamic_parameters(mjx_model, 1)), mjx_model.dof_damping, mjx_model.dof_frictionloss]
8796
)
8897

89-
# Simulation with XML parameters
90-
xml_trajectory = rollout_trajectory(default_parameters, mjx_model, initial_state, control_inputs)
98+
# //////////////////////////////////////
99+
# SIMULATION BATCHES: THIS WILL BE HANDY IN OPTIMIZATION
100+
101+
# Vectorize over both initial states and control inputs
102+
batched_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0)))
103+
104+
# Create a batch of initial states
105+
key, subkey = jax.random.split(key)
106+
batch_initial_states = jax.random.uniform(subkey, (N_INTERVALS, 2), minval=-0.1, maxval=0.1) + initial_state
107+
# Create a batch of control input sequences
108+
key, subkey = jax.random.split(key)
109+
batch_control_inputs = jax.random.normal(subkey, (N_INTERVALS, HORIZON)) * 0.1 # + control_inputs
110+
# Run warm up for batched rollout
111+
t1 = perf_counter()
112+
batched_trajectories = batched_rollout(default_parameters, mjx_model, batch_initial_states, batch_control_inputs)
113+
t2 = perf_counter()
114+
print(f"Batch simulation time: {t2 - t1} seconds")
115+
116+
# Run batched rollout on shor horizon data from pendulum
117+
interval_initial_states = true_trajectory[::HORIZON]
118+
interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)
119+
t1 = perf_counter()
120+
batched_states_trajectories = batched_rollout(
121+
default_parameters * 0.7, mjx_model, interval_initial_states, interval_controls
122+
)
123+
t2 = perf_counter()
124+
print(f"Batch simulation time: {t2 - t1} seconds")
125+
126+
batched_states_trajectories = np.array(batched_states_trajectories).reshape(N_INTERVALS * HORIZON, 2)
91127

128+
# Plotting simulation results for batсhed state trajectories
129+
plt.figure(figsize=(10, 5))
130+
131+
plt.subplot(2, 2, 1)
132+
plt.plot(timespan, angle, label="Actual Angle", color="black", linestyle="dashed", linewidth=2)
133+
plt.plot(timespan, batched_states_trajectories[:, 0], alpha=0.5, color="blue", label="Simulated Angle")
134+
plt.ylabel("Angle (rad)")
135+
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
136+
plt.legend()
137+
plt.title("Pendulum Dynamics - Bathed State Trajectories")
138+
139+
plt.subplot(2, 2, 3)
140+
plt.plot(timespan, velocity, label="Actual Velocity", color="black", linestyle="dashed", linewidth=2)
141+
plt.plot(timespan, batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated Velocity")
142+
plt.xlabel("Time (s)")
143+
plt.ylabel("Velocity (rad/s)")
144+
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
145+
plt.legend()
146+
147+
# Add phase portrait
148+
plt.subplot(1, 2, 2)
149+
plt.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2)
150+
plt.plot(
151+
batched_states_trajectories[:, 0], batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated"
152+
)
153+
plt.xlabel("Angle (rad)")
154+
plt.ylabel("Angular Velocity (rad/s)")
155+
plt.title("Phase Portrait")
156+
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
157+
plt.legend()
158+
159+
plt.tight_layout()
160+
plt.show()
161+
162+
# //////////////////////////////////////////////////
163+
# PARAMETRIC BATCHES
92164
# Create a batch of 200 randomized parameters
93165
num_batches = 200
94166
key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)
@@ -115,13 +187,16 @@ def step_fn(state, control):
115187
batch_parameters = batch_parameters.at[:, -2].set(randomized_damping)
116188
batch_parameters = batch_parameters.at[:, -1].set(randomized_dry_friction)
117189

190+
118191
# Define a batched version of rollout_trajectory using vmap
119-
batched_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(0, None, None, None)))
192+
batched_parameters_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(0, None, None, None)))
120193

194+
# Simulation with XML parameters
195+
xml_trajectory = rollout_trajectory(default_parameters, mjx_model, initial_state, control_inputs)
121196

122197
# Simulate trajectories with randomized parameters using vmap
123198
t1 = perf_counter()
124-
randomized_trajectories = batched_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
199+
randomized_trajectories = batched_parameters_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
125200
t2 = perf_counter()
126201

127202
print(f"Simulation with randomized parameters using vmap took {t2-t1:.2f} seconds.")
@@ -187,10 +262,14 @@ def step_fn(state, control):
187262

188263
# Simulate trajectories with randomized parameters using vmap
189264
t1 = perf_counter()
190-
randomized_trajectories = batched_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
265+
randomized_trajectories = batched_parameters_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
191266
t2 = perf_counter()
192267
print(f"Simulation with randomized parameters using vmap took {t2-t1:.2f} seconds.")
193268

269+
270+
# TODO: OPTIMIZATION
271+
272+
194273
# Optimization
195274

196275
# # Error function

mujoco_sysid/mjx/model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,13 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
126126
return model
127127

128128

129+
# @jax.jit(static_argnames=['parameters_map'])
129130
def parametric_step(
130-
parameters: jnp.ndarray, parameters_map: Callable, model: mjx.Model, state: jnp.ndarray, control: jnp.ndarray
131+
parameters: jnp.ndarray,
132+
model: mjx.Model,
133+
state: jnp.ndarray,
134+
control: jnp.ndarray,
135+
parameters_map: Callable,
131136
) -> jnp.ndarray:
132137
"""
133138
Perform a step with new parameter mapping.
@@ -150,12 +155,13 @@ def parametric_step(
150155
return jnp.concatenate([data.qpos, data.qvel])
151156

152157

158+
# @jax.jit(static_argnames=['parameters_map'])
153159
def rollout_trajectory(
154160
parameters: jnp.ndarray,
155-
parameters_map: Callable,
156161
model: mjx.Model,
157162
initial_state: jnp.ndarray,
158163
control_inputs: jnp.ndarray,
164+
parameters_map: Callable,
159165
) -> jnp.ndarray:
160166
"""
161167
Rollout a trajectory given parameters, initial state, and control inputs.

0 commit comments

Comments
 (0)