@@ -452,7 +452,7 @@ def forward(self, x: torch.Tensor, mask=None):
452452 x: input features with shape of (B * num_lon, num_pl*num_lat, N, C)
453453 mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon)
454454 """
455-
455+
456456 B_ , nW_ , N , C = x .shape
457457 qkv = (
458458 self .qkv (x )
@@ -478,18 +478,18 @@ def forward(self, x: torch.Tensor, mask=None):
478478 attn = self .attn_drop_fn (attn )
479479
480480 x = self .apply_attention (attn , v , B_ , nW_ , N , C )
481-
481+
482482 else :
483483 if mask is not None :
484484 bias = mask .unsqueeze (1 ).unsqueeze (0 ) + earth_position_bias .unsqueeze (0 ).unsqueeze (0 )
485485 # squeeze the bias if needed in dim 2
486486 #bias = bias.squeeze(2)
487487 else :
488488 bias = earth_position_bias .unsqueeze (0 )
489-
489+
490490 # extract batch size for q,k,v
491491 nLon = self .num_lon
492- q = q .view (B_ // nLon , nLon , q .shape [1 ], q .shape [2 ], q .shape [3 ], q .shape [4 ])
492+ q = q .view (B_ // nLon , nLon , q .shape [1 ], q .shape [2 ], q .shape [3 ], q .shape [4 ])
493493 k = k .view (B_ // nLon , nLon , k .shape [1 ], k .shape [2 ], k .shape [3 ], k .shape [4 ])
494494 v = v .view (B_ // nLon , nLon , v .shape [1 ], v .shape [2 ], v .shape [3 ], v .shape [4 ])
495495 ####
@@ -736,7 +736,7 @@ class Pangu(nn.Module):
736736 - https://arxiv.org/abs/2211.02556
737737 """
738738
739- def __init__ (self ,
739+ def __init__ (self ,
740740 inp_shape = (721 ,1440 ),
741741 out_shape = (721 ,1440 ),
742742 grid_in = "equiangular" ,
@@ -773,14 +773,14 @@ def __init__(self,
773773 self .checkpointing_level = checkpointing_level
774774
775775 drop_path = np .linspace (0 , drop_path_rate , 8 ).tolist ()
776-
776+
777777 # Add static channels to surface
778778 self .num_aux = len (self .aux_channel_names )
779779 N_total_surface = self .num_aux + self .num_surface
780780
781781 # compute static permutations to extract
782782 self ._precompute_channel_groups (self .channel_names , self .aux_channel_names )
783-
783+
784784 # Patch embeddings are 2D or 3D convolutions, mapping the data to the required patches
785785 self .patchembed2d = PatchEmbed2D (
786786 img_size = self .inp_shape ,
@@ -791,7 +791,7 @@ def __init__(self,
791791 flatten = False ,
792792 norm_layer = None ,
793793 )
794-
794+
795795 self .patchembed3d = PatchEmbed3D (
796796 img_size = (num_levels , self .inp_shape [0 ], self .inp_shape [1 ]),
797797 patch_size = patch_size ,
@@ -870,7 +870,7 @@ def __init__(self,
870870 self .patchrecovery3d = PatchRecovery3D (
871871 (num_levels , self .inp_shape [0 ], self .inp_shape [1 ]), patch_size , 2 * embed_dim , num_atmospheric
872872 )
873-
873+
874874 def _precompute_channel_groups (
875875 self ,
876876 channel_names = [],
@@ -901,7 +901,7 @@ def _precompute_channel_groups(
901901
902902 def prepare_input (self , input ):
903903 """
904- Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric,
904+ Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric,
905905 and reshaping the atmospheric variables into the required format.
906906 """
907907
@@ -932,23 +932,23 @@ def prepare_output(self, output_surface, output_atmospheric):
932932 level_dict = {level : [idx for idx , value in enumerate (self .channel_names ) if value [1 :] == level ] for level in levels }
933933 reordered_ids = [idx for level in levels for idx in level_dict [level ]]
934934 check_reorder = [f'{ level } _{ idx } ' for level in levels for idx in level_dict [level ]]
935-
935+
936936 # Flatten & reorder the output atmospheric to original order (doublechecked that this is working correctly!)
937937 flattened_atmospheric = output_atmospheric .reshape (output_atmospheric .shape [0 ], - 1 , output_atmospheric .shape [3 ], output_atmospheric .shape [4 ])
938938 reordered_atmospheric = torch .cat ([torch .zeros_like (output_surface ), torch .zeros_like (flattened_atmospheric )], dim = 1 )
939939 for i in range (len (reordered_ids )):
940940 reordered_atmospheric [:, reordered_ids [i ], :, :] = flattened_atmospheric [:, i , :, :]
941-
941+
942942 # Append the surface output, this has not been reordered.
943943 if output_surface is not None :
944- _ , surf_chans , _ , _ = features .get_channel_groups (self .channel_names , self .aux_channel_names )
944+ _ , surf_chans , _ , _ , _ = features .get_channel_groups (self .channel_names , self .aux_channel_names )
945945 reordered_atmospheric [:, surf_chans , :, :] = output_surface
946946 output = reordered_atmospheric
947947 else :
948948 output = reordered_atmospheric
949949
950950 return output
951-
951+
952952 def forward (self , input ):
953953
954954 # Prep the input by splitting into surface and atmospheric variables
@@ -959,7 +959,7 @@ def forward(self, input):
959959 surface = checkpoint (self .patchembed2d , surface_aux , use_reentrant = False )
960960 atmospheric = checkpoint (self .patchembed3d , atmospheric , use_reentrant = False )
961961 else :
962- surface = self .patchembed2d (surface_aux )
962+ surface = self .patchembed2d (surface_aux )
963963 atmospheric = self .patchembed3d (atmospheric )
964964
965965 if surface .shape [1 ] == 0 :
@@ -1011,11 +1011,5 @@ def forward(self, input):
10111011 output_atmospheric = self .patchrecovery3d (output_atmospheric )
10121012
10131013 output = self .prepare_output (output_surface , output_atmospheric )
1014-
1015- return output
1016-
1017-
1018-
10191014
1020-
1021-
1015+ return output
0 commit comments