@@ -318,7 +318,7 @@ def test_patch_embed_2d_shape(self):
318318 in_chans = 3
319319 embed_dim = 768
320320
321- patch_embed = PatchEmbed (img_size , patch_size , in_chans , embed_dim )
321+ patch_embed = PatchEmbed (patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dim )
322322 x = mx .random .normal ((2 , in_chans , img_size , img_size ))
323323
324324 output = patch_embed (x )
@@ -340,8 +340,9 @@ def test_patch_embed_3d_shape(self):
340340
341341 num_frames = 8
342342
343- patch_embed = PatchEmbed3D (img_size , patch_size , tubelet_size , in_chans , embed_dim )
344- x = mx .random .normal ((1 , in_chans , num_frames , img_size , img_size ))
343+ patch_embed = PatchEmbed3D (patch_size = patch_size , tubelet_size = tubelet_size , in_chans = in_chans , embed_dim = embed_dim )
344+ # PatchEmbed3D expects (B, T, H, W, C) format
345+ x = mx .random .normal ((1 , num_frames , img_size , img_size , in_chans ))
345346
346347 output = patch_embed (x )
347348 mx .eval (output )
@@ -365,7 +366,6 @@ def test_1d_sincos_shape(self):
365366 length = 100
366367
367368 pos_embed = get_1d_sincos_pos_embed (embed_dim , length )
368- mx .eval (pos_embed )
369369
370370 assert pos_embed .shape == (length , embed_dim )
371371
@@ -374,8 +374,7 @@ def test_2d_sincos_shape(self):
374374 embed_dim = 768
375375 grid_size = 14
376376
377- pos_embed = get_2d_sincos_pos_embed (embed_dim , grid_size , grid_size )
378- mx .eval (pos_embed )
377+ pos_embed = get_2d_sincos_pos_embed (embed_dim , grid_size )
379378
380379 num_patches = grid_size * grid_size
381380 assert pos_embed .shape == (num_patches , embed_dim )
@@ -387,10 +386,9 @@ def test_1d_sincos_different_lengths(self):
387386 pos_embed_100 = get_1d_sincos_pos_embed (embed_dim , 100 )
388387 pos_embed_50 = get_1d_sincos_pos_embed (embed_dim , 50 )
389388
390- mx .eval (pos_embed_100 , pos_embed_50 )
391-
392- # First 50 positions should be similar
393- assert mx .allclose (pos_embed_100 [:50 ], pos_embed_50 , atol = 1e-5 )
389+ # First 50 positions should be similar (these are numpy arrays)
390+ import numpy as np
391+ assert np .allclose (pos_embed_100 [:50 ], pos_embed_50 , atol = 1e-5 )
394392
395393
396394if __name__ == "__main__" :
0 commit comments