Skip to content

Commit fbac3cd

Browse files
committed
Revert changes to Encoder, don't change the API
1 parent c403c47 commit fbac3cd

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

src/model.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ def __init__( # noqa: PLR0913
3939
self.dim = dim
4040
self.cls_token = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
4141

42-
# Required to compile & export the model
43-
self.grid_size = 256 // 8
44-
self.num_patches = self.grid_size**2
45-
4642
self.patch_embedding = DynamicEmbedding(
4743
wave_dim=128,
4844
num_latent_tokens=128,
@@ -68,9 +64,8 @@ def add_encodings(self, patches, time, latlon, gsd):
6864
"""Add position encoding to the patches"""
6965
B, L, D = patches.shape
7066

71-
# grid_size = int(math.sqrt(L))
72-
# self.num_patches = grid_size**2
73-
grid_size = self.grid_size
67+
grid_size = int(math.sqrt(L))
68+
self.num_patches = grid_size**2
7469

7570
pos_encoding = (
7671
posemb_sincos_2d_with_gsd(
@@ -165,14 +160,14 @@ def mask_out(self, patches):
165160
masked_matrix,
166161
) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L]
167162

168-
def forward(self, cube, time, latlon, waves, gsd):
169-
# cube, time, latlon, gsd, waves = (
170-
# datacube["pixels"], # [B C H W]
171-
# datacube["time"], # [B 2]
172-
# datacube["latlon"], # [B 2]
173-
# datacube["gsd"], # 1
174-
# datacube["waves"], # [N]
175-
# ) # [B C H W]
163+
def forward(self, datacube):
164+
cube, time, latlon, gsd, waves = (
165+
datacube["pixels"], # [B C H W]
166+
datacube["time"], # [B 2]
167+
datacube["latlon"], # [B 2]
168+
datacube["gsd"], # 1
169+
datacube["waves"], # [N]
170+
) # [B C H W]
176171

177172
B, C, H, W = cube.shape
178173

0 commit comments

Comments
 (0)