1+ from typing import Literal
2+
13import matplotlib .pyplot as plt
24import numpy as np
35from einops import rearrange
@@ -16,7 +18,7 @@ def plot_spatiotemporal_video( # noqa: PLR0915, PLR0912
1618 cmap = "viridis" ,
1719 save_path = None ,
1820 title = "Ground Truth vs Prediction" ,
19- colorbar_mode = "none" ,
21+ colorbar_mode : Literal [ "none" , "row" , "column" , "all" ] = "none" ,
2022 channel_names = None ,
2123):
2224 """Create a video comparing ground truth and predicted spatiotemporal time series.
@@ -56,9 +58,9 @@ def plot_spatiotemporal_video( # noqa: PLR0915, PLR0912
5658 animation.FuncAnimation
5759 Animation object that can be displayed in notebooks.
5860 """
59- colorbar_mode = colorbar_mode .lower ()
61+ colorbar_mode_str = colorbar_mode .lower ()
6062 valid_modes = {"none" , "row" , "column" , "all" }
61- if colorbar_mode not in valid_modes :
63+ if colorbar_mode_str not in valid_modes :
6264 raise ValueError (
6365 "Invalid colorbar_mode "
6466 f"'{ colorbar_mode } '. Expected one of { sorted (valid_modes )} ."
@@ -85,20 +87,20 @@ def _range_from_arrays(arrays):
8587
8688 norms = [[None ] * C for _ in range (n_primary_rows )]
8789
88- if colorbar_mode == "column" :
90+ if colorbar_mode_str == "column" :
8991 for ch in range (C ):
9092 channel_arrays = [row [:, :, :, ch ] for row in primary_rows ]
9193 min_val , max_val = _range_from_arrays (channel_arrays )
9294 norm = Normalize (vmin = min_val , vmax = max_val )
9395 for row_idx in range (n_primary_rows ):
9496 norms [row_idx ][ch ] = norm
95- elif colorbar_mode == "row" :
97+ elif colorbar_mode_str == "row" :
9698 for row_idx , row in enumerate (primary_rows ):
9799 min_val , max_val = _range_from_arrays ([row ])
98100 norm = Normalize (vmin = min_val , vmax = max_val )
99101 for ch in range (C ):
100102 norms [row_idx ][ch ] = norm
101- elif colorbar_mode == "all" :
103+ elif colorbar_mode_str == "all" :
102104 min_val , max_val = _range_from_arrays (primary_rows )
103105 norm = Normalize (vmin = min_val , vmax = max_val )
104106 for row_idx in range (n_primary_rows ):
0 commit comments