Skip to content

DLESyM notebook example #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions earth2studio/lexicon/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class CDSLexicon(metaclass=LexiconType):
"tcwv": "reanalysis-era5-single-levels::total_column_water_vapour::",
"tp": "reanalysis-era5-single-levels::total_precipitation::",
"fg10m": "reanalysis-era5-single-levels::10m_wind_gust_since_previous_post_processing::",
"sst": "reanalysis-era5-single-levels::sea_surface_temperature::",
"u50": "reanalysis-era5-pressure-levels::u_component_of_wind::50",
"u100": "reanalysis-era5-pressure-levels::u_component_of_wind::100",
"u150": "reanalysis-era5-pressure-levels::u_component_of_wind::150",
Expand Down
185 changes: 161 additions & 24 deletions earth2studio/models/px/dlesym.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
import xarray as xr

try:
import earth2grid
Expand All @@ -32,6 +33,7 @@
OmegaConf = None
earth2grid = None
from earth2studio.models.auto import AutoModelMixin, Package
from earth2studio.models.batch import batch_coords, batch_func
from earth2studio.models.px.base import PrognosticModel
from earth2studio.models.px.utils import PrognosticMixin
from earth2studio.utils import check_extra_imports, handshake_coords, handshake_dim
Expand Down Expand Up @@ -85,20 +87,22 @@ class DLESyM(torch.nn.Module, AutoModelMixin, PrognosticMixin):
iterator = model.create_iterator(x, coords)

for step, (x, coords) in enumerate(iterator):
# Valid atmos and ocean predictions with their respective coordinates extracted below
atmos_outputs, atmos_coords = model.retrieve_valid_atmos_outputs(x, coords)
ocean_outputs, ocean_coords = model.retrieve_valid_ocean_outputs(x, coords)
...
if step > 0:
# Valid atmos and ocean predictions with their respective coordinates extracted below
atmos_outputs, atmos_coords = model.retrieve_valid_atmos_outputs(x, coords)
ocean_outputs, ocean_coords = model.retrieve_valid_ocean_outputs(x, coords)
...

Note
----
For more information about this model see:

- https://arxiv.org/abs/2409.16247
- https://arxiv.org/abs/2311.06253v2
- https://arxiv.org/abs/2409.16247
- https://arxiv.org/abs/2311.06253v2

For more information about the HEALPix grid see:
- https://github.com/NVlabs/earth2grid

- https://github.com/NVlabs/earth2grid

Parameters
----------
Expand Down Expand Up @@ -262,17 +266,17 @@ def __init__(

# Setup the variable indices for [atmos, ocean]
self.atmos_var_idx = [
list(in_coords["variable"]).index(var) for var in self.atmos_variables
list(out_coords["variable"]).index(var) for var in self.atmos_variables
]
self.ocean_var_idx = [
list(in_coords["variable"]).index(var) for var in self.ocean_variables
list(out_coords["variable"]).index(var) for var in self.ocean_variables
]
self.atmos_coupling_var_idx = [
list(in_coords["variable"]).index(var)
list(out_coords["variable"]).index(var)
for var in self.atmos_coupling_variables
]
self.ocean_coupling_var_idx = [
list(in_coords["variable"]).index(var)
list(out_coords["variable"]).index(var)
for var in self.ocean_coupling_variables
]

Expand All @@ -296,7 +300,7 @@ def input_coords(self) -> CoordSystem:
}
)

# @batch_coords()
@batch_coords()
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the prognostic model

Expand Down Expand Up @@ -345,10 +349,14 @@ def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
@classmethod
def load_default_package(cls) -> Package:
"""Default DLESyM model package on NGC"""
# TODO use NGC package when ready
raise NotImplementedError(
"DLESyM NGC package not yet available, but is expected May 2025!"
package = Package(
"ngc://models/nvidia/earth-2/[email protected]",
cache_options={
"cache_storage": Package.default_cache("dlesym"),
"same_names": True,
},
)
return package

@classmethod
@check_extra_imports("dlesym", [Module, OmegaConf])
Expand Down Expand Up @@ -705,15 +713,21 @@ def retrieve_valid_ocean_outputs(
Output coordinates
"""

self._validate_output_coords(coords)

var_dim = list(coords.keys()).index("variable")
lead_dim = list(coords.keys()).index("lead_time")
out_coords = coords.copy()
out_coords["variable"] = np.array(self.ocean_variables)
out_coords["lead_time"] = np.array(
[t for t in coords["lead_time"] if t % self.ocean_output_times[0] == 0]
)

ocean_outputs = x[:, :, self.ocean_output_lt_idx, ...]
ocean_outputs = x.index_select(
dim=var_dim, index=torch.tensor(self.ocean_var_idx, device=x.device)
)
ocean_outputs = ocean_outputs.index_select(
dim=3, index=torch.tensor(self.ocean_var_idx, device=x.device)
dim=lead_dim, index=torch.tensor(self.ocean_output_lt_idx, device=x.device)
)
return ocean_outputs, out_coords

Expand All @@ -738,13 +752,39 @@ def retrieve_valid_atmos_outputs(
Output coordinates
"""

self._validate_output_coords(coords)

var_dim = list(coords.keys()).index("variable")

out_coords = coords.copy()
out_coords["variable"] = np.array(self.atmos_variables)

atmos_outputs = x[:, :, :, self.atmos_var_idx, ...]
atmos_outputs = x.index_select(
dim=var_dim, index=torch.tensor(self.atmos_var_idx, device=x.device)
)

return atmos_outputs, out_coords

def _validate_output_coords(self, coords: CoordSystem) -> None:
"""Validate the coordinates passed to the output subselection methods

Parameters
----------
coords : CoordSystem
Output coordinates to be validated

Raises
------
ValueError
If the coordinates are invalid (missing or incorrect length lead_time dim)
"""
if "lead_time" not in coords:
raise ValueError("Lead time is required in the output coordinates")
if len(coords["lead_time"]) != len(self.atmos_output_times):
raise ValueError(
f"Lead time dimension length mismatch between model and coords: expected {len(self.atmos_output_times)}, got {len(coords['lead_time'])}"
)

@torch.inference_mode()
def _forward(
self,
Expand Down Expand Up @@ -792,7 +832,7 @@ def _next_step_inputs(

return next_x, next_coords

# @batch_func()
@batch_func()
def __call__(
self,
x: torch.Tensor,
Expand All @@ -817,7 +857,7 @@ def __call__(

return self._forward(x, coords), output_coords

# @batch_func()
@batch_func()
def _default_generator(
self, x: torch.Tensor, coords: CoordSystem
) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:
Expand Down Expand Up @@ -972,9 +1012,16 @@ def input_coords(self) -> CoordSystem:
"""
coords = super().input_coords()
coords = self.coords_to_ll(coords)

# Modify to use the base variables instead of the derived variables
input_variables = [
v for v in list(coords["variable"]) if v not in ["tau300-700", "ws10m"]
]
input_variables.extend(["u10m", "v10m", "z300", "z700"])
coords["variable"] = np.array(input_variables)
return coords

# @batch_coords()
@batch_coords()
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the prognostic model

Expand Down Expand Up @@ -1060,7 +1107,86 @@ def coords_to_ll(self, coords: CoordSystem) -> CoordSystem:
ll_coords.move_to_end(dim)
return ll_coords

# @batch_func()
def _nan_interpolate_sst(
self, sst: torch.Tensor, coords: CoordSystem
) -> torch.Tensor:
"""Custom interpolation to fill NaNs over landmasses in SST data."""

da_sst = xr.DataArray(sst.cpu().numpy(), dims=coords.keys())
da_interp = da_sst.interpolate_na(
dim="lon", method="linear", use_coordinate=False
)

# Second pass: roll, interpolate along longitude, and unroll
roll_amount_lon = int(len(da_interp.lon) / 2)
da_double_interp = (
da_interp.roll(lon=roll_amount_lon, roll_coords=False)
.interpolate_na(dim="lon", method="linear", use_coordinate=False)
.roll(lon=len(da_interp.lon) - roll_amount_lon, roll_coords=False)
)

# Third pass do a similar roll along latitude
roll_amount_lat = int(len(da_double_interp.lat) / 2)
da_triple_interp = (
da_double_interp.roll(lat=roll_amount_lat, roll_coords=False)
.interpolate_na(dim="lat", method="linear", use_coordinate=False)
.roll(lat=len(da_double_interp.lat) - roll_amount_lat, roll_coords=False)
)

return torch.from_numpy(da_triple_interp.values).to(sst.device)

def _prepare_derived_variables(
self, x: torch.Tensor, coords: CoordSystem
) -> tuple[torch.Tensor, CoordSystem]:
"""Prepare derived variables for the DLESyM model.

This method handles the preparation of derived variables from the input tensor
and coordinates. It ensures that the derived variables are correctly computed,
and performs NaN-interpolation on the SST data.

Parameters
----------
x : torch.Tensor
Input tensor
coords : CoordSystem
Input coordinate system

Returns
-------
tuple[torch.Tensor, CoordSystem]
Output tensor and coordinate system for the derived variables
"""

prep_coords = coords.copy()

# Fetch the base variables
base_vars = list(prep_coords["variable"])
src_vars = {
v: x[..., base_vars.index(v) : base_vars.index(v) + 1, :, :]
for v in base_vars
}

# Compute the derived variables
out_vars = {
"ws10m": torch.sqrt(src_vars["u10m"] ** 2 + src_vars["v10m"] ** 2),
"tau300-700": src_vars["z300"] - src_vars["z700"],
}
out_vars.update(src_vars)

# Fill SST nans by custom interpolation
out_vars["sst"] = self._nan_interpolate_sst(out_vars["sst"], coords)

# Update the tensor with the derived variables and return
prep_coords["variable"] = np.array(self.atmos_variables + self.ocean_variables)
x_out = torch.empty(
*[v.shape[0] for v in prep_coords.values()], device=x.device
)
for i, v in enumerate(prep_coords["variable"]):
x_out[..., i : i + 1, :, :] = out_vars[v]

return x_out, prep_coords

@batch_func()
def __call__(
self, x: torch.Tensor, coords: CoordSystem
) -> tuple[torch.Tensor, CoordSystem]:
Expand All @@ -1080,18 +1206,24 @@ def __call__(
"""
output_coords = self.output_coords(coords)

x, coords = self._prepare_derived_variables(x, coords)

x = self.to_hpx(x)
x = self._forward(x, self.coords_to_hpx(coords))
x = self.to_ll(x)
return x, output_coords

# @batch_func()
@batch_func()
def _default_generator(
self, x: torch.Tensor, coords: CoordSystem
) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:

coords = coords.copy()

base_vars = coords["variable"]

x, coords = self._prepare_derived_variables(x, coords)

yield x, coords

x = self.to_hpx(x)
Expand All @@ -1101,7 +1233,12 @@ def _default_generator(
x, coords = self.front_hook(x, coords)

x = self._forward(x, self.coords_to_hpx(coords))
coords = self.output_coords(coords)

# Output coords expects the input variable set to include base variables,
# but will return the ouptut variables with the derived variables
base_coords = coords.copy()
base_coords["variable"] = base_vars
coords = self.output_coords(base_coords)

# Rear hook
x, coords = self.rear_hook(x, coords)
Expand Down
Loading