Skip to content

Commit c9bfb1c

Browse files
committed
added learning exmple (WIP)
1 parent ebefda8 commit c9bfb1c

File tree

5 files changed

+247
-21
lines changed

5 files changed

+247
-21
lines changed

data/ltv_lqr_traj.json

-1
This file was deleted.

data/trajectories/skydio.npz

-6.59 KB
Binary file not shown.

data/trajectories/skydio_easy.npz

-6.59 KB
Binary file not shown.

examples/pendulum.py renamed to examples/pendulum/01_simulation_error.py

+44-20
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
import mujoco
77
import numpy as np
88
from mujoco import mjx
9-
109
from mujoco_sysid.mjx.convert import logchol2theta, theta2logchol
1110
from mujoco_sysid.mjx.model import create_rollout
1211
from mujoco_sysid.mjx.parameters import get_dynamic_parameters, set_dynamic_parameters
1312
import os
13+
import optax
14+
from mujoco.mjx._src.types import IntegratorType
1415

1516
# SHOULD WE MOVE THIS IN TO MODULE INIT?
1617
xla_flags = os.environ.get("XLA_FLAGS", "")
@@ -33,27 +34,28 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
3334

3435

3536
rollout_trajectory = jax.jit(create_rollout(parameters_map))
36-
37+
3738

3839
# Initialize random key
3940
key = jax.random.PRNGKey(0)
4041

4142
# Load the model
42-
MJCF_PATH = "../data/models/pendulum/pendulum.xml"
43+
MJCF_PATH = "../../data/models/pendulum/pendulum.xml"
4344
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
4445
data = mujoco.MjData(model)
45-
model.opt.integrator = 1
46+
model.opt.integrator = IntegratorType.EULER
4647

4748
# Setting up constraint solver to ensure differentiability and faster simulations
4849
model.opt.solver = 2 # 2 corresponds to Newton solver
49-
model.opt.iterations = 2
50+
model.opt.iterations = 1
5051
model.opt.ls_iterations = 10
5152

5253
mjx_model = mjx.put_model(model)
5354

5455
# Load test data
55-
TEST_DATA_PATH = "../data/trajectories/pendulum/free_fall_2.csv"
56-
data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500)
56+
TEST_DATA_PATH = "../../data/trajectories/pendulum/free_fall_2.csv"
57+
data_array = np.genfromtxt(
58+
TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500)
5759
timespan = data_array[:, 0] - data_array[0, 0]
5860
sampling = np.mean(np.diff(timespan))
5961
angle = data_array[:, 1]
@@ -79,62 +81,83 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
7981

8082
# Get default parameters from the model
8183
default_parameters = jnp.concatenate(
82-
[theta2logchol(get_dynamic_parameters(mjx_model, 1)), mjx_model.dof_damping, mjx_model.dof_frictionloss]
84+
[theta2logchol(get_dynamic_parameters(mjx_model, 1)),
85+
mjx_model.dof_damping, mjx_model.dof_frictionloss]
8386
)
8487

8588
# //////////////////////////////////////
8689
# SIMULATION BATCHES: THIS WILL BE HANDY IN OPTIMIZATION
8790

8891
# Vectorize over both initial states and control inputs
89-
batched_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0)))
92+
batched_rollout = jax.jit(
93+
jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0)))
9094

9195
# Create a batch of initial states
9296
key, subkey = jax.random.split(key)
93-
batch_initial_states = jax.random.uniform(subkey, (N_INTERVALS, 2), minval=-0.1, maxval=0.1) + initial_state
97+
batch_initial_states = jax.random.uniform(
98+
subkey, (N_INTERVALS, 2), minval=-0.1, maxval=0.1) + initial_state
9499
# Create a batch of control input sequences
95100
key, subkey = jax.random.split(key)
96-
batch_control_inputs = jax.random.normal(subkey, (N_INTERVALS, HORIZON)) * 0.1 # + control_inputs
101+
batch_control_inputs = jax.random.normal(
102+
subkey, (N_INTERVALS, HORIZON)) * 0.1 # + control_inputs
97103
# Run warm up for batched rollout
98104
t1 = perf_counter()
99-
batched_trajectories = batched_rollout(default_parameters, mjx_model, batch_initial_states, batch_control_inputs)
105+
batched_trajectories = batched_rollout(
106+
default_parameters, mjx_model, batch_initial_states, batch_control_inputs)
100107
t2 = perf_counter()
101108
print(f"Batch simulation time: {t2 - t1} seconds")
102109

103110
# Run batched rollout on shor horizon data from pendulum
104111
interval_initial_states = true_trajectory[::HORIZON]
112+
interval_terminal_states = true_trajectory[HORIZON+1:][::HORIZON]
105113
interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)
106114
t1 = perf_counter()
107-
batched_states_trajectories = batched_rollout(default_parameters, mjx_model, interval_initial_states, interval_controls)
115+
batched_states_trajectories = batched_rollout(
116+
default_parameters*0.7, mjx_model, interval_initial_states, interval_controls)
108117
t2 = perf_counter()
109118
print(f"Batch simulation time: {t2 - t1} seconds")
110119

111-
batched_states_trajectories = np.array(batched_states_trajectories).reshape(N_INTERVALS * HORIZON, 2)
112-
120+
predicted_terminal_points = np.array(batched_states_trajectories)[:,-1,:]
121+
batched_states_trajectories = np.array(
122+
batched_states_trajectories).reshape(N_INTERVALS * HORIZON, 2)
113123
# Plotting simulation results for batсhed state trajectories
114124
plt.figure(figsize=(10, 5))
115125

116126
plt.subplot(2, 2, 1)
117-
plt.plot(timespan, angle, label="Actual Angle", color="black", linestyle="dashed", linewidth=2)
118-
plt.plot(timespan, batched_states_trajectories[:, 0], alpha=0.5, color="blue", label="Simulated Angle")
127+
plt.plot(timespan, angle, label="Actual Angle",
128+
color="black", linestyle="dashed", linewidth=2)
129+
plt.plot(timespan, batched_states_trajectories[:, 0],
130+
alpha=0.5, color="blue", label="Simulated Angle")
131+
plt.plot(timespan, angle, label="Actual Angle",
132+
color="black", linestyle="dashed", linewidth=2)
133+
plt.plot(timespan[HORIZON+1:][::HORIZON], predicted_terminal_points[:-1,0], 'ob')
134+
plt.plot(timespan[HORIZON+1:][::HORIZON], interval_terminal_states[:, 0], 'or')
119135
plt.ylabel("Angle (rad)")
120136
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
121137
plt.legend()
122138
plt.title("Pendulum Dynamics - Bathed State Trajectories")
123139

124140
plt.subplot(2, 2, 3)
125-
plt.plot(timespan, velocity, label="Actual Velocity", color="black", linestyle="dashed", linewidth=2)
126-
plt.plot(timespan, batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated Velocity")
141+
plt.plot(timespan, velocity, label="Actual Velocity",
142+
color="black", linestyle="dashed", linewidth=2)
143+
plt.plot(timespan[HORIZON+1:][::HORIZON], predicted_terminal_points[:-1,1], 'ob')
144+
plt.plot(timespan[HORIZON+1:][::HORIZON], interval_terminal_states[:, 1], 'or')
145+
plt.plot(timespan, batched_states_trajectories[:, 1],
146+
alpha=0.5, color="blue", label="Simulated Velocity")
127147
plt.xlabel("Time (s)")
128148
plt.ylabel("Velocity (rad/s)")
129149
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
130150
plt.legend()
131151

132152
# Add phase portrait
133153
plt.subplot(1, 2, 2)
134-
plt.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2)
154+
plt.plot(angle, velocity, label="Actual",
155+
color="black", linestyle="dashed", linewidth=2)
135156
plt.plot(
136157
batched_states_trajectories[:, 0], batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated"
137158
)
159+
plt.plot(predicted_terminal_points[:-1,0], predicted_terminal_points[:-1,1], 'ob')
160+
plt.plot(interval_terminal_states[:, 0], interval_terminal_states[:, 1], 'or')
138161
plt.xlabel("Angle (rad)")
139162
plt.ylabel("Angular Velocity (rad/s)")
140163
plt.title("Phase Portrait")
@@ -145,3 +168,4 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
145168
plt.show()
146169
# TODO:
147170
# Optimization
171+

examples/pendulum/02_learning.py

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from time import perf_counter
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import matplotlib.pyplot as plt
6+
import mujoco
7+
import numpy as np
8+
from mujoco import mjx
9+
from mujoco_sysid.mjx.convert import logchol2theta, theta2logchol
10+
from mujoco_sysid.mjx.model import create_rollout
11+
from mujoco_sysid.mjx.parameters import get_dynamic_parameters, set_dynamic_parameters
12+
import os
13+
import optax
14+
from mujoco.mjx._src.types import IntegratorType
15+
16+
# SHOULD WE MOVE THIS IN TO MODULE INIT?
17+
xla_flags = os.environ.get("XLA_FLAGS", "")
18+
xla_flags += " --xla_gpu_triton_gemm_any=True"
19+
os.environ["XLA_FLAGS"] = xla_flags
20+
21+
22+
@jax.jit
23+
def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
24+
"""Map new parameters to the model."""
25+
log_cholesky, log_damping, log_friction_loss = jnp.split(parameters, [10, 11])
26+
inertial_parameters = logchol2theta(log_cholesky)
27+
model = set_dynamic_parameters(model, 0, inertial_parameters)
28+
damping = jnp.exp(log_damping[0])
29+
friction_loss = jnp.exp(log_friction_loss[0])
30+
return model.tree_replace(
31+
{
32+
"dof_damping": model.dof_damping.at[0].set(damping),
33+
"dof_frictionloss": model.dof_frictionloss.at[0].set(friction_loss),
34+
}
35+
)
36+
37+
38+
rollout_trajectory = jax.jit(create_rollout(parameters_map))
39+
40+
41+
# Initialize random key
42+
key = jax.random.PRNGKey(0)
43+
44+
# Load the model
45+
MJCF_PATH = "../../data/models/pendulum/pendulum.xml"
46+
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
47+
data = mujoco.MjData(model)
48+
model.opt.integrator = IntegratorType.EULER
49+
50+
# Setting up constraint solver to ensure differentiability and faster simulations
51+
model.opt.solver = 2 # 2 corresponds to Newton solver
52+
model.opt.iterations = 1
53+
model.opt.ls_iterations = 10
54+
55+
mjx_model = mjx.put_model(model)
56+
57+
# Load test data
58+
TEST_DATA_PATH = "../../data/trajectories/pendulum/free_fall_2.csv"
59+
data_array = np.genfromtxt(
60+
TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500)
61+
timespan = data_array[:, 0] - data_array[0, 0]
62+
sampling = np.mean(np.diff(timespan))
63+
angle = data_array[:, 1]
64+
velocity = data_array[:, 2]
65+
control = data_array[:, 3]
66+
67+
model.opt.timestep = sampling
68+
69+
HORIZON = 10
70+
N_INTERVALS = len(timespan) // HORIZON - 1
71+
timespan = timespan[: N_INTERVALS * HORIZON]
72+
angle = angle[: N_INTERVALS * HORIZON]
73+
velocity = velocity[: N_INTERVALS * HORIZON]
74+
control = control[: N_INTERVALS * HORIZON]
75+
76+
# Prepare data for simulation and optimization
77+
initial_state = jnp.array([angle[0], velocity[0]])
78+
true_trajectory = jnp.column_stack((angle, velocity))
79+
control_inputs = jnp.array(control)
80+
81+
interval_true_trajectory = true_trajectory[::HORIZON]
82+
interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)
83+
84+
# Get default parameters from the model
85+
default_parameters = jnp.concatenate(
86+
[theta2logchol(get_dynamic_parameters(mjx_model, 1)),
87+
jnp.log(mjx_model.dof_damping), jnp.log(mjx_model.dof_frictionloss)]
88+
)
89+
90+
# print()
91+
@jax.jit
92+
def rollout_errors(parameters, states, controls):
93+
interval_initial_states = states[::HORIZON]
94+
interval_terminal_states = states[HORIZON+1:][::HORIZON]
95+
interval_controls = jnp.reshape(controls, (N_INTERVALS, HORIZON))
96+
batched_rollout = jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0))
97+
batched_states_trajectories = batched_rollout(parameters, mjx_model, interval_initial_states, interval_controls)
98+
predicted_terminal_points = batched_states_trajectories[:,-1,:]
99+
loss = jnp.mean(optax.l2_loss(predicted_terminal_points[:-1], interval_terminal_states)) + 0.05*jnp.mean(optax.huber_loss(parameters, jnp.zeros_like(parameters)))
100+
return loss
101+
102+
start_learning_rate = 1e-3
103+
optimizer = optax.adam(learning_rate = start_learning_rate)
104+
105+
106+
# Initialize parameters of the model + optimizer.
107+
params = jnp.array(0.5*default_parameters)
108+
opt_state = optimizer.init(params)
109+
val_and_grad = jax.jit(jax.value_and_grad(rollout_errors))
110+
loss_val, loss_grad = val_and_grad(params, true_trajectory, control_inputs)
111+
112+
# A simple update loop.
113+
for _ in range(100):
114+
loss_val, loss_grad = val_and_grad(params, true_trajectory, control_inputs)
115+
updates, opt_state = optimizer.update(loss_grad, opt_state)
116+
params = optax.apply_updates(params, updates)
117+
print(loss_val, params)
118+
119+
# assert jnp.allclose(params, target_params), \
120+
# 'Optimization should retrive the target params used to generate the data.'
121+
# # //////////////////////////////////////
122+
# # SIMULATION BATCHES: THIS WILL BE HANDY IN OPTIMIZATION
123+
124+
# # Vectorize over both initial states and control inputs
125+
126+
127+
# # Create a batch of initial states
128+
# key, subkey = jax.random.split(key)
129+
# batch_initial_states = jax.random.uniform(
130+
# subkey, (N_INTERVALS, 2), minval=-0.1, maxval=0.1) + initial_state
131+
# # Create a batch of control input sequences
132+
# key, subkey = jax.random.split(key)
133+
# batch_control_inputs = jax.random.normal(
134+
# subkey, (N_INTERVALS, HORIZON)) * 0.1 # + control_inputs
135+
# # Run warm up for batched rollout
136+
# t1 = perf_counter()
137+
# batched_trajectories = batched_rollout(
138+
# default_parameters, mjx_model, batch_initial_states, batch_control_inputs)
139+
# t2 = perf_counter()
140+
# print(f"Batch simulation time: {t2 - t1} seconds")
141+
142+
# # Run batched rollout on shor horizon data from pendulum
143+
# interval_initial_states = true_trajectory[::HORIZON]
144+
# interval_terminal_states = true_trajectory[HORIZON+1:][::HORIZON]
145+
# interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)
146+
# batched_states_trajectories = batched_rollout(
147+
# default_parameters*0.999999, mjx_model, interval_initial_states, interval_controls)
148+
# t1 = perf_counter()
149+
# t2 = perf_counter()
150+
# print(f"Batch simulation time: {t2 - t1} seconds")
151+
152+
# predicted_terminal_points = np.array(batched_states_trajectories)[:,-1,:]
153+
# batched_states_trajectories = np.array(
154+
# batched_states_trajectories).reshape(N_INTERVALS * HORIZON, 2)
155+
# # Plotting simulation results for batсhed state trajectories
156+
# plt.figure(figsize=(10, 5))
157+
158+
# plt.subplot(2, 2, 1)
159+
# plt.plot(timespan, angle, label="Actual Angle",
160+
# color="black", linestyle="dashed", linewidth=2)
161+
# plt.plot(timespan, batched_states_trajectories[:, 0],
162+
# alpha=0.5, color="blue", label="Simulated Angle")
163+
# plt.plot(timespan, angle, label="Actual Angle",
164+
# color="black", linestyle="dashed", linewidth=2)
165+
# plt.plot(timespan[HORIZON+1:][::HORIZON], predicted_terminal_points[:-1,0], 'ob')
166+
# plt.plot(timespan[HORIZON+1:][::HORIZON], interval_terminal_states[:, 0], 'or')
167+
# plt.ylabel("Angle (rad)")
168+
# plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
169+
# plt.legend()
170+
# plt.title("Pendulum Dynamics - Bathed State Trajectories")
171+
172+
# plt.subplot(2, 2, 3)
173+
# plt.plot(timespan, velocity, label="Actual Velocity",
174+
# color="black", linestyle="dashed", linewidth=2)
175+
# plt.plot(timespan[HORIZON+1:][::HORIZON], predicted_terminal_points[:-1,1], 'ob')
176+
# plt.plot(timespan[HORIZON+1:][::HORIZON], interval_terminal_states[:, 1], 'or')
177+
# plt.plot(timespan, batched_states_trajectories[:, 1],
178+
# alpha=0.5, color="blue", label="Simulated Velocity")
179+
# plt.xlabel("Time (s)")
180+
# plt.ylabel("Velocity (rad/s)")
181+
# plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
182+
# plt.legend()
183+
184+
# # Add phase portrait
185+
# plt.subplot(1, 2, 2)
186+
# plt.plot(angle, velocity, label="Actual",
187+
# color="black", linestyle="dashed", linewidth=2)
188+
# plt.plot(
189+
# batched_states_trajectories[:, 0], batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated"
190+
# )
191+
# plt.plot(predicted_terminal_points[:-1,0], predicted_terminal_points[:-1,1], 'ob')
192+
# plt.plot(interval_terminal_states[:, 0], interval_terminal_states[:, 1], 'or')
193+
# plt.xlabel("Angle (rad)")
194+
# plt.ylabel("Angular Velocity (rad/s)")
195+
# plt.title("Phase Portrait")
196+
# plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
197+
# plt.legend()
198+
199+
# plt.tight_layout()
200+
# plt.show()
201+
# # TODO:
202+
# # Optimization
203+

0 commit comments

Comments
 (0)