Skip to content

Commit f68a4af

Browse files
committed
TST: implement testing of show or save plots.
1 parent dcde26a commit f68a4af

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

tests/unit/test_plots.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from unittest.mock import patch
1+
import os
2+
from unittest.mock import MagicMock, patch
23

34
import matplotlib.pyplot as plt
5+
import pytest
46

57
from rocketpy.plots.compare import Compare
8+
from rocketpy.plots.plot_helpers import show_or_save_fig, show_or_save_plot
69

710

811
@patch("matplotlib.pyplot.show")
@@ -38,3 +41,51 @@ def test_compare(mock_show, flight_calisto): # pylint: disable=unused-argument
3841
)
3942

4043
assert isinstance(fig, plt.Figure)
44+
45+
46+
@patch("matplotlib.pyplot.show")
47+
@pytest.mark.parametrize("filename", [None, "test.png"])
48+
def test_show_or_save_plot(mock_show, filename):
49+
"""This test is to check if the show_or_save_plot function is
50+
working properly.
51+
52+
Parameters
53+
----------
54+
mock_show :
55+
Mocks the matplotlib.pyplot.show() function to avoid showing
56+
the plots.
57+
filename : str
58+
Name of the file to save the plot. If None, the plot will be
59+
shown instead.
60+
"""
61+
plt.subplots()
62+
show_or_save_plot(filename)
63+
64+
if filename is None:
65+
mock_show.assert_called_once()
66+
else:
67+
assert os.path.exists(filename)
68+
os.remove(filename)
69+
70+
71+
@pytest.mark.parametrize("filename", [None, "test.png"])
72+
def test_show_or_save_fig(filename):
73+
"""This test is to check if the show_or_save_fig function is
74+
working properly.
75+
76+
Parameters
77+
----------
78+
filename : str
79+
Name of the file to save the plot. If None, the plot will be
80+
shown instead.
81+
"""
82+
fig, _ = plt.subplots()
83+
84+
fig.show = MagicMock()
85+
show_or_save_fig(fig, filename)
86+
87+
if filename is None:
88+
fig.show.assert_called_once()
89+
else:
90+
assert os.path.exists(filename)
91+
os.remove(filename)

0 commit comments

Comments
 (0)