Skip to content

Commit 43a55a4

Browse files
authored
DLESyM notebook example (#300)
* draft dlesym example * update interpolation * DLESyM example complete, add NGC package * drop duplicate sst in arco * format fixes
1 parent 1e0788e commit 43a55a4

File tree

6 files changed

+3488
-3009
lines changed

6 files changed

+3488
-3009
lines changed

earth2studio/lexicon/cds.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class CDSLexicon(metaclass=LexiconType):
4545
"tcwv": "reanalysis-era5-single-levels::total_column_water_vapour::",
4646
"tp": "reanalysis-era5-single-levels::total_precipitation::",
4747
"fg10m": "reanalysis-era5-single-levels::10m_wind_gust_since_previous_post_processing::",
48+
"sst": "reanalysis-era5-single-levels::sea_surface_temperature::",
4849
"u50": "reanalysis-era5-pressure-levels::u_component_of_wind::50",
4950
"u100": "reanalysis-era5-pressure-levels::u_component_of_wind::100",
5051
"u150": "reanalysis-era5-pressure-levels::u_component_of_wind::150",

earth2studio/models/px/dlesym.py

Lines changed: 161 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222
import torch
23+
import xarray as xr
2324

2425
try:
2526
import earth2grid
@@ -32,6 +33,7 @@
3233
OmegaConf = None
3334
earth2grid = None
3435
from earth2studio.models.auto import AutoModelMixin, Package
36+
from earth2studio.models.batch import batch_coords, batch_func
3537
from earth2studio.models.px.base import PrognosticModel
3638
from earth2studio.models.px.utils import PrognosticMixin
3739
from earth2studio.utils import check_extra_imports, handshake_coords, handshake_dim
@@ -85,20 +87,22 @@ class DLESyM(torch.nn.Module, AutoModelMixin, PrognosticMixin):
8587
iterator = model.create_iterator(x, coords)
8688
8789
for step, (x, coords) in enumerate(iterator):
88-
# Valid atmos and ocean predictions with their respective coordinates extracted below
89-
atmos_outputs, atmos_coords = model.retrieve_valid_atmos_outputs(x, coords)
90-
ocean_outputs, ocean_coords = model.retrieve_valid_ocean_outputs(x, coords)
91-
...
90+
if step > 0:
91+
# Valid atmos and ocean predictions with their respective coordinates extracted below
92+
atmos_outputs, atmos_coords = model.retrieve_valid_atmos_outputs(x, coords)
93+
ocean_outputs, ocean_coords = model.retrieve_valid_ocean_outputs(x, coords)
94+
...
9295
9396
Note
9497
----
9598
For more information about this model see:
9699
97-
- https://arxiv.org/abs/2409.16247
98-
- https://arxiv.org/abs/2311.06253v2
100+
- https://arxiv.org/abs/2409.16247
101+
- https://arxiv.org/abs/2311.06253v2
99102
100103
For more information about the HEALPix grid see:
101-
- https://github.com/NVlabs/earth2grid
104+
105+
- https://github.com/NVlabs/earth2grid
102106
103107
Parameters
104108
----------
@@ -262,17 +266,17 @@ def __init__(
262266

263267
# Setup the variable indices for [atmos, ocean]
264268
self.atmos_var_idx = [
265-
list(in_coords["variable"]).index(var) for var in self.atmos_variables
269+
list(out_coords["variable"]).index(var) for var in self.atmos_variables
266270
]
267271
self.ocean_var_idx = [
268-
list(in_coords["variable"]).index(var) for var in self.ocean_variables
272+
list(out_coords["variable"]).index(var) for var in self.ocean_variables
269273
]
270274
self.atmos_coupling_var_idx = [
271-
list(in_coords["variable"]).index(var)
275+
list(out_coords["variable"]).index(var)
272276
for var in self.atmos_coupling_variables
273277
]
274278
self.ocean_coupling_var_idx = [
275-
list(in_coords["variable"]).index(var)
279+
list(out_coords["variable"]).index(var)
276280
for var in self.ocean_coupling_variables
277281
]
278282

@@ -296,7 +300,7 @@ def input_coords(self) -> CoordSystem:
296300
}
297301
)
298302

299-
# @batch_coords()
303+
@batch_coords()
300304
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
301305
"""Output coordinate system of the prognostic model
302306
@@ -345,10 +349,14 @@ def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
345349
@classmethod
346350
def load_default_package(cls) -> Package:
347351
"""Default DLESyM model package on NGC"""
348-
# TODO use NGC package when ready
349-
raise NotImplementedError(
350-
"DLESyM NGC package not yet available, but is expected May 2025!"
352+
package = Package(
353+
"ngc://models/nvidia/earth-2/[email protected]",
354+
cache_options={
355+
"cache_storage": Package.default_cache("dlesym"),
356+
"same_names": True,
357+
},
351358
)
359+
return package
352360

353361
@classmethod
354362
@check_extra_imports("dlesym", [Module, OmegaConf])
@@ -705,15 +713,21 @@ def retrieve_valid_ocean_outputs(
705713
Output coordinates
706714
"""
707715

716+
self._validate_output_coords(coords)
717+
718+
var_dim = list(coords.keys()).index("variable")
719+
lead_dim = list(coords.keys()).index("lead_time")
708720
out_coords = coords.copy()
709721
out_coords["variable"] = np.array(self.ocean_variables)
710722
out_coords["lead_time"] = np.array(
711723
[t for t in coords["lead_time"] if t % self.ocean_output_times[0] == 0]
712724
)
713725

714-
ocean_outputs = x[:, :, self.ocean_output_lt_idx, ...]
726+
ocean_outputs = x.index_select(
727+
dim=var_dim, index=torch.tensor(self.ocean_var_idx, device=x.device)
728+
)
715729
ocean_outputs = ocean_outputs.index_select(
716-
dim=3, index=torch.tensor(self.ocean_var_idx, device=x.device)
730+
dim=lead_dim, index=torch.tensor(self.ocean_output_lt_idx, device=x.device)
717731
)
718732
return ocean_outputs, out_coords
719733

@@ -738,13 +752,39 @@ def retrieve_valid_atmos_outputs(
738752
Output coordinates
739753
"""
740754

755+
self._validate_output_coords(coords)
756+
757+
var_dim = list(coords.keys()).index("variable")
758+
741759
out_coords = coords.copy()
742760
out_coords["variable"] = np.array(self.atmos_variables)
743761

744-
atmos_outputs = x[:, :, :, self.atmos_var_idx, ...]
762+
atmos_outputs = x.index_select(
763+
dim=var_dim, index=torch.tensor(self.atmos_var_idx, device=x.device)
764+
)
745765

746766
return atmos_outputs, out_coords
747767

768+
def _validate_output_coords(self, coords: CoordSystem) -> None:
769+
"""Validate the coordinates passed to the output subselection methods
770+
771+
Parameters
772+
----------
773+
coords : CoordSystem
774+
Output coordinates to be validated
775+
776+
Raises
777+
------
778+
ValueError
779+
If the coordinates are invalid (missing or incorrect length lead_time dim)
780+
"""
781+
if "lead_time" not in coords:
782+
raise ValueError("Lead time is required in the output coordinates")
783+
if len(coords["lead_time"]) != len(self.atmos_output_times):
784+
raise ValueError(
785+
f"Lead time dimension length mismatch between model and coords: expected {len(self.atmos_output_times)}, got {len(coords['lead_time'])}"
786+
)
787+
748788
@torch.inference_mode()
749789
def _forward(
750790
self,
@@ -792,7 +832,7 @@ def _next_step_inputs(
792832

793833
return next_x, next_coords
794834

795-
# @batch_func()
835+
@batch_func()
796836
def __call__(
797837
self,
798838
x: torch.Tensor,
@@ -817,7 +857,7 @@ def __call__(
817857

818858
return self._forward(x, coords), output_coords
819859

820-
# @batch_func()
860+
@batch_func()
821861
def _default_generator(
822862
self, x: torch.Tensor, coords: CoordSystem
823863
) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:
@@ -972,9 +1012,16 @@ def input_coords(self) -> CoordSystem:
9721012
"""
9731013
coords = super().input_coords()
9741014
coords = self.coords_to_ll(coords)
1015+
1016+
# Modify to use the base variables instead of the derived variables
1017+
input_variables = [
1018+
v for v in list(coords["variable"]) if v not in ["tau300-700", "ws10m"]
1019+
]
1020+
input_variables.extend(["u10m", "v10m", "z300", "z700"])
1021+
coords["variable"] = np.array(input_variables)
9751022
return coords
9761023

977-
# @batch_coords()
1024+
@batch_coords()
9781025
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
9791026
"""Output coordinate system of the prognostic model
9801027
@@ -1060,7 +1107,86 @@ def coords_to_ll(self, coords: CoordSystem) -> CoordSystem:
10601107
ll_coords.move_to_end(dim)
10611108
return ll_coords
10621109

1063-
# @batch_func()
1110+
def _nan_interpolate_sst(
1111+
self, sst: torch.Tensor, coords: CoordSystem
1112+
) -> torch.Tensor:
1113+
"""Custom interpolation to fill NaNs over landmasses in SST data."""
1114+
1115+
da_sst = xr.DataArray(sst.cpu().numpy(), dims=coords.keys())
1116+
da_interp = da_sst.interpolate_na(
1117+
dim="lon", method="linear", use_coordinate=False
1118+
)
1119+
1120+
# Second pass: roll, interpolate along longitude, and unroll
1121+
roll_amount_lon = int(len(da_interp.lon) / 2)
1122+
da_double_interp = (
1123+
da_interp.roll(lon=roll_amount_lon, roll_coords=False)
1124+
.interpolate_na(dim="lon", method="linear", use_coordinate=False)
1125+
.roll(lon=len(da_interp.lon) - roll_amount_lon, roll_coords=False)
1126+
)
1127+
1128+
# Third pass do a similar roll along latitude
1129+
roll_amount_lat = int(len(da_double_interp.lat) / 2)
1130+
da_triple_interp = (
1131+
da_double_interp.roll(lat=roll_amount_lat, roll_coords=False)
1132+
.interpolate_na(dim="lat", method="linear", use_coordinate=False)
1133+
.roll(lat=len(da_double_interp.lat) - roll_amount_lat, roll_coords=False)
1134+
)
1135+
1136+
return torch.from_numpy(da_triple_interp.values).to(sst.device)
1137+
1138+
def _prepare_derived_variables(
1139+
self, x: torch.Tensor, coords: CoordSystem
1140+
) -> tuple[torch.Tensor, CoordSystem]:
1141+
"""Prepare derived variables for the DLESyM model.
1142+
1143+
This method handles the preparation of derived variables from the input tensor
1144+
and coordinates. It ensures that the derived variables are correctly computed,
1145+
and performs NaN-interpolation on the SST data.
1146+
1147+
Parameters
1148+
----------
1149+
x : torch.Tensor
1150+
Input tensor
1151+
coords : CoordSystem
1152+
Input coordinate system
1153+
1154+
Returns
1155+
-------
1156+
tuple[torch.Tensor, CoordSystem]
1157+
Output tensor and coordinate system for the derived variables
1158+
"""
1159+
1160+
prep_coords = coords.copy()
1161+
1162+
# Fetch the base variables
1163+
base_vars = list(prep_coords["variable"])
1164+
src_vars = {
1165+
v: x[..., base_vars.index(v) : base_vars.index(v) + 1, :, :]
1166+
for v in base_vars
1167+
}
1168+
1169+
# Compute the derived variables
1170+
out_vars = {
1171+
"ws10m": torch.sqrt(src_vars["u10m"] ** 2 + src_vars["v10m"] ** 2),
1172+
"tau300-700": src_vars["z300"] - src_vars["z700"],
1173+
}
1174+
out_vars.update(src_vars)
1175+
1176+
# Fill SST nans by custom interpolation
1177+
out_vars["sst"] = self._nan_interpolate_sst(out_vars["sst"], coords)
1178+
1179+
# Update the tensor with the derived variables and return
1180+
prep_coords["variable"] = np.array(self.atmos_variables + self.ocean_variables)
1181+
x_out = torch.empty(
1182+
*[v.shape[0] for v in prep_coords.values()], device=x.device
1183+
)
1184+
for i, v in enumerate(prep_coords["variable"]):
1185+
x_out[..., i : i + 1, :, :] = out_vars[v]
1186+
1187+
return x_out, prep_coords
1188+
1189+
@batch_func()
10641190
def __call__(
10651191
self, x: torch.Tensor, coords: CoordSystem
10661192
) -> tuple[torch.Tensor, CoordSystem]:
@@ -1080,18 +1206,24 @@ def __call__(
10801206
"""
10811207
output_coords = self.output_coords(coords)
10821208

1209+
x, coords = self._prepare_derived_variables(x, coords)
1210+
10831211
x = self.to_hpx(x)
10841212
x = self._forward(x, self.coords_to_hpx(coords))
10851213
x = self.to_ll(x)
10861214
return x, output_coords
10871215

1088-
# @batch_func()
1216+
@batch_func()
10891217
def _default_generator(
10901218
self, x: torch.Tensor, coords: CoordSystem
10911219
) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:
10921220

10931221
coords = coords.copy()
10941222

1223+
base_vars = coords["variable"]
1224+
1225+
x, coords = self._prepare_derived_variables(x, coords)
1226+
10951227
yield x, coords
10961228

10971229
x = self.to_hpx(x)
@@ -1101,7 +1233,12 @@ def _default_generator(
11011233
x, coords = self.front_hook(x, coords)
11021234

11031235
x = self._forward(x, self.coords_to_hpx(coords))
1104-
coords = self.output_coords(coords)
1236+
1237+
# Output coords expects the input variable set to include base variables,
1238+
# but will return the ouptut variables with the derived variables
1239+
base_coords = coords.copy()
1240+
base_coords["variable"] = base_vars
1241+
coords = self.output_coords(base_coords)
11051242

11061243
# Rear hook
11071244
x, coords = self.rear_hook(x, coords)

0 commit comments

Comments
 (0)