diff --git a/fme/ace/aggregator/inference/video.py b/fme/ace/aggregator/inference/video.py index 1c4803ae..0e1f9af6 100644 --- a/fme/ace/aggregator/inference/video.py +++ b/fme/ace/aggregator/inference/video.py @@ -509,9 +509,13 @@ def _make_video( video_data = np.minimum(video_data, 255) video_data = np.maximum(video_data, 0) video_data[np.isnan(video_data)] = 0 + # Convert from single-channel grayscale to three-channel RGB to work + # around imageio error + video_data = video_data.repeat(3, axis=1) caption += f"; vmin={data_min:.4g}, vmax={data_max:.4g}" return wandb.Video( np.flip(video_data, axis=-2), caption=caption, fps=4, + format="gif", ) diff --git a/fme/sht_fix.py b/fme/sht_fix.py index ba2484b1..d1696e4a 100644 --- a/fme/sht_fix.py +++ b/fme/sht_fix.py @@ -42,7 +42,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -import numpy as np import torch import torch.nn as nn import torch.fft @@ -100,15 +99,14 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho raise(ValueError("Unknown quadrature mode")) # apply cosine transform and flip them - tq = np.flip(np.arccos(cost)) + tq = torch.flip(torch.arccos(cost), dims=(0,)) # determine the dimensions self.mmax = mmax or self.nlon // 2 + 1 # combine quadrature weights with the legendre weights - weights = torch.from_numpy(w) pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)) - weights = torch.einsum('mlk,k->mlk', pct, weights) + weights = torch.einsum('mlk,k->mlk', pct, w) # remember quadrature weights self.weights = weights.float().to(get_device()) @@ -181,7 +179,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho raise(ValueError("Unknown quadrature mode")) # apply cosine transform and flip them - t = np.flip(np.arccos(cost)) + t = torch.flip(torch.arccos(cost), dims=(0,)) # determine the dimensions self.mmax = mmax or self.nlon // 2 + 1 diff --git a/requirements.txt b/requirements.txt index 03f317c5..52dd58ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ dacite gcsfs h5netcdf h5py -imageio<=2.27.0 +imageio>=2.28.1 matplotlib moviepy netcdf4 @@ -14,7 +14,7 @@ plotly s3fs tensorly tensorly-torch -torch-harmonics==0.7.4 # pinned since we use private API +torch-harmonics==0.8.0 # pinned since we use private API torch>=2 wandb[media]>=0.19.0 xarray