Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 62 additions & 17 deletions src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
spatial_unit="um",
variable_limits="fixed",
n_t_linear=100,
x_axis="Time",
):
solutions = self.preprocess_solutions(solutions)

Expand Down Expand Up @@ -214,6 +215,33 @@ def t_sample(sol):
self.min_t = min_t / time_scaling_factor
self.max_t = max_t / time_scaling_factor

# set x_axis
self.x_axis = x_axis

if x_axis == "Discharge capacity [A.h]":
# Use discharge capacity as x-axis
discharge_capacities = [
solution["Discharge capacity [A.h]"].entries for solution in solutions
]
self.x_values = discharge_capacities

self.x_axis_min = min(dc[0] for dc in discharge_capacities)
self.x_axis_max = max(dc[-1] for dc in discharge_capacities)
self.x_scaling_factor = 1
self.x_unit = "A.h"

elif x_axis == "Time":
self.x_values = ts_seconds

self.x_axis_min = self.min_t
self.x_axis_max = self.max_t
self.x_scaling_factor = self.time_scaling_factor

self.x_unit = self.time_unit
else:
msg = "Invalid value for `x_axis`."
raise ValueError(msg)

# Prepare dictionary of variables
# output_variables is a list of strings or lists, e.g.
# ["var 1", ["variable 2", "var 3"]]
Expand Down Expand Up @@ -413,8 +441,8 @@ def reset_axis(self):
self.axis_limits = {}
for key, variable_lists in self.variables.items():
if variable_lists[0][0].dimensions == 0:
x_min = self.min_t
x_max = self.max_t
x_min = self.x_axis_min
x_max = self.x_axis_max
elif variable_lists[0][0].dimensions == 1:
x_min = self.first_spatial_variable[key][0]
x_max = self.first_spatial_variable[key][-1]
Expand All @@ -436,7 +464,7 @@ def reset_axis(self):

# Get min and max variable values
if self.variable_limits[key] == "fixed":
# fixed variable limits: calculate "globlal" min and max
# fixed variable limits: calculate "global" min and max
spatial_vars = self.spatial_variable_dict[key]
var_min = np.min(
[
Expand Down Expand Up @@ -520,7 +548,11 @@ def plot(self, t, dynamic=False):
# Set labels for the first subplot only (avoid repetition)
if variable_lists[0][0].dimensions == 0:
# 0D plot: plot as a function of time, indicating time t with a line
ax.set_xlabel(f"Time [{self.time_unit}]")
if self.x_axis == "Time":
ax.set_xlabel(f"Time [{self.time_unit}]")
if self.x_axis == "Discharge capacity [A.h]":
ax.set_xlabel("Discharge capacity [A.h]")

for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
if len(variable_list) == 1:
Expand All @@ -530,10 +562,10 @@ def plot(self, t, dynamic=False):
# multiple variables -> use linestyle to differentiate
# variables (color differentiates models)
linestyle = self.linestyles[j]
full_t = self.ts_seconds[i]
full_val = self.x_values[i]
(self.plots[key][i][j],) = ax.plot(
full_t / self.time_scaling_factor,
variable(full_t),
full_val / self.x_scaling_factor,
variable(full_val),
color=self.colors[i],
linestyle=linestyle,
)
Expand Down Expand Up @@ -667,13 +699,13 @@ def plot(self, t, dynamic=False):

def dynamic_plot(self, show_plot=True, step=None):
"""
Generate a dynamic plot with a slider to control the time.
Generate a dynamic plot with a slider to control the x-axis.

Parameters
----------
step : float, optional
For notebook mode, size of steps to allow in the slider. Defaults to 1/100th
of the total time.
of the total range (time or discharge capacity).
show_plot : bool, optional
Whether to show the plots. Default is True. Set to False if you want to
only display the plot after plt.show() has been called.
Expand All @@ -682,29 +714,42 @@ def dynamic_plot(self, show_plot=True, step=None):
if pybamm.is_notebook(): # pragma: no cover
import ipywidgets as widgets

step = step or self.max_t / 100
step = step or (self.x_axis_max - self.x_axis_min) / 100
widgets.interact(
lambda t: self.plot(t, dynamic=False),
t=widgets.FloatSlider(
min=self.min_t, max=self.max_t, step=step, value=self.min_t
min=self.x_axis_min,
max=self.x_axis_max,
step=step,
value=self.x_axis_min,
),
continuous_update=False,
)
else:
plt = import_optional_dependency("matplotlib.pyplot")
Slider = import_optional_dependency("matplotlib.widgets", "Slider")

# create an initial plot at time self.min_t
self.plot(self.min_t, dynamic=True)
# Set initial x-axis values and slider
self.plot(self.x_axis_min, dynamic=True)

# Set x-axis label correctly
if self.x_axis == "Time":
ax_label = f"Time [{self.time_unit}]"
elif self.x_axis == "Discharge capacity [A.h]":
ax_label = "Discharge capacity [A.h]"
else:
ax_label = self.x_axis # Use the string directly if unknown

ax_min, ax_max, val_init = self.x_axis_min, self.x_axis_max, self.x_axis_min

axcolor = "lightgoldenrodyellow"
ax_slider = plt.axes([0.315, 0.02, 0.37, 0.03], facecolor=axcolor)
self.slider = Slider(
ax_slider,
f"Time [{self.time_unit}]",
self.min_t,
self.max_t,
valinit=self.min_t,
ax_label,
ax_min,
ax_max,
valinit=val_init,
color="#1f77b4",
)
self.slider.on_changed(self.slider_update)
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,43 @@ def test_simple_ode_model(self, solver):

pybamm.close_plots()

def test_invalid_x_axis(self):
model = pybamm.lithium_ion.SPM()
sim = pybamm.Simulation(model)
solution = sim.solve([0, 3600])

with pytest.raises(ValueError, match="Invalid value for `x_axis`."):
pybamm.QuickPlot([solution], x_axis="Invalid axis")

def test_plot_with_discharge_capacity(self):
model = pybamm.lithium_ion.BaseModel(name="Simple ODE Model")
a = pybamm.Variable("a", domain=[])
model.rhs = {a: pybamm.Scalar(0.2)}
model.initial_conditions = {a: pybamm.Scalar(0)}
model.variables = {"a": a, "Discharge capacity [A.h]": a * 2}

t_eval = np.linspace(0, 2, 100)
solution = pybamm.CasadiSolver().solve(model, t_eval)

quick_plot = pybamm.QuickPlot(
solution,
["a"],
x_axis="Discharge capacity [A.h]",
)
quick_plot.plot(0)

# Test discharge capacity values
np.testing.assert_allclose(
quick_plot.plots[("a",)][0][0].get_xdata(),
solution["Discharge capacity [A.h]"].data,
)

# Test x-axis label
x_label = quick_plot.fig.axes[0].get_xlabel()
assert x_label == "Discharge capacity [A.h]", (
f"Expected 'Discharge capacity [A.h]', got '{x_label}'"
)

def test_plot_with_different_models(self):
model = pybamm.BaseModel()
a = pybamm.Variable("a")
Expand Down