Skip to content

Commit a40f16b

Browse files
Bump torch-harmonics to 0.8.0 and update constraint on imageio (#49)
This PR ports the changes made in ai2cm/full-model#2286 to our public repository. Changes: - Bumps the version of `torch-harmonics` to `0.8.0`. - Constrains the version of `imageio` to be greater than or equal to `2.28.1`. Per wandb/wandb#9887 and an additional fix, this means videos logged to WandB will be displayed properly again. - Modifies code in `sht_fix.py` to account for the fact that quadrature functions in `torch-harmonics` return `torch.tensor` objects instead of NumPy arrays as of version `0.7.6` (NVIDIA/torch-harmonics#66). - Converts video data from single-channel grayscale to three-channel RGB to work around an `imageio` error, and explicitly specifies the video format as `"gif"` to silence a warning. Relevant `wandb` and `imageio` issues and PRs: - imageio/imageio#904 - wandb/wandb#9790 Resolves #48
1 parent 1beed63 commit a40f16b

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

fme/ace/aggregator/inference/video.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,13 @@ def _make_video(
509509
video_data = np.minimum(video_data, 255)
510510
video_data = np.maximum(video_data, 0)
511511
video_data[np.isnan(video_data)] = 0
512+
# Convert from single-channel grayscale to three-channel RGB to work
513+
# around imageio error
514+
video_data = video_data.repeat(3, axis=1)
512515
caption += f"; vmin={data_min:.4g}, vmax={data_max:.4g}"
513516
return wandb.Video(
514517
np.flip(video_data, axis=-2),
515518
caption=caption,
516519
fps=4,
520+
format="gif",
517521
)

fme/sht_fix.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4343
#
4444

45-
import numpy as np
4645
import torch
4746
import torch.nn as nn
4847
import torch.fft
@@ -100,15 +99,14 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
10099
raise(ValueError("Unknown quadrature mode"))
101100

102101
# apply cosine transform and flip them
103-
tq = np.flip(np.arccos(cost))
102+
tq = torch.flip(torch.arccos(cost), dims=(0,))
104103

105104
# determine the dimensions
106105
self.mmax = mmax or self.nlon // 2 + 1
107106

108107
# combine quadrature weights with the legendre weights
109-
weights = torch.from_numpy(w)
110108
pct = torch.as_tensor(_precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase))
111-
weights = torch.einsum('mlk,k->mlk', pct, weights)
109+
weights = torch.einsum('mlk,k->mlk', pct, w)
112110

113111
# remember quadrature weights
114112
self.weights = weights.float().to(get_device())
@@ -181,7 +179,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
181179
raise(ValueError("Unknown quadrature mode"))
182180

183181
# apply cosine transform and flip them
184-
t = np.flip(np.arccos(cost))
182+
t = torch.flip(torch.arccos(cost), dims=(0,))
185183

186184
# determine the dimensions
187185
self.mmax = mmax or self.nlon // 2 + 1

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ dacite
33
gcsfs
44
h5netcdf
55
h5py
6-
imageio<=2.27.0
6+
imageio>=2.28.1
77
matplotlib
88
moviepy
99
netcdf4
@@ -14,7 +14,7 @@ plotly
1414
s3fs
1515
tensorly
1616
tensorly-torch
17-
torch-harmonics==0.7.4 # pinned since we use private API
17+
torch-harmonics==0.8.0 # pinned since we use private API
1818
torch>=2
1919
wandb[media]>=0.19.0
2020
xarray

0 commit comments

Comments
 (0)