20
20
21
21
import numpy as np
22
22
import torch
23
+ import xarray as xr
23
24
24
25
try :
25
26
import earth2grid
32
33
OmegaConf = None
33
34
earth2grid = None
34
35
from earth2studio .models .auto import AutoModelMixin , Package
36
+ from earth2studio .models .batch import batch_coords , batch_func
35
37
from earth2studio .models .px .base import PrognosticModel
36
38
from earth2studio .models .px .utils import PrognosticMixin
37
39
from earth2studio .utils import check_extra_imports , handshake_coords , handshake_dim
@@ -85,20 +87,22 @@ class DLESyM(torch.nn.Module, AutoModelMixin, PrognosticMixin):
85
87
iterator = model.create_iterator(x, coords)
86
88
87
89
for step, (x, coords) in enumerate(iterator):
88
- # Valid atmos and ocean predictions with their respective coordinates extracted below
89
- atmos_outputs, atmos_coords = model.retrieve_valid_atmos_outputs(x, coords)
90
- ocean_outputs, ocean_coords = model.retrieve_valid_ocean_outputs(x, coords)
91
- ...
90
+ if step > 0:
91
+ # Valid atmos and ocean predictions with their respective coordinates extracted below
92
+ atmos_outputs, atmos_coords = model.retrieve_valid_atmos_outputs(x, coords)
93
+ ocean_outputs, ocean_coords = model.retrieve_valid_ocean_outputs(x, coords)
94
+ ...
92
95
93
96
Note
94
97
----
95
98
For more information about this model see:
96
99
97
- - https://arxiv.org/abs/2409.16247
98
- - https://arxiv.org/abs/2311.06253v2
100
+ - https://arxiv.org/abs/2409.16247
101
+ - https://arxiv.org/abs/2311.06253v2
99
102
100
103
For more information about the HEALPix grid see:
101
- - https://github.com/NVlabs/earth2grid
104
+
105
+ - https://github.com/NVlabs/earth2grid
102
106
103
107
Parameters
104
108
----------
@@ -262,17 +266,17 @@ def __init__(
262
266
263
267
# Setup the variable indices for [atmos, ocean]
264
268
self .atmos_var_idx = [
265
- list (in_coords ["variable" ]).index (var ) for var in self .atmos_variables
269
+ list (out_coords ["variable" ]).index (var ) for var in self .atmos_variables
266
270
]
267
271
self .ocean_var_idx = [
268
- list (in_coords ["variable" ]).index (var ) for var in self .ocean_variables
272
+ list (out_coords ["variable" ]).index (var ) for var in self .ocean_variables
269
273
]
270
274
self .atmos_coupling_var_idx = [
271
- list (in_coords ["variable" ]).index (var )
275
+ list (out_coords ["variable" ]).index (var )
272
276
for var in self .atmos_coupling_variables
273
277
]
274
278
self .ocean_coupling_var_idx = [
275
- list (in_coords ["variable" ]).index (var )
279
+ list (out_coords ["variable" ]).index (var )
276
280
for var in self .ocean_coupling_variables
277
281
]
278
282
@@ -296,7 +300,7 @@ def input_coords(self) -> CoordSystem:
296
300
}
297
301
)
298
302
299
- # @batch_coords()
303
+ @batch_coords ()
300
304
def output_coords (self , input_coords : CoordSystem ) -> CoordSystem :
301
305
"""Output coordinate system of the prognostic model
302
306
@@ -345,10 +349,14 @@ def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
345
349
@classmethod
346
350
def load_default_package (cls ) -> Package :
347
351
"""Default DLESyM model package on NGC"""
348
- # TODO use NGC package when ready
349
- raise NotImplementedError (
350
- "DLESyM NGC package not yet available, but is expected May 2025!"
352
+ package = Package (
353
+ "ngc://models/nvidia/earth-2/[email protected] " ,
354
+ cache_options = {
355
+ "cache_storage" : Package .default_cache ("dlesym" ),
356
+ "same_names" : True ,
357
+ },
351
358
)
359
+ return package
352
360
353
361
@classmethod
354
362
@check_extra_imports ("dlesym" , [Module , OmegaConf ])
@@ -705,15 +713,21 @@ def retrieve_valid_ocean_outputs(
705
713
Output coordinates
706
714
"""
707
715
716
+ self ._validate_output_coords (coords )
717
+
718
+ var_dim = list (coords .keys ()).index ("variable" )
719
+ lead_dim = list (coords .keys ()).index ("lead_time" )
708
720
out_coords = coords .copy ()
709
721
out_coords ["variable" ] = np .array (self .ocean_variables )
710
722
out_coords ["lead_time" ] = np .array (
711
723
[t for t in coords ["lead_time" ] if t % self .ocean_output_times [0 ] == 0 ]
712
724
)
713
725
714
- ocean_outputs = x [:, :, self .ocean_output_lt_idx , ...]
726
+ ocean_outputs = x .index_select (
727
+ dim = var_dim , index = torch .tensor (self .ocean_var_idx , device = x .device )
728
+ )
715
729
ocean_outputs = ocean_outputs .index_select (
716
- dim = 3 , index = torch .tensor (self .ocean_var_idx , device = x .device )
730
+ dim = lead_dim , index = torch .tensor (self .ocean_output_lt_idx , device = x .device )
717
731
)
718
732
return ocean_outputs , out_coords
719
733
@@ -738,13 +752,39 @@ def retrieve_valid_atmos_outputs(
738
752
Output coordinates
739
753
"""
740
754
755
+ self ._validate_output_coords (coords )
756
+
757
+ var_dim = list (coords .keys ()).index ("variable" )
758
+
741
759
out_coords = coords .copy ()
742
760
out_coords ["variable" ] = np .array (self .atmos_variables )
743
761
744
- atmos_outputs = x [:, :, :, self .atmos_var_idx , ...]
762
+ atmos_outputs = x .index_select (
763
+ dim = var_dim , index = torch .tensor (self .atmos_var_idx , device = x .device )
764
+ )
745
765
746
766
return atmos_outputs , out_coords
747
767
768
+ def _validate_output_coords (self , coords : CoordSystem ) -> None :
769
+ """Validate the coordinates passed to the output subselection methods
770
+
771
+ Parameters
772
+ ----------
773
+ coords : CoordSystem
774
+ Output coordinates to be validated
775
+
776
+ Raises
777
+ ------
778
+ ValueError
779
+ If the coordinates are invalid (missing or incorrect length lead_time dim)
780
+ """
781
+ if "lead_time" not in coords :
782
+ raise ValueError ("Lead time is required in the output coordinates" )
783
+ if len (coords ["lead_time" ]) != len (self .atmos_output_times ):
784
+ raise ValueError (
785
+ f"Lead time dimension length mismatch between model and coords: expected { len (self .atmos_output_times )} , got { len (coords ['lead_time' ])} "
786
+ )
787
+
748
788
@torch .inference_mode ()
749
789
def _forward (
750
790
self ,
@@ -792,7 +832,7 @@ def _next_step_inputs(
792
832
793
833
return next_x , next_coords
794
834
795
- # @batch_func()
835
+ @batch_func ()
796
836
def __call__ (
797
837
self ,
798
838
x : torch .Tensor ,
@@ -817,7 +857,7 @@ def __call__(
817
857
818
858
return self ._forward (x , coords ), output_coords
819
859
820
- # @batch_func()
860
+ @batch_func ()
821
861
def _default_generator (
822
862
self , x : torch .Tensor , coords : CoordSystem
823
863
) -> Generator [tuple [torch .Tensor , CoordSystem ], None , None ]:
@@ -972,9 +1012,16 @@ def input_coords(self) -> CoordSystem:
972
1012
"""
973
1013
coords = super ().input_coords ()
974
1014
coords = self .coords_to_ll (coords )
1015
+
1016
+ # Modify to use the base variables instead of the derived variables
1017
+ input_variables = [
1018
+ v for v in list (coords ["variable" ]) if v not in ["tau300-700" , "ws10m" ]
1019
+ ]
1020
+ input_variables .extend (["u10m" , "v10m" , "z300" , "z700" ])
1021
+ coords ["variable" ] = np .array (input_variables )
975
1022
return coords
976
1023
977
- # @batch_coords()
1024
+ @batch_coords ()
978
1025
def output_coords (self , input_coords : CoordSystem ) -> CoordSystem :
979
1026
"""Output coordinate system of the prognostic model
980
1027
@@ -1060,7 +1107,86 @@ def coords_to_ll(self, coords: CoordSystem) -> CoordSystem:
1060
1107
ll_coords .move_to_end (dim )
1061
1108
return ll_coords
1062
1109
1063
- # @batch_func()
1110
+ def _nan_interpolate_sst (
1111
+ self , sst : torch .Tensor , coords : CoordSystem
1112
+ ) -> torch .Tensor :
1113
+ """Custom interpolation to fill NaNs over landmasses in SST data."""
1114
+
1115
+ da_sst = xr .DataArray (sst .cpu ().numpy (), dims = coords .keys ())
1116
+ da_interp = da_sst .interpolate_na (
1117
+ dim = "lon" , method = "linear" , use_coordinate = False
1118
+ )
1119
+
1120
+ # Second pass: roll, interpolate along longitude, and unroll
1121
+ roll_amount_lon = int (len (da_interp .lon ) / 2 )
1122
+ da_double_interp = (
1123
+ da_interp .roll (lon = roll_amount_lon , roll_coords = False )
1124
+ .interpolate_na (dim = "lon" , method = "linear" , use_coordinate = False )
1125
+ .roll (lon = len (da_interp .lon ) - roll_amount_lon , roll_coords = False )
1126
+ )
1127
+
1128
+ # Third pass do a similar roll along latitude
1129
+ roll_amount_lat = int (len (da_double_interp .lat ) / 2 )
1130
+ da_triple_interp = (
1131
+ da_double_interp .roll (lat = roll_amount_lat , roll_coords = False )
1132
+ .interpolate_na (dim = "lat" , method = "linear" , use_coordinate = False )
1133
+ .roll (lat = len (da_double_interp .lat ) - roll_amount_lat , roll_coords = False )
1134
+ )
1135
+
1136
+ return torch .from_numpy (da_triple_interp .values ).to (sst .device )
1137
+
1138
+ def _prepare_derived_variables (
1139
+ self , x : torch .Tensor , coords : CoordSystem
1140
+ ) -> tuple [torch .Tensor , CoordSystem ]:
1141
+ """Prepare derived variables for the DLESyM model.
1142
+
1143
+ This method handles the preparation of derived variables from the input tensor
1144
+ and coordinates. It ensures that the derived variables are correctly computed,
1145
+ and performs NaN-interpolation on the SST data.
1146
+
1147
+ Parameters
1148
+ ----------
1149
+ x : torch.Tensor
1150
+ Input tensor
1151
+ coords : CoordSystem
1152
+ Input coordinate system
1153
+
1154
+ Returns
1155
+ -------
1156
+ tuple[torch.Tensor, CoordSystem]
1157
+ Output tensor and coordinate system for the derived variables
1158
+ """
1159
+
1160
+ prep_coords = coords .copy ()
1161
+
1162
+ # Fetch the base variables
1163
+ base_vars = list (prep_coords ["variable" ])
1164
+ src_vars = {
1165
+ v : x [..., base_vars .index (v ) : base_vars .index (v ) + 1 , :, :]
1166
+ for v in base_vars
1167
+ }
1168
+
1169
+ # Compute the derived variables
1170
+ out_vars = {
1171
+ "ws10m" : torch .sqrt (src_vars ["u10m" ] ** 2 + src_vars ["v10m" ] ** 2 ),
1172
+ "tau300-700" : src_vars ["z300" ] - src_vars ["z700" ],
1173
+ }
1174
+ out_vars .update (src_vars )
1175
+
1176
+ # Fill SST nans by custom interpolation
1177
+ out_vars ["sst" ] = self ._nan_interpolate_sst (out_vars ["sst" ], coords )
1178
+
1179
+ # Update the tensor with the derived variables and return
1180
+ prep_coords ["variable" ] = np .array (self .atmos_variables + self .ocean_variables )
1181
+ x_out = torch .empty (
1182
+ * [v .shape [0 ] for v in prep_coords .values ()], device = x .device
1183
+ )
1184
+ for i , v in enumerate (prep_coords ["variable" ]):
1185
+ x_out [..., i : i + 1 , :, :] = out_vars [v ]
1186
+
1187
+ return x_out , prep_coords
1188
+
1189
+ @batch_func ()
1064
1190
def __call__ (
1065
1191
self , x : torch .Tensor , coords : CoordSystem
1066
1192
) -> tuple [torch .Tensor , CoordSystem ]:
@@ -1080,18 +1206,24 @@ def __call__(
1080
1206
"""
1081
1207
output_coords = self .output_coords (coords )
1082
1208
1209
+ x , coords = self ._prepare_derived_variables (x , coords )
1210
+
1083
1211
x = self .to_hpx (x )
1084
1212
x = self ._forward (x , self .coords_to_hpx (coords ))
1085
1213
x = self .to_ll (x )
1086
1214
return x , output_coords
1087
1215
1088
- # @batch_func()
1216
+ @batch_func ()
1089
1217
def _default_generator (
1090
1218
self , x : torch .Tensor , coords : CoordSystem
1091
1219
) -> Generator [tuple [torch .Tensor , CoordSystem ], None , None ]:
1092
1220
1093
1221
coords = coords .copy ()
1094
1222
1223
+ base_vars = coords ["variable" ]
1224
+
1225
+ x , coords = self ._prepare_derived_variables (x , coords )
1226
+
1095
1227
yield x , coords
1096
1228
1097
1229
x = self .to_hpx (x )
@@ -1101,7 +1233,12 @@ def _default_generator(
1101
1233
x , coords = self .front_hook (x , coords )
1102
1234
1103
1235
x = self ._forward (x , self .coords_to_hpx (coords ))
1104
- coords = self .output_coords (coords )
1236
+
1237
+ # Output coords expects the input variable set to include base variables,
1238
+ # but will return the ouptut variables with the derived variables
1239
+ base_coords = coords .copy ()
1240
+ base_coords ["variable" ] = base_vars
1241
+ coords = self .output_coords (base_coords )
1105
1242
1106
1243
# Rear hook
1107
1244
x , coords = self .rear_hook (x , coords )
0 commit comments