Skip to content

Commit db49a71

Browse files
committed
some changes to the model and losses
1 parent e95c3ea commit db49a71

File tree

12 files changed

+407
-152
lines changed

12 files changed

+407
-152
lines changed

makani/models/networks/fourcastnet3.py

Lines changed: 183 additions & 72 deletions
Large diffs are not rendered by default.

makani/models/networks/pangu.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def forward(self, x: torch.Tensor, mask=None):
452452
x: input features with shape of (B * num_lon, num_pl*num_lat, N, C)
453453
mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon)
454454
"""
455-
455+
456456
B_, nW_, N, C = x.shape
457457
qkv = (
458458
self.qkv(x)
@@ -478,18 +478,18 @@ def forward(self, x: torch.Tensor, mask=None):
478478
attn = self.attn_drop_fn(attn)
479479

480480
x = self.apply_attention(attn, v, B_, nW_, N, C)
481-
481+
482482
else:
483483
if mask is not None:
484484
bias = mask.unsqueeze(1).unsqueeze(0) + earth_position_bias.unsqueeze(0).unsqueeze(0)
485485
# squeeze the bias if needed in dim 2
486486
#bias = bias.squeeze(2)
487487
else:
488488
bias = earth_position_bias.unsqueeze(0)
489-
489+
490490
# extract batch size for q,k,v
491491
nLon = self.num_lon
492-
q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4])
492+
q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4])
493493
k = k.view(B_ // nLon, nLon, k.shape[1], k.shape[2], k.shape[3], k.shape[4])
494494
v = v.view(B_ // nLon, nLon, v.shape[1], v.shape[2], v.shape[3], v.shape[4])
495495
####
@@ -736,7 +736,7 @@ class Pangu(nn.Module):
736736
- https://arxiv.org/abs/2211.02556
737737
"""
738738

739-
def __init__(self,
739+
def __init__(self,
740740
inp_shape=(721,1440),
741741
out_shape=(721,1440),
742742
grid_in="equiangular",
@@ -773,14 +773,14 @@ def __init__(self,
773773
self.checkpointing_level = checkpointing_level
774774

775775
drop_path = np.linspace(0, drop_path_rate, 8).tolist()
776-
776+
777777
# Add static channels to surface
778778
self.num_aux = len(self.aux_channel_names)
779779
N_total_surface = self.num_aux + self.num_surface
780780

781781
# compute static permutations to extract
782782
self._precompute_channel_groups(self.channel_names, self.aux_channel_names)
783-
783+
784784
# Patch embeddings are 2D or 3D convolutions, mapping the data to the required patches
785785
self.patchembed2d = PatchEmbed2D(
786786
img_size=self.inp_shape,
@@ -791,7 +791,7 @@ def __init__(self,
791791
flatten=False,
792792
norm_layer=None,
793793
)
794-
794+
795795
self.patchembed3d = PatchEmbed3D(
796796
img_size=(num_levels, self.inp_shape[0], self.inp_shape[1]),
797797
patch_size=patch_size,
@@ -870,7 +870,7 @@ def __init__(self,
870870
self.patchrecovery3d = PatchRecovery3D(
871871
(num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size, 2 * embed_dim, num_atmospheric
872872
)
873-
873+
874874
def _precompute_channel_groups(
875875
self,
876876
channel_names=[],
@@ -901,7 +901,7 @@ def _precompute_channel_groups(
901901

902902
def prepare_input(self, input):
903903
"""
904-
Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric,
904+
Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric,
905905
and reshaping the atmospheric variables into the required format.
906906
"""
907907

@@ -932,23 +932,23 @@ def prepare_output(self, output_surface, output_atmospheric):
932932
level_dict = {level: [idx for idx, value in enumerate(self.channel_names) if value[1:] == level] for level in levels}
933933
reordered_ids = [idx for level in levels for idx in level_dict[level]]
934934
check_reorder = [f'{level}_{idx}' for level in levels for idx in level_dict[level]]
935-
935+
936936
# Flatten & reorder the output atmospheric to original order (doublechecked that this is working correctly!)
937937
flattened_atmospheric = output_atmospheric.reshape(output_atmospheric.shape[0], -1, output_atmospheric.shape[3], output_atmospheric.shape[4])
938938
reordered_atmospheric = torch.cat([torch.zeros_like(output_surface), torch.zeros_like(flattened_atmospheric)], dim=1)
939939
for i in range(len(reordered_ids)):
940940
reordered_atmospheric[:, reordered_ids[i], :, :] = flattened_atmospheric[:, i, :, :]
941-
941+
942942
# Append the surface output, this has not been reordered.
943943
if output_surface is not None:
944-
_, surf_chans, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names)
944+
_, surf_chans, _, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names)
945945
reordered_atmospheric[:, surf_chans, :, :] = output_surface
946946
output = reordered_atmospheric
947947
else:
948948
output = reordered_atmospheric
949949

950950
return output
951-
951+
952952
def forward(self, input):
953953

954954
# Prep the input by splitting into surface and atmospheric variables
@@ -959,7 +959,7 @@ def forward(self, input):
959959
surface = checkpoint(self.patchembed2d, surface_aux, use_reentrant=False)
960960
atmospheric = checkpoint(self.patchembed3d, atmospheric, use_reentrant=False)
961961
else:
962-
surface = self.patchembed2d(surface_aux)
962+
surface = self.patchembed2d(surface_aux)
963963
atmospheric = self.patchembed3d(atmospheric)
964964

965965
if surface.shape[1] == 0:
@@ -1011,11 +1011,5 @@ def forward(self, input):
10111011
output_atmospheric = self.patchrecovery3d(output_atmospheric)
10121012

10131013
output = self.prepare_output(output_surface, output_atmospheric)
1014-
1015-
return output
1016-
1017-
1018-
10191014

1020-
1021-
1015+
return output

makani/models/networks/pangu_onnx.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class PanguOnnx(OnnxWrapper):
3838
channel_order_PL: List containing the names of the pressure levels with the ordering that the ONNX model expects
3939
onnx_file: Path to the ONNX file containing the model
4040
'''
41-
def __init__(self,
41+
def __init__(self,
4242
channel_names=[],
4343
aux_channel_names=[],
4444
onnx_file=None,
@@ -58,7 +58,7 @@ def _precompute_channel_groups(
5858
group the channels appropriately into atmospheric pressure levels and surface variables
5959
"""
6060

61-
atmo_chans, surf_chans, _, pressure_lvls = get_channel_groups(channel_names, aux_channel_names)
61+
atmo_chans, surf_chans, _, _, pressure_lvls = get_channel_groups(channel_names, aux_channel_names)
6262

6363
# compute how many channel groups will be kept internally
6464
self.n_atmo_groups = len(pressure_lvls)
@@ -78,12 +78,12 @@ def prepare_input(self, input):
7878
B,V,Lat,Long=input.shape
7979

8080
if B>1:
81-
raise NotImplementedError("Not implemented yet for batch size greater than 1")
81+
raise NotImplementedError("Not implemented yet for batch size greater than 1")
8282

8383
input=input.squeeze(0)
8484
surface_aux_inp=input[self.surf_channels]
8585
atmospheric_inp=input[self.atmo_channels].reshape(self.n_atmo_groups,self.n_atmo_chans,Lat,Long).transpose(1,0)
86-
86+
8787
return surface_aux_inp, atmospheric_inp
8888

8989
def prepare_output(self, output_surface, output_atmospheric):
@@ -99,15 +99,15 @@ def prepare_output(self, output_surface, output_atmospheric):
9999

100100
return output.unsqueeze(0)
101101

102-
102+
103103
def forward(self, input):
104-
104+
105105
surface, atmospheric = self.prepare_input(input)
106106

107107

108108
output,output_surface=self.onnx_session_run({'input':atmospheric,'input_surface':surface})
109109

110110
output = self.prepare_output(output_surface, output)
111111

112-
112+
113113
return output

makani/models/stepper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ def _forward_eval(self, inp, update_state=True, replace_state=True):
153153

154154
return y
155155

156-
def forward(self, inp, replace_state=True):
156+
def forward(self, inp, update_state=True, replace_state=True):
157157
# decide which routine to call
158158
if self.training:
159-
y = self._forward_train(inp, update_state=True, replace_state=replace_state)
159+
y = self._forward_train(inp, update_state=update_state, replace_state=replace_state)
160160
else:
161-
y = self._forward_eval(inp, update_state=True, replace_state=replace_state)
161+
y = self._forward_eval(inp, update_state=update_state, replace_state=replace_state)
162162

163163
return y

makani/utils/driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,11 +632,11 @@ def get_optimizer(self, model, params):
632632
if params.optimizer_type == "Adam":
633633
if self.log_to_screen:
634634
self.logger.info("using Adam optimizer")
635-
optimizer = optim.Adam(all_parameters, betas=betas, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), foreach=True)
635+
optimizer = optim.Adam(all_parameters, lr=params.get("lr", 1e-3), betas=betas, eps=params.get("optimizer_eps", 1e-8), weight_decay=params.get("weight_decay", 0), foreach=True)
636636
elif params.optimizer_type == "AdamW":
637637
if self.log_to_screen:
638638
self.logger.info("using AdamW optimizer")
639-
optimizer = optim.AdamW(all_parameters, betas=betas, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), foreach=True)
639+
optimizer = optim.AdamW(all_parameters, lr=params.get("lr", 1e-3), betas=betas, eps=params.get("optimizer_eps", 1e-8), weight_decay=params.get("weight_decay", 0), foreach=True)
640640
elif params.optimizer_type == "SGD":
641641
if self.log_to_screen:
642642
self.logger.info("using SGD optimizer")

makani/utils/features.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,15 @@ def get_wind_channels(channel_names):
9797

9898
def get_channel_groups(channel_names, aux_channel_names=[]):
9999
"""
100-
Helper routine to extract indices of atmospheric, surface and auxiliary variables and group them into their respective groups
100+
Helper routine to extract indices of atmospheric, surface and auxiliary variables and group them into their respective groups.
101+
The resulting numbering does NOT respect history.
101102
"""
102103

103104
atmo_groups = OrderedDict()
104105
atmo_chans = []
105106
surf_chans = []
106-
aux_chans = []
107+
dyn_aux_chans = []
108+
stat_aux_chans = []
107109

108110
# parse channel names and group variables by pressure level/surface variables
109111
for idx, chn in enumerate(channel_names):
@@ -127,6 +129,10 @@ def get_channel_groups(channel_names, aux_channel_names=[]):
127129
atmo_chans += idx
128130

129131
# append the auxiliary variable to the surface channels
130-
aux_chans = [idx + len(channel_names) for idx in range(len(aux_channel_names))]
132+
for idx, chn in enumerate(aux_channel_names):
133+
if chn in ["xoro", "xlsml", "xlsms"]:
134+
stat_aux_chans.append(idx + len(channel_names))
135+
else:
136+
dyn_aux_chans.append(idx + len(channel_names))
131137

132-
return atmo_chans, surf_chans, aux_chans, atmo_groups.keys()
138+
return atmo_chans, surf_chans, dyn_aux_chans, stat_aux_chans, atmo_groups.keys()

makani/utils/loss.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from torch_harmonics.quadrature import clenshaw_curtiss_weights, legendre_gauss_weights
3333

3434
from .losses import LossType, GeometricLpLoss, SpectralH1Loss, SpectralAMSELoss
35-
from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss
35+
from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss, EnergyScoreLoss
3636
from .losses import EnsembleNLLLoss, EnsembleMMDLoss
3737
from .losses import DriftRegularization, HydrostaticBalanceLoss
3838

@@ -119,8 +119,6 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps
119119
)
120120

121121
# append to dict and compile before:
122-
# TODO: fix the compile issue
123-
# self.loss_fn[loss_type] = torch.compile(loss_fn)
124122
self.loss_fn.append(loss_fn)
125123

126124
# determine channel weighting
@@ -140,7 +138,8 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps
140138
# get channel weights either directly or through the compute routine
141139
if isinstance(channel_weight_type, List):
142140
chw = torch.tensor(channel_weight_type, dtype=torch.float32)
143-
chw = chw * time_diff_scale
141+
if time_diff_scale is not None:
142+
chw = chw * time_diff_scale
144143
assert chw.shape[1] == loss_fn.n_channels
145144
else:
146145
chw = loss_fn.compute_channel_weighting(channel_weight_type, time_diff_scale=time_diff_scale)
@@ -228,6 +227,8 @@ def _parse_loss_type(self, loss_type: str):
228227
loss_handle = EnsembleNLLLoss
229228
elif "ensemble_mmd" in loss_type:
230229
loss_handle = EnsembleMMDLoss
230+
elif "energy_score" in loss_type:
231+
loss_handle = partial(EnergyScoreLoss)
231232
elif "drift_regularization" in loss_type:
232233
loss_handle = DriftRegularization
233234
else:
@@ -333,19 +334,23 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens
333334
loss_vals.append(lfn(prd, tar, wgt))
334335
all_losses = torch.cat(loss_vals, dim=-1)
335336

337+
# print(all_losses)
338+
336339
if self.training and self.track_running_stats:
337340
self._update_running_stats(all_losses.clone())
338341

339342
# process channel weights
340343
chw = self.channel_weights
341344
if self.uncertainty_weighting and self.training:
342345
var, _ = self.get_running_stats()
346+
if self.num_batches_tracked.item() <= 100:
347+
var = torch.ones_like(var)
343348
chw = chw / (torch.sqrt(2 * var) + self.eps)
344349
elif self.balanced_weighting and self.training:
345350
_, mean = self.get_running_stats()
346351
if self.num_batches_tracked.item() <= 100:
347352
mean = torch.ones_like(mean)
348-
chw = chw / mean
353+
chw = chw / (mean + self.eps)
349354

350355
if self.randomized_loss_weights:
351356
rmask = torch.zeros_like(chw)

makani/utils/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .amse_loss import SpectralAMSELoss
2020
from .hydrostatic_loss import HydrostaticBalanceLoss
2121
from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss
22+
from .crps_loss import EnergyScoreLoss
2223
from .mmd_loss import EnsembleMMDLoss
2324
from .likelihood_loss import EnsembleNLLLoss
2425
from .drift_regularization import DriftRegularization

makani/utils/losses/base_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t
4545
elif channel_weight_type == "auto":
4646

4747
for c, chn in enumerate(channel_names):
48-
if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]:
48+
if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]:
4949
channel_weights[c] = 0.1
5050
elif chn in ["t2m", "2d"]:
5151
channel_weights[c] = 1.0
@@ -58,7 +58,7 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t
5858
elif channel_weight_type == "new auto":
5959

6060
for c, chn in enumerate(channel_names):
61-
if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]:
61+
if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]:
6262
channel_weights[c] = 0.1
6363
elif chn in ["t2m", "2d"]:
6464
channel_weights[c] = 2.0
@@ -71,7 +71,7 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t
7171
elif channel_weight_type == "new auto 2":
7272

7373
for c, chn in enumerate(channel_names):
74-
if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]:
74+
if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]:
7575
channel_weights[c] = 0.1
7676
elif chn in ["t2m", "2d"]:
7777
channel_weights[c] = 2.0

0 commit comments

Comments
 (0)