Skip to content

Commit

Permalink
Resolved earthlab#796
Browse files Browse the repository at this point in the history
* Added logic to earthpy.plot.plot_bands to validate dimensions of
  ax array, if provided, when raster array is multiband.
  Raises a ValueError if the provided axes array is smaller
  than the number of bands.
* Added a test for the success and failure states.
  • Loading branch information
ahasha committed Nov 18, 2022
1 parent 33751d6 commit 40da774
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
22 changes: 19 additions & 3 deletions earthpy/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def plot_bands(
Specify the vmin to scale imshow() plots.
vmax : Int (Optional)
Specify the vmax to scale imshow() plots.
ax : object(s) (optional)
The axes object(s) where the ax element should be plotted.
alpha : float (optional)
The alpha value for the plot. This will help adjust the transparency
of the plot to the desired level.
Expand Down Expand Up @@ -248,8 +250,21 @@ def plot_bands(
total_layers = arr.shape[0]

# Plot all bands
fig, axs = plt.subplots(plot_rows, cols, figsize=figsize)
axs_ravel = axs.ravel()
if ax is None:
fig, axs = plt.subplots(plot_rows, cols, figsize=figsize)
axs_ravel = axs.ravel()
show = True
else:
if not isinstance(ax, np.ndarray) or len(ax.ravel()) < arr.shape[0]:
raise ValueError(
"plot_bands expects the ax keyword argument "
"to be a numpy.ndarray with number of elements "
"greater than or equal to the number of array raster layers."
)
axs = ax
axs_ravel = ax.ravel()


for ax, i in zip(axs_ravel, range(total_layers)):
band = i + 1

Expand Down Expand Up @@ -280,7 +295,8 @@ def plot_bands(
ax.set_axis_off()
ax.set(xticks=[], yticks=[])
plt.tight_layout()
plt.show()
if show:
plt.show()
return axs

elif arr.ndim == 2 or arr.shape[0] == 1:
Expand Down
13 changes: 13 additions & 0 deletions earthpy/tests/test_plot_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,19 @@ def test_multi_panel_single_band(one_band_3dims):
assert all_axes[1].get_title() == title2


def test_ax_argument_multi_band(image_array_3bands):
"""Test that ax keyword argument is used for multi band arr."""
f, axs = plt.subplots(3, 1)
axs2 = ep.plot_bands(image_array_3bands, ax=axs)

assert np.all(axs == axs2)

f, axs = plt.subplots(1, 2)
with pytest.raises(ValueError, match=r"number of elements"):
axs3 = ep.plot_bands(image_array_3bands, ax=axs)



def test_alpha(image_array_2bands):
"""Test that the alpha param returns a plot with the correct alpha."""
alpha_val = 0.5
Expand Down

0 comments on commit 40da774

Please sign in to comment.