@@ -39,10 +39,6 @@ def __init__( # noqa: PLR0913
3939 self .dim = dim
4040 self .cls_token = nn .Parameter (torch .randn (1 , 1 , dim ) * 0.02 )
4141
42- # Required to compile & export the model
43- self .grid_size = 256 // 8
44- self .num_patches = self .grid_size ** 2
45-
4642 self .patch_embedding = DynamicEmbedding (
4743 wave_dim = 128 ,
4844 num_latent_tokens = 128 ,
@@ -68,9 +64,8 @@ def add_encodings(self, patches, time, latlon, gsd):
6864 """Add position encoding to the patches"""
6965 B , L , D = patches .shape
7066
71- # grid_size = int(math.sqrt(L))
72- # self.num_patches = grid_size**2
73- grid_size = self .grid_size
67+ grid_size = int (math .sqrt (L ))
68+ self .num_patches = grid_size ** 2
7469
7570 pos_encoding = (
7671 posemb_sincos_2d_with_gsd (
@@ -165,14 +160,14 @@ def mask_out(self, patches):
165160 masked_matrix ,
166161 ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L]
167162
168- def forward (self , cube , time , latlon , waves , gsd ):
169- # cube, time, latlon, gsd, waves = (
170- # datacube["pixels"], # [B C H W]
171- # datacube["time"], # [B 2]
172- # datacube["latlon"], # [B 2]
173- # datacube["gsd"], # 1
174- # datacube["waves"], # [N]
175- # ) # [B C H W]
163+ def forward (self , datacube ):
164+ cube , time , latlon , gsd , waves = (
165+ datacube ["pixels" ], # [B C H W]
166+ datacube ["time" ], # [B 2]
167+ datacube ["latlon" ], # [B 2]
168+ datacube ["gsd" ], # 1
169+ datacube ["waves" ], # [N]
170+ ) # [B C H W]
176171
177172 B , C , H , W = cube .shape
178173
0 commit comments