1111import optax
1212
1313
14- # Initialize random key
15- key = jax .random .PRNGKey (0 )
16-
17- # Load the model
18- MJCF_PATH = "../data/models/pendulum/pendulum.xml"
19- model = mujoco .MjModel .from_xml_path (MJCF_PATH )
20- data = mujoco .MjData (model )
21- model .opt .integrator = 1
22-
23- # Setting up constraint solver to ensure differentiability and faster simulations
24- model .opt .solver = 2 # 2 corresponds to Newton solver
25- model .opt .iterations = 2
26- model .opt .ls_iterations = 10
27-
28- mjx_model = mjx .put_model (model )
29-
30- # Load test data
31- TEST_DATA_PATH = "../data/trajectories/pendulum/free_fall_2.csv"
32- data_array = np .genfromtxt (TEST_DATA_PATH , delimiter = "," , skip_header = 100 , skip_footer = 2500 )
33- timespan = data_array [:, 0 ] - data_array [0 , 0 ]
34- sampling = np .mean (np .diff (timespan ))
35- angle = data_array [:, 1 ]
36- velocity = data_array [:, 2 ]
37- control = data_array [:, 3 ]
38-
39- model .opt .timestep = sampling
40-
41-
4214@jax .jit
4315def parameters_map (parameters : jnp .ndarray , model : mjx .Model ) -> mjx .Model :
4416 """Map new parameters to the model."""
@@ -76,19 +48,119 @@ def step_fn(state, control):
7648 return states
7749
7850
51+ # Initialize random key
52+ key = jax .random .PRNGKey (0 )
53+
54+ # Load the model
55+ MJCF_PATH = "../data/models/pendulum/pendulum.xml"
56+ model = mujoco .MjModel .from_xml_path (MJCF_PATH )
57+ data = mujoco .MjData (model )
58+ model .opt .integrator = 1
59+
60+ # Setting up constraint solver to ensure differentiability and faster simulations
61+ model .opt .solver = 2 # 2 corresponds to Newton solver
62+ model .opt .iterations = 2
63+ model .opt .ls_iterations = 10
64+
65+ mjx_model = mjx .put_model (model )
66+
67+ # Load test data
68+ TEST_DATA_PATH = "../data/trajectories/pendulum/free_fall_2.csv"
69+ data_array = np .genfromtxt (TEST_DATA_PATH , delimiter = "," , skip_header = 100 , skip_footer = 2500 )
70+ timespan = data_array [:, 0 ] - data_array [0 , 0 ]
71+ sampling = np .mean (np .diff (timespan ))
72+ angle = data_array [:, 1 ]
73+ velocity = data_array [:, 2 ]
74+ control = data_array [:, 3 ]
75+
76+ model .opt .timestep = sampling
77+
78+ HORIZON = 100
79+ N_INTERVALS = len (timespan ) // HORIZON - 1
80+ timespan = timespan [: N_INTERVALS * HORIZON ]
81+ angle = angle [: N_INTERVALS * HORIZON ]
82+ velocity = velocity [: N_INTERVALS * HORIZON ]
83+ control = control [: N_INTERVALS * HORIZON ]
84+
7985# Prepare data for simulation and optimization
8086initial_state = jnp .array ([angle [0 ], velocity [0 ]])
8187true_trajectory = jnp .column_stack ((angle , velocity ))
8288control_inputs = jnp .array (control )
8389
90+ interval_true_trajectory = true_trajectory [::HORIZON ]
91+ interval_controls = control_inputs .reshape (N_INTERVALS , HORIZON )
92+
8493# Get default parameters from the model
8594default_parameters = jnp .concatenate (
8695 [theta2logchol (get_dynamic_parameters (mjx_model , 1 )), mjx_model .dof_damping , mjx_model .dof_frictionloss ]
8796)
8897
89- # Simulation with XML parameters
90- xml_trajectory = rollout_trajectory (default_parameters , mjx_model , initial_state , control_inputs )
98+ # //////////////////////////////////////
99+ # SIMULATION BATCHES: THIS WILL BE HANDY IN OPTIMIZATION
100+
101+ # Vectorize over both initial states and control inputs
102+ batched_rollout = jax .jit (jax .vmap (rollout_trajectory , in_axes = (None , None , 0 , 0 )))
103+
104+ # Create a batch of initial states
105+ key , subkey = jax .random .split (key )
106+ batch_initial_states = jax .random .uniform (subkey , (N_INTERVALS , 2 ), minval = - 0.1 , maxval = 0.1 ) + initial_state
107+ # Create a batch of control input sequences
108+ key , subkey = jax .random .split (key )
109+ batch_control_inputs = jax .random .normal (subkey , (N_INTERVALS , HORIZON )) * 0.1 # + control_inputs
110+ # Run warm up for batched rollout
111+ t1 = perf_counter ()
112+ batched_trajectories = batched_rollout (default_parameters , mjx_model , batch_initial_states , batch_control_inputs )
113+ t2 = perf_counter ()
114+ print (f"Batch simulation time: { t2 - t1 } seconds" )
115+
116+ # Run batched rollout on shor horizon data from pendulum
117+ interval_initial_states = true_trajectory [::HORIZON ]
118+ interval_controls = control_inputs .reshape (N_INTERVALS , HORIZON )
119+ t1 = perf_counter ()
120+ batched_states_trajectories = batched_rollout (
121+ default_parameters * 0.7 , mjx_model , interval_initial_states , interval_controls
122+ )
123+ t2 = perf_counter ()
124+ print (f"Batch simulation time: { t2 - t1 } seconds" )
125+
126+ batched_states_trajectories = np .array (batched_states_trajectories ).reshape (N_INTERVALS * HORIZON , 2 )
91127
128+ # Plotting simulation results for batсhed state trajectories
129+ plt .figure (figsize = (10 , 5 ))
130+
131+ plt .subplot (2 , 2 , 1 )
132+ plt .plot (timespan , angle , label = "Actual Angle" , color = "black" , linestyle = "dashed" , linewidth = 2 )
133+ plt .plot (timespan , batched_states_trajectories [:, 0 ], alpha = 0.5 , color = "blue" , label = "Simulated Angle" )
134+ plt .ylabel ("Angle (rad)" )
135+ plt .grid (color = "black" , linestyle = "--" , linewidth = 1.0 , alpha = 0.4 )
136+ plt .legend ()
137+ plt .title ("Pendulum Dynamics - Bathed State Trajectories" )
138+
139+ plt .subplot (2 , 2 , 3 )
140+ plt .plot (timespan , velocity , label = "Actual Velocity" , color = "black" , linestyle = "dashed" , linewidth = 2 )
141+ plt .plot (timespan , batched_states_trajectories [:, 1 ], alpha = 0.5 , color = "blue" , label = "Simulated Velocity" )
142+ plt .xlabel ("Time (s)" )
143+ plt .ylabel ("Velocity (rad/s)" )
144+ plt .grid (color = "black" , linestyle = "--" , linewidth = 1.0 , alpha = 0.4 )
145+ plt .legend ()
146+
147+ # Add phase portrait
148+ plt .subplot (1 , 2 , 2 )
149+ plt .plot (angle , velocity , label = "Actual" , color = "black" , linestyle = "dashed" , linewidth = 2 )
150+ plt .plot (
151+ batched_states_trajectories [:, 0 ], batched_states_trajectories [:, 1 ], alpha = 0.5 , color = "blue" , label = "Simulated"
152+ )
153+ plt .xlabel ("Angle (rad)" )
154+ plt .ylabel ("Angular Velocity (rad/s)" )
155+ plt .title ("Phase Portrait" )
156+ plt .grid (color = "black" , linestyle = "--" , linewidth = 1.0 , alpha = 0.4 )
157+ plt .legend ()
158+
159+ plt .tight_layout ()
160+ plt .show ()
161+
162+ # //////////////////////////////////////////////////
163+ # PARAMETRIC BATCHES
92164# Create a batch of 200 randomized parameters
93165num_batches = 200
94166key , subkey1 , subkey2 , subkey3 = jax .random .split (key , 4 )
@@ -115,13 +187,16 @@ def step_fn(state, control):
115187batch_parameters = batch_parameters .at [:, - 2 ].set (randomized_damping )
116188batch_parameters = batch_parameters .at [:, - 1 ].set (randomized_dry_friction )
117189
190+
118191# Define a batched version of rollout_trajectory using vmap
119- batched_rollout = jax .jit (jax .vmap (rollout_trajectory , in_axes = (0 , None , None , None )))
192+ batched_parameters_rollout = jax .jit (jax .vmap (rollout_trajectory , in_axes = (0 , None , None , None )))
120193
194+ # Simulation with XML parameters
195+ xml_trajectory = rollout_trajectory (default_parameters , mjx_model , initial_state , control_inputs )
121196
122197# Simulate trajectories with randomized parameters using vmap
123198t1 = perf_counter ()
124- randomized_trajectories = batched_rollout (batch_parameters , mjx_model , initial_state , control_inputs )
199+ randomized_trajectories = batched_parameters_rollout (batch_parameters , mjx_model , initial_state , control_inputs )
125200t2 = perf_counter ()
126201
127202print (f"Simulation with randomized parameters using vmap took { t2 - t1 :.2f} seconds." )
@@ -187,10 +262,14 @@ def step_fn(state, control):
187262
188263# Simulate trajectories with randomized parameters using vmap
189264t1 = perf_counter ()
190- randomized_trajectories = batched_rollout (batch_parameters , mjx_model , initial_state , control_inputs )
265+ randomized_trajectories = batched_parameters_rollout (batch_parameters , mjx_model , initial_state , control_inputs )
191266t2 = perf_counter ()
192267print (f"Simulation with randomized parameters using vmap took { t2 - t1 :.2f} seconds." )
193268
269+
270+ # TODO: OPTIMIZATION
271+
272+
194273# Optimization
195274
196275# # Error function
0 commit comments