Skip to content

Commit 01b7121

Browse files
committed
cleaned learning example, added animation
1 parent f37bc4e commit 01b7121

10 files changed

+355
-137
lines changed

examples/pendulum/02_simulation_error.py renamed to _draft/02_simulation_error.py

+15-40
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import optax
1414
from mujoco.mjx._src.types import IntegratorType
15+
from _plotting_utils import plot_simulation_errors
1516

1617
# SHOULD WE MOVE THIS IN TO MODULE INIT?
1718
xla_flags = os.environ.get("XLA_FLAGS", "")
@@ -40,7 +41,7 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
4041
key = jax.random.PRNGKey(0)
4142

4243
# Load the model
43-
MJCF_PATH = "models/pendulum.xml"
44+
MJCF_PATH = "models/pendulum_estimated.xml"
4445
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
4546
data = mujoco.MjData(model)
4647
model.opt.integrator = IntegratorType.EULER
@@ -112,45 +113,19 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
112113

113114
predicted_terminal_points = np.array(batched_states_trajectories)[:, -1, :]
114115
batched_states_trajectories = np.array(batched_states_trajectories).reshape(N_INTERVALS * HORIZON, 2)
115-
# Plotting simulation results for batсhed state trajectories
116-
plt.figure(figsize=(10, 5))
117-
118-
plt.subplot(2, 2, 1)
119-
plt.plot(timespan, angle, label="Actual Angle", color="black", linestyle="dashed", linewidth=2)
120-
plt.plot(timespan, batched_states_trajectories[:, 0], alpha=0.5, color="blue", label="Simulated Angle")
121-
plt.plot(timespan, angle, label="Actual Angle", color="black", linestyle="dashed", linewidth=2)
122-
plt.plot(timespan[HORIZON + 1 :][::HORIZON], predicted_terminal_points[:-1, 0], "ob")
123-
plt.plot(timespan[HORIZON + 1 :][::HORIZON], interval_terminal_states[:, 0], "or")
124-
plt.ylabel("Angle (rad)")
125-
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
126-
plt.legend()
127-
plt.title("Pendulum Dynamics - Bathed State Trajectories")
128-
129-
plt.subplot(2, 2, 3)
130-
plt.plot(timespan, velocity, label="Actual Velocity", color="black", linestyle="dashed", linewidth=2)
131-
plt.plot(timespan[HORIZON + 1 :][::HORIZON], predicted_terminal_points[:-1, 1], "ob")
132-
plt.plot(timespan[HORIZON + 1 :][::HORIZON], interval_terminal_states[:, 1], "or")
133-
plt.plot(timespan, batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated Velocity")
134-
plt.xlabel("Time (s)")
135-
plt.ylabel("Velocity (rad/s)")
136-
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
137-
plt.legend()
138-
139-
# Add phase portrait
140-
plt.subplot(1, 2, 2)
141-
plt.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2)
142-
plt.plot(
143-
batched_states_trajectories[:, 0], batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated"
116+
117+
# Replace the plotting section with:
118+
plot_simulation_errors(
119+
timespan,
120+
angle,
121+
velocity,
122+
batched_states_trajectories,
123+
predicted_terminal_points,
124+
interval_terminal_states,
125+
HORIZON,
126+
save_path="plots/simulation_error.png",
127+
show=True
144128
)
145-
plt.plot(predicted_terminal_points[:-1, 0], predicted_terminal_points[:-1, 1], "ob")
146-
plt.plot(interval_terminal_states[:, 0], interval_terminal_states[:, 1], "or")
147-
plt.xlabel("Angle (rad)")
148-
plt.ylabel("Angular Velocity (rad/s)")
149-
plt.title("Phase Portrait")
150-
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
151-
plt.legend()
152-
153-
plt.tight_layout()
154-
plt.show()
129+
155130
# TODO:
156131
# Optimization

examples/pendulum/TODO

+7-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
1-
☐ Jupyter notebook with pendulum learning
1+
☐ Jupyter notebook with pendulum learning
2+
☐ Parametric model
3+
☐ Simulation error
4+
☐ Residuals
5+
☐ Learning
6+
☐ Media animation
7+

examples/pendulum/_plotting_utils.py

+206
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import os
4+
from typing import List, Tuple, Union
5+
6+
def setup_plot(nrows: int = 2, ncols: int = 2, figsize: Tuple[int, int] = (10, 10)) -> Tuple[plt.Figure, Union[plt.Axes, List[plt.Axes]]]:
7+
"""
8+
Set up a matplotlib figure with the specified number of rows and columns.
9+
10+
Args:
11+
nrows (int): Number of rows in the subplot grid.
12+
ncols (int): Number of columns in the subplot grid.
13+
figsize (Tuple[int, int]): Figure size in inches (width, height).
14+
15+
Returns:
16+
Tuple[plt.Figure, Union[plt.Axes, List[plt.Axes]]]: Figure and axes objects.
17+
"""
18+
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
19+
if nrows * ncols == 1:
20+
axes = [axes]
21+
elif nrows == 1 or ncols == 1:
22+
axes = axes.flatten()
23+
return fig, axes
24+
25+
def plot_state(ax: plt.Axes, timespan: np.ndarray, actual: np.ndarray, simulated: np.ndarray, label: str, color: str = 'blue', alpha: float = 0.5) -> None:
26+
"""
27+
Plot actual and simulated state data on the given axes.
28+
29+
Args:
30+
ax (plt.Axes): The matplotlib axes to plot on.
31+
timespan (np.ndarray): Array of time points.
32+
actual (np.ndarray): Array of actual state values.
33+
simulated (np.ndarray): Array of simulated state values.
34+
label (str): Label for the state (e.g., "Angle" or "Velocity").
35+
color (str): Color for the simulated data plot.
36+
alpha (float): Alpha value for the simulated data plot.
37+
"""
38+
ax.plot(timespan, actual, label=f"Actual {label}", color="black", linestyle="dashed", linewidth=2)
39+
ax.plot(timespan, simulated, alpha=alpha, color=color, label=f"Simulated {label}")
40+
ax.set_ylabel(f"{label} (rad{'/' if label == 'Velocity' else ''}s)")
41+
ax.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
42+
ax.legend()
43+
44+
def plot_phase_portrait(ax: plt.Axes, angle: np.ndarray, velocity: np.ndarray, simulated_angle: np.ndarray, simulated_velocity: np.ndarray, color: str = 'blue', alpha: float = 0.5) -> None:
45+
"""
46+
Plot the phase portrait of actual and simulated data.
47+
48+
Args:
49+
ax (plt.Axes): The matplotlib axes to plot on.
50+
angle (np.ndarray): Array of actual angle values.
51+
velocity (np.ndarray): Array of actual velocity values.
52+
simulated_angle (np.ndarray): Array of simulated angle values.
53+
simulated_velocity (np.ndarray): Array of simulated velocity values.
54+
color (str): Color for the simulated data plot.
55+
alpha (float): Alpha value for the simulated data plot.
56+
"""
57+
ax.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2)
58+
ax.plot(simulated_angle, simulated_velocity, alpha=alpha, color=color, label="Simulated")
59+
ax.set_xlabel("Angle (rad)")
60+
ax.set_ylabel("Angular Velocity (rad/s)")
61+
ax.set_title("Phase Portrait")
62+
ax.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
63+
ax.legend()
64+
65+
def plot_simulation_errors(timespan: np.ndarray, angle: np.ndarray, velocity: np.ndarray, batched_states_trajectories: np.ndarray, predicted_terminal_points: np.ndarray, interval_terminal_states: np.ndarray, HORIZON: int, save_path: str = None, show: bool = False, title: str = "Simulation Errors", iteration: int = None) -> np.ndarray:
66+
"""
67+
Plot simulation errors for the pendulum system and return the frame as an image.
68+
69+
Args:
70+
timespan (np.ndarray): Array of time points.
71+
angle (np.ndarray): Array of actual angle values.
72+
velocity (np.ndarray): Array of actual velocity values.
73+
batched_states_trajectories (np.ndarray): Array of simulated state trajectories.
74+
predicted_terminal_points (np.ndarray): Array of predicted terminal points.
75+
interval_terminal_states (np.ndarray): Array of actual terminal states at intervals.
76+
HORIZON (int): Number of time steps in each interval.
77+
save_path (str): Path to save the plot. If None, the plot is not saved.
78+
show (bool): Whether to display the plot.
79+
title (str): Title for the plot.
80+
iteration (int, optional): Current iteration number for animation frames.
81+
82+
Returns:
83+
np.ndarray: Image array representing the current frame.
84+
"""
85+
fig = plt.figure(figsize=(12, 6))
86+
gs = fig.add_gridspec(2, 2)
87+
88+
ax1 = fig.add_subplot(gs[0, 0])
89+
ax2 = fig.add_subplot(gs[1, 0])
90+
ax3 = fig.add_subplot(gs[:, 1])
91+
92+
plot_state(ax1, timespan, angle, batched_states_trajectories[:, 0], "Angle")
93+
ax1.plot(timespan[HORIZON + 1 :][::HORIZON], predicted_terminal_points[:-1, 0], "ob", label="Predicted")
94+
ax1.plot(timespan[HORIZON + 1 :][::HORIZON], interval_terminal_states[:, 0], "or", label="Actual")
95+
if iteration is not None:
96+
ax1.set_title(f"{title} (Iteration {iteration})")
97+
else:
98+
ax1.set_title(title)
99+
ax1.legend(loc='upper right')
100+
101+
plot_state(ax2, timespan, velocity, batched_states_trajectories[:, 1], "Velocity")
102+
ax2.plot(timespan[HORIZON + 1 :][::HORIZON], predicted_terminal_points[:-1, 1], "ob", label="Predicted")
103+
ax2.plot(timespan[HORIZON + 1 :][::HORIZON], interval_terminal_states[:, 1], "or", label="Actual")
104+
ax2.set_xlabel("Time (s)")
105+
ax2.legend(loc='upper right')
106+
107+
plot_phase_portrait(ax3, angle, velocity, batched_states_trajectories[:, 0], batched_states_trajectories[:, 1])
108+
ax3.plot(predicted_terminal_points[:-1, 0], predicted_terminal_points[:-1, 1], "ob", label="Predicted")
109+
ax3.plot(interval_terminal_states[:, 0], interval_terminal_states[:, 1], "or", label="Actual")
110+
ax3.legend(loc='upper right')
111+
112+
plt.tight_layout()
113+
114+
if save_path:
115+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
116+
plt.savefig(save_path, dpi=300)
117+
118+
if show:
119+
plt.show()
120+
121+
# Convert plot to image array
122+
fig.canvas.draw()
123+
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
124+
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
125+
126+
plt.close(fig)
127+
128+
return image
129+
130+
def create_animation_frame(timespan: np.ndarray, true_trajectory: np.ndarray, current_rollout: np.ndarray, iteration: int) -> np.ndarray:
131+
"""
132+
Create a single frame for the animation of the learning process.
133+
134+
Args:
135+
timespan (np.ndarray): Array of time points.
136+
true_trajectory (np.ndarray): Array of actual state values.
137+
current_rollout (np.ndarray): Array of current simulated state values.
138+
iteration (int): Current iteration number.
139+
140+
Returns:
141+
np.ndarray: Image array representing the current frame.
142+
"""
143+
fig = plt.figure(figsize=(12, 6)) # Reduced height from 10 to 5
144+
gs = fig.add_gridspec(2, 2)
145+
146+
ax1 = fig.add_subplot(gs[0, 0])
147+
ax2 = fig.add_subplot(gs[1, 0])
148+
ax3 = fig.add_subplot(gs[:, 1])
149+
150+
plot_state(ax1, timespan, true_trajectory[:, 0], current_rollout[:, 0], "Angle", color="red")
151+
ax1.set_title(f"Iteration {iteration}")
152+
153+
plot_state(ax2, timespan, true_trajectory[:, 1], current_rollout[:, 1], "Velocity", color="red")
154+
ax2.set_xlabel("Time (s)")
155+
156+
plot_phase_portrait(ax3, true_trajectory[:, 0], true_trajectory[:, 1], current_rollout[:, 0], current_rollout[:, 1], color="red")
157+
ax3.set_title("Phase Portrait")
158+
159+
plt.tight_layout()
160+
161+
fig.canvas.draw()
162+
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
163+
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
164+
165+
plt.close(fig)
166+
167+
return image
168+
169+
def plot_full_simulation(timespan: np.ndarray, angle: np.ndarray, velocity: np.ndarray, old_rollout: np.ndarray, new_rollout: np.ndarray, save_path: str = "plots/learning_results.png", show: bool = True) -> None:
170+
"""
171+
Plot full simulation results for the pendulum system.
172+
173+
Args:
174+
timespan (np.ndarray): Array of time points.
175+
angle (np.ndarray): Array of actual angle values.
176+
velocity (np.ndarray): Array of actual velocity values.
177+
old_rollout (np.ndarray): Array of simulated states using the old model.
178+
new_rollout (np.ndarray): Array of simulated states using the new model.
179+
save_path (str): Path to save the plot.
180+
show (bool): Whether to display the plot.
181+
"""
182+
fig = plt.figure(figsize=(12, 6)) # Reduced height from 10 to 5
183+
gs = fig.add_gridspec(2, 2)
184+
185+
ax1 = fig.add_subplot(gs[0, 0])
186+
ax2 = fig.add_subplot(gs[1, 0])
187+
ax3 = fig.add_subplot(gs[:, 1])
188+
189+
plot_state(ax1, timespan, angle, old_rollout[:, 0], "Angle", color="blue", alpha=0.3)
190+
ax1.plot(timespan, new_rollout[:, 0], color="red", label="Optimized Model")
191+
192+
plot_state(ax2, timespan, velocity, old_rollout[:, 1], "Velocity", color="blue", alpha=0.3)
193+
ax2.plot(timespan, new_rollout[:, 1], color="red", label="Optimized Model")
194+
ax2.set_xlabel("Time (s)")
195+
196+
plot_phase_portrait(ax3, angle, velocity, old_rollout[:, 0], old_rollout[:, 1], color="blue", alpha=0.3)
197+
ax3.plot(new_rollout[:, 0], new_rollout[:, 1], color="red", label="Optimized Model")
198+
199+
plt.tight_layout()
200+
if save_path:
201+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
202+
plt.savefig(save_path, dpi=300)
203+
if show:
204+
plt.show()
205+
else:
206+
plt.close(fig)

0 commit comments

Comments
 (0)