Skip to content

Commit ebefda8

Browse files
committed
minor restructure and docstrings
1 parent 35d6ab0 commit ebefda8

16 files changed

+676
-335
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

_draft/pendulum.py

+341
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import jax.typing as jpt
4+
import mujoco
5+
from mujoco import mjx
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
from mujoco_sysid.mjx.convert import logchol2theta, theta2logchol
9+
from mujoco_sysid.mjx.parameters import get_dynamic_parameters, set_dynamic_parameters
10+
from time import perf_counter
11+
import optax
12+
13+
14+
@jax.jit
15+
def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
16+
"""Map new parameters to the model."""
17+
log_cholesky, damping, friction_loss = jnp.split(parameters, [10, 11])
18+
inertial_parameters = logchol2theta(log_cholesky)
19+
model = set_dynamic_parameters(model, 1, inertial_parameters)
20+
return model.tree_replace(
21+
{
22+
"dof_damping": model.dof_damping.at[0].set(damping[0]),
23+
"dof_frictionloss": model.dof_frictionloss.at[0].set(friction_loss[0]),
24+
}
25+
)
26+
27+
28+
@jax.jit
29+
def parametric_step(parameters: jnp.ndarray, model: mjx.Model, state: jnp.ndarray, control: jnp.ndarray) -> jnp.ndarray:
30+
"""Perform a step with new parameter mapping."""
31+
new_model = parameters_map(parameters, model)
32+
data = mjx.make_data(new_model).replace(qpos=state[: new_model.nq], qvel=state[new_model.nq :], ctrl=control)
33+
data = mjx.step(new_model, data)
34+
return jnp.concatenate([data.qpos, data.qvel])
35+
36+
37+
@jax.jit
38+
def rollout_trajectory(
39+
parameters: jnp.ndarray, model: mjx.Model, initial_state: jnp.ndarray, control_inputs: jnp.ndarray
40+
) -> jnp.ndarray:
41+
"""Rollout a trajectory given parameters, initial state, and control inputs."""
42+
43+
def step_fn(state, control):
44+
new_state = parametric_step(parameters, model, state, control)
45+
return new_state, new_state
46+
47+
(_, states) = jax.lax.scan(step_fn, initial_state, control_inputs)
48+
return states
49+
50+
51+
# Initialize random key
52+
key = jax.random.PRNGKey(0)
53+
54+
# Load the model
55+
MJCF_PATH = "../data/models/pendulum/pendulum.xml"
56+
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
57+
data = mujoco.MjData(model)
58+
model.opt.integrator = 1
59+
60+
# Setting up constraint solver to ensure differentiability and faster simulations
61+
model.opt.solver = 2 # 2 corresponds to Newton solver
62+
model.opt.iterations = 2
63+
model.opt.ls_iterations = 10
64+
65+
mjx_model = mjx.put_model(model)
66+
67+
# Load test data
68+
TEST_DATA_PATH = "../data/trajectories/pendulum/free_fall_2.csv"
69+
data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500)
70+
timespan = data_array[:, 0] - data_array[0, 0]
71+
sampling = np.mean(np.diff(timespan))
72+
angle = data_array[:, 1]
73+
velocity = data_array[:, 2]
74+
control = data_array[:, 3]
75+
76+
model.opt.timestep = sampling
77+
78+
HORIZON = 100
79+
N_INTERVALS = len(timespan) // HORIZON - 1
80+
timespan = timespan[: N_INTERVALS * HORIZON]
81+
angle = angle[: N_INTERVALS * HORIZON]
82+
velocity = velocity[: N_INTERVALS * HORIZON]
83+
control = control[: N_INTERVALS * HORIZON]
84+
85+
# Prepare data for simulation and optimization
86+
initial_state = jnp.array([angle[0], velocity[0]])
87+
true_trajectory = jnp.column_stack((angle, velocity))
88+
control_inputs = jnp.array(control)
89+
90+
interval_true_trajectory = true_trajectory[::HORIZON]
91+
interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)
92+
93+
# Get default parameters from the model
94+
default_parameters = jnp.concatenate(
95+
[theta2logchol(get_dynamic_parameters(mjx_model, 1)), mjx_model.dof_damping, mjx_model.dof_frictionloss]
96+
)
97+
98+
# //////////////////////////////////////
99+
# SIMULATION BATCHES: THIS WILL BE HANDY IN OPTIMIZATION
100+
101+
# Vectorize over both initial states and control inputs
102+
batched_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0)))
103+
104+
# Create a batch of initial states
105+
key, subkey = jax.random.split(key)
106+
batch_initial_states = jax.random.uniform(subkey, (N_INTERVALS, 2), minval=-0.1, maxval=0.1) + initial_state
107+
# Create a batch of control input sequences
108+
key, subkey = jax.random.split(key)
109+
batch_control_inputs = jax.random.normal(subkey, (N_INTERVALS, HORIZON)) * 0.1 # + control_inputs
110+
# Run warm up for batched rollout
111+
t1 = perf_counter()
112+
batched_trajectories = batched_rollout(default_parameters, mjx_model, batch_initial_states, batch_control_inputs)
113+
t2 = perf_counter()
114+
print(f"Batch simulation time: {t2 - t1} seconds")
115+
116+
# Run batched rollout on shor horizon data from pendulum
117+
interval_initial_states = true_trajectory[::HORIZON]
118+
interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)
119+
t1 = perf_counter()
120+
batched_states_trajectories = batched_rollout(
121+
default_parameters * 0.7, mjx_model, interval_initial_states, interval_controls
122+
)
123+
t2 = perf_counter()
124+
print(f"Batch simulation time: {t2 - t1} seconds")
125+
126+
batched_states_trajectories = np.array(batched_states_trajectories).reshape(N_INTERVALS * HORIZON, 2)
127+
128+
# Plotting simulation results for batсhed state trajectories
129+
plt.figure(figsize=(10, 5))
130+
131+
plt.subplot(2, 2, 1)
132+
plt.plot(timespan, angle, label="Actual Angle", color="black", linestyle="dashed", linewidth=2)
133+
plt.plot(timespan, batched_states_trajectories[:, 0], alpha=0.5, color="blue", label="Simulated Angle")
134+
plt.ylabel("Angle (rad)")
135+
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
136+
plt.legend()
137+
plt.title("Pendulum Dynamics - Bathed State Trajectories")
138+
139+
plt.subplot(2, 2, 3)
140+
plt.plot(timespan, velocity, label="Actual Velocity", color="black", linestyle="dashed", linewidth=2)
141+
plt.plot(timespan, batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated Velocity")
142+
plt.xlabel("Time (s)")
143+
plt.ylabel("Velocity (rad/s)")
144+
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
145+
plt.legend()
146+
147+
# Add phase portrait
148+
plt.subplot(1, 2, 2)
149+
plt.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2)
150+
plt.plot(
151+
batched_states_trajectories[:, 0], batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated"
152+
)
153+
plt.xlabel("Angle (rad)")
154+
plt.ylabel("Angular Velocity (rad/s)")
155+
plt.title("Phase Portrait")
156+
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
157+
plt.legend()
158+
159+
plt.tight_layout()
160+
plt.show()
161+
162+
# //////////////////////////////////////////////////
163+
# PARAMETRIC BATCHES
164+
# Create a batch of 200 randomized parameters
165+
num_batches = 200
166+
key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)
167+
168+
default_log_cholesky_first = default_parameters[0]
169+
default_damping = default_parameters[-2]
170+
default_dry_friction = default_parameters[-1]
171+
172+
randomized_log_cholesky_first = jax.random.uniform(
173+
subkey1, (num_batches,), minval=default_log_cholesky_first * 0.8, maxval=default_log_cholesky_first * 1.1
174+
)
175+
176+
randomized_damping = jax.random.uniform(
177+
subkey2, (num_batches,), minval=default_damping * 0.9, maxval=default_damping * 1.5
178+
)
179+
180+
randomized_dry_friction = jax.random.uniform(
181+
subkey3, (num_batches,), minval=default_dry_friction * 0.9, maxval=default_dry_friction * 1.5
182+
)
183+
184+
# Create a batch of parameters with randomized first log-Cholesky parameter, damping, and dry frictions
185+
batch_parameters = jnp.tile(default_parameters, (num_batches, 1))
186+
batch_parameters = batch_parameters.at[:, 0].set(randomized_log_cholesky_first)
187+
batch_parameters = batch_parameters.at[:, -2].set(randomized_damping)
188+
batch_parameters = batch_parameters.at[:, -1].set(randomized_dry_friction)
189+
190+
191+
# Define a batched version of rollout_trajectory using vmap
192+
batched_parameters_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(0, None, None, None)))
193+
194+
# Simulation with XML parameters
195+
xml_trajectory = rollout_trajectory(default_parameters, mjx_model, initial_state, control_inputs)
196+
197+
# Simulate trajectories with randomized parameters using vmap
198+
t1 = perf_counter()
199+
randomized_trajectories = batched_parameters_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
200+
t2 = perf_counter()
201+
202+
print(f"Simulation with randomized parameters using vmap took {t2-t1:.2f} seconds.")
203+
# Plotting simulation results (XML vs Randomized)
204+
plt.figure(figsize=(10, 5))
205+
206+
plt.subplot(2, 2, 1)
207+
plt.plot(timespan, angle, label="Actual Angle", color="black", linestyle="dashed", linewidth=2)
208+
for trajectory in randomized_trajectories:
209+
plt.plot(timespan, trajectory[:, 0], alpha=0.02, color="blue")
210+
plt.plot(timespan, xml_trajectory[:, 0], label="XML Model Angle", color="red", linewidth=2)
211+
plt.ylabel("Angle (rad)")
212+
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
213+
plt.grid(True)
214+
plt.legend()
215+
plt.title("Pendulum Dynamics - Randomized Parameters")
216+
217+
plt.subplot(2, 2, 3)
218+
plt.plot(timespan, velocity, label="Actual Velocity", color="black", linestyle="dashed", linewidth=2)
219+
for trajectory in randomized_trajectories:
220+
plt.plot(timespan, trajectory[:, 1], alpha=0.02, color="blue")
221+
plt.plot(timespan, xml_trajectory[:, 1], label="XML Model Velocity", color="red", linewidth=2)
222+
plt.xlabel("Time (s)")
223+
plt.ylabel("Velocity (rad/s)")
224+
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
225+
plt.grid(True)
226+
plt.legend()
227+
228+
# Add phase portrait
229+
plt.subplot(1, 2, 2)
230+
plt.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2)
231+
for trajectory in randomized_trajectories:
232+
plt.plot(trajectory[:, 0], trajectory[:, 1], alpha=0.02, color="blue")
233+
plt.plot(xml_trajectory[:, 0], xml_trajectory[:, 1], label="XML Model", color="red", linewidth=2)
234+
plt.xlabel("Angle (rad)")
235+
plt.ylabel("Angular Velocity (rad/s)")
236+
plt.title("Phase Portrait")
237+
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
238+
plt.grid(True)
239+
plt.legend()
240+
241+
plt.tight_layout()
242+
plt.show()
243+
244+
245+
randomized_log_cholesky_first = jax.random.uniform(
246+
subkey1, (num_batches,), minval=default_log_cholesky_first * 0.8, maxval=default_log_cholesky_first * 1.1
247+
)
248+
249+
randomized_damping = jax.random.uniform(
250+
subkey2, (num_batches,), minval=default_damping * 0.9, maxval=default_damping * 1.5
251+
)
252+
253+
randomized_dry_friction = jax.random.uniform(
254+
subkey3, (num_batches,), minval=default_dry_friction * 0.9, maxval=default_dry_friction * 1.5
255+
)
256+
257+
# Create a batch of parameters with randomized first log-Cholesky parameter, damping, and dry frictions
258+
batch_parameters = jnp.tile(default_parameters, (num_batches, 1))
259+
batch_parameters = batch_parameters.at[:, 0].set(randomized_log_cholesky_first)
260+
batch_parameters = batch_parameters.at[:, -2].set(randomized_damping)
261+
batch_parameters = batch_parameters.at[:, -1].set(randomized_dry_friction)
262+
263+
# Simulate trajectories with randomized parameters using vmap
264+
t1 = perf_counter()
265+
randomized_trajectories = batched_parameters_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
266+
t2 = perf_counter()
267+
print(f"Simulation with randomized parameters using vmap took {t2-t1:.2f} seconds.")
268+
269+
270+
# TODO: OPTIMIZATION
271+
272+
273+
# Optimization
274+
275+
# # Error function
276+
# def trajectory_error(parameters: jnp.ndarray, model: mjx.Model, initial_state: jnp.ndarray, control_inputs: jnp.ndarray, true_trajectory: jnp.ndarray) -> jnp.ndarray:
277+
# predicted_trajectory = rollout_trajectory(parameters, model, initial_state, control_inputs)
278+
# return jnp.mean(jnp.square(predicted_trajectory - true_trajectory))
279+
280+
# # Optimization
281+
# @jax.jit
282+
# def update_step(parameters, opt_state, model, initial_state, control_inputs, true_trajectory):
283+
# loss, grads = jax.value_and_grad(trajectory_error)(parameters, model, initial_state, control_inputs, true_trajectory)
284+
# updates, opt_state = optimizer.update(grads, opt_state, parameters)
285+
# parameters = optax.apply_updates(parameters, updates)
286+
# return parameters, opt_state, loss
287+
288+
# # Initial parameters for optimization (using randomized parameters as starting point)
289+
# initial_parameters = randomized_parameters
290+
291+
# # Optimization setup
292+
# learning_rate = 1e-3
293+
# optimizer = optax.adam(learning_rate)
294+
# opt_state = optimizer.init(initial_parameters)
295+
296+
# # Optimization loop
297+
# num_iterations = 1000
298+
# for i in range(num_iterations):
299+
# initial_parameters, opt_state, loss = update_step(initial_parameters, opt_state, mjx_model, initial_state, control_inputs, true_trajectory)
300+
# if i % 100 == 0:
301+
# print(f"Iteration {i}, Loss: {loss}")
302+
303+
# # Final simulation with learned parameters
304+
# final_trajectory = rollout_trajectory(initial_parameters, mjx_model, initial_state, control_inputs)
305+
306+
# # Plotting optimization results
307+
# plt.figure(figsize=(12, 9))
308+
309+
# plt.subplot(2, 1, 1)
310+
# plt.plot(timespan, angle, label='Actual Angle')
311+
# plt.plot(timespan, xml_trajectory[:, 0], label='XML Model Angle', linestyle='dashed')
312+
# plt.plot(timespan, randomized_trajectory[:, 0], label='Initial Randomized Angle', linestyle='dotted')
313+
# plt.plot(timespan, final_trajectory[:, 0], label='Learned Model Angle', linestyle='dashdot')
314+
# plt.ylabel('Angle (rad)')
315+
# plt.legend()
316+
# plt.title('Pendulum Dynamics Comparison - Optimization')
317+
318+
# plt.subplot(2, 1, 2)
319+
# plt.plot(timespan, velocity, label='Actual Velocity')
320+
# plt.plot(timespan, xml_trajectory[:, 1], label='XML Model Velocity', linestyle='dashed')
321+
# plt.plot(timespan, randomized_trajectory[:, 1], label='Initial Randomized Velocity', linestyle='dotted')
322+
# plt.plot(timespan, final_trajectory[:, 1], label='Learned Model Velocity', linestyle='dashdot')
323+
# plt.xlabel('Time (s)')
324+
# plt.ylabel('Velocity (rad/s)')
325+
# plt.legend()
326+
327+
# plt.tight_layout()
328+
# plt.show()
329+
330+
# # Print learned parameters
331+
# learned_log_cholesky, learned_damping, learned_friction_loss = jnp.split(initial_parameters, [10, 11])
332+
# learned_theta = logchol2theta(learned_log_cholesky)
333+
334+
# print("\nParameters comparison:")
335+
# print("Parameter\t\tXML\t\tRandomized\tLearned")
336+
# print(f"Inertia:\t\t{get_dynamic_parameters(mjx_model, 1)[0]:.6f}\t{logchol2theta(randomized_parameters[:10])[0]:.6f}\t{learned_theta[0]:.6f}")
337+
# print(f"Damping:\t\t{mjx_model.dof_damping[0]:.6f}\t{randomized_parameters[10]:.6f}\t{learned_damping[0]:.6f}")
338+
# print(f"Friction loss:\t{mjx_model.dof_frictionloss[0]:.6f}\t{randomized_parameters[11]:.6f}\t{learned_friction_loss[0]:.6f}")
339+
340+
# TODO: Save the learned parameters to a new XML file
341+
# This would require additional code to modify the XML file with the learned parameters

examples/pendulum_ref.py renamed to _draft/pendulum_ref.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
from mujoco_sysid.mjx.convert import logchol2theta, theta2logchol
1111
from mujoco_sysid.mjx.model import create_rollout
1212
from mujoco_sysid.mjx.parameters import get_dynamic_parameters, set_dynamic_parameters
13+
import os
14+
15+
# SHOULD WE MOVE THIS IN TO MODULE INIT?
16+
xla_flags = os.environ.get("XLA_FLAGS", "")
17+
xla_flags += " --xla_gpu_triton_gemm_any=True"
18+
os.environ["XLA_FLAGS"] = xla_flags
1319

1420

1521
@jax.jit
@@ -56,7 +62,7 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
5662

5763
model.opt.timestep = sampling
5864

59-
HORIZON = 100
65+
HORIZON = 50
6066
N_INTERVALS = len(timespan) // HORIZON - 1
6167
timespan = timespan[: N_INTERVALS * HORIZON]
6268
angle = angle[: N_INTERVALS * HORIZON]
@@ -98,9 +104,7 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
98104
interval_initial_states = true_trajectory[::HORIZON]
99105
interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)
100106
t1 = perf_counter()
101-
batched_states_trajectories = batched_rollout(
102-
default_parameters * 0.7, mjx_model, interval_initial_states, interval_controls
103-
)
107+
batched_states_trajectories = batched_rollout(default_parameters, mjx_model, interval_initial_states, interval_controls)
104108
t2 = perf_counter()
105109
print(f"Batch simulation time: {t2 - t1} seconds")
106110

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

examples/TODO

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Identification
2+
☐ Pendulum dynamics estimation
3+
☐ Panda load parameter adaptation
4+
☐ GO2 walking + load + friction + damping
5+
6+
Batch simulation and sensitivity
7+
☐ Which parameters are important, jacobian wrt parameters
8+
☐ Smulate Quadrotor, add parameter distrubance for batch and check sensetivity
9+

0 commit comments

Comments
 (0)