Skip to content

Commit dea8ecf

Browse files
committed
Fix component test failures
- Fix PatchEmbed tests to use correct constructor arguments: - PatchEmbed(patch_size, in_chans, embed_dim) instead of (img_size, ...) - PatchEmbed3D expects (B, T, H, W, C) format input - Fix get_2d_sincos_pos_embed test: function takes (embed_dim, grid_size) - Fix allclose test: pos_embs returns numpy arrays, use np.allclose
1 parent 7ea9c2f commit dea8ecf

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

tests/test_component_outputs.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def test_patch_embed_2d_shape(self):
318318
in_chans = 3
319319
embed_dim = 768
320320

321-
patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
321+
patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
322322
x = mx.random.normal((2, in_chans, img_size, img_size))
323323

324324
output = patch_embed(x)
@@ -340,8 +340,9 @@ def test_patch_embed_3d_shape(self):
340340

341341
num_frames = 8
342342

343-
patch_embed = PatchEmbed3D(img_size, patch_size, tubelet_size, in_chans, embed_dim)
344-
x = mx.random.normal((1, in_chans, num_frames, img_size, img_size))
343+
patch_embed = PatchEmbed3D(patch_size=patch_size, tubelet_size=tubelet_size, in_chans=in_chans, embed_dim=embed_dim)
344+
# PatchEmbed3D expects (B, T, H, W, C) format
345+
x = mx.random.normal((1, num_frames, img_size, img_size, in_chans))
345346

346347
output = patch_embed(x)
347348
mx.eval(output)
@@ -365,7 +366,6 @@ def test_1d_sincos_shape(self):
365366
length = 100
366367

367368
pos_embed = get_1d_sincos_pos_embed(embed_dim, length)
368-
mx.eval(pos_embed)
369369

370370
assert pos_embed.shape == (length, embed_dim)
371371

@@ -374,8 +374,7 @@ def test_2d_sincos_shape(self):
374374
embed_dim = 768
375375
grid_size = 14
376376

377-
pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size, grid_size)
378-
mx.eval(pos_embed)
377+
pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size)
379378

380379
num_patches = grid_size * grid_size
381380
assert pos_embed.shape == (num_patches, embed_dim)
@@ -387,10 +386,9 @@ def test_1d_sincos_different_lengths(self):
387386
pos_embed_100 = get_1d_sincos_pos_embed(embed_dim, 100)
388387
pos_embed_50 = get_1d_sincos_pos_embed(embed_dim, 50)
389388

390-
mx.eval(pos_embed_100, pos_embed_50)
391-
392-
# First 50 positions should be similar
393-
assert mx.allclose(pos_embed_100[:50], pos_embed_50, atol=1e-5)
389+
# First 50 positions should be similar (these are numpy arrays)
390+
import numpy as np
391+
assert np.allclose(pos_embed_100[:50], pos_embed_50, atol=1e-5)
394392

395393

396394
if __name__ == "__main__":

0 commit comments

Comments
 (0)