Skip to content

Commit d8244a9

Browse files
committed
Update default colorbar outputs
1 parent 4f9a44a commit d8244a9

2 files changed

Lines changed: 9 additions & 6 deletions

File tree

src/autocast/eval/processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def _render_rollouts(
434434
batch_idx=sample_index,
435435
fps=fps,
436436
save_path=str(filename),
437+
colorbar_mode="column",
437438
)
438439
saved_paths.append(filename)
439440
rendered_batches.add(batch_idx)

src/autocast/utils/plots.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Literal
2+
13
import matplotlib.pyplot as plt
24
import numpy as np
35
from einops import rearrange
@@ -16,7 +18,7 @@ def plot_spatiotemporal_video( # noqa: PLR0915, PLR0912
1618
cmap="viridis",
1719
save_path=None,
1820
title="Ground Truth vs Prediction",
19-
colorbar_mode="none",
21+
colorbar_mode: Literal["none", "row", "column", "all"] = "none",
2022
channel_names=None,
2123
):
2224
"""Create a video comparing ground truth and predicted spatiotemporal time series.
@@ -56,9 +58,9 @@ def plot_spatiotemporal_video( # noqa: PLR0915, PLR0912
5658
animation.FuncAnimation
5759
Animation object that can be displayed in notebooks.
5860
"""
59-
colorbar_mode = colorbar_mode.lower()
61+
colorbar_mode_str = colorbar_mode.lower()
6062
valid_modes = {"none", "row", "column", "all"}
61-
if colorbar_mode not in valid_modes:
63+
if colorbar_mode_str not in valid_modes:
6264
raise ValueError(
6365
"Invalid colorbar_mode "
6466
f"'{colorbar_mode}'. Expected one of {sorted(valid_modes)}."
@@ -85,20 +87,20 @@ def _range_from_arrays(arrays):
8587

8688
norms = [[None] * C for _ in range(n_primary_rows)]
8789

88-
if colorbar_mode == "column":
90+
if colorbar_mode_str == "column":
8991
for ch in range(C):
9092
channel_arrays = [row[:, :, :, ch] for row in primary_rows]
9193
min_val, max_val = _range_from_arrays(channel_arrays)
9294
norm = Normalize(vmin=min_val, vmax=max_val)
9395
for row_idx in range(n_primary_rows):
9496
norms[row_idx][ch] = norm
95-
elif colorbar_mode == "row":
97+
elif colorbar_mode_str == "row":
9698
for row_idx, row in enumerate(primary_rows):
9799
min_val, max_val = _range_from_arrays([row])
98100
norm = Normalize(vmin=min_val, vmax=max_val)
99101
for ch in range(C):
100102
norms[row_idx][ch] = norm
101-
elif colorbar_mode == "all":
103+
elif colorbar_mode_str == "all":
102104
min_val, max_val = _range_from_arrays(primary_rows)
103105
norm = Normalize(vmin=min_val, vmax=max_val)
104106
for row_idx in range(n_primary_rows):

0 commit comments

Comments
 (0)