@@ -344,6 +344,96 @@ def __init__(self):
344344 np .testing .assert_allclose (to_jax (pt_output ), jax_output , rtol = 1e-3 , atol = 0.05 )
345345
346346
347+ class Llama4MultiModalProjectorTest (unittest .TestCase ):
348+ """Test for the Llama4 Multi Modal Projector implementation."""
349+
350+ def __copy_weights (self , pt_model , params ):
351+ """Copy weights from PyTorch model to JAX model.
352+
353+ Args:
354+ pt_model: PyTorch Llama4MultiModalProjector model
355+ params: JAX model parameters
356+ """
357+ # Create new params with copied weights
358+ updated_params = jax .tree_util .tree_map (lambda x : x , params )
359+ updated_params ["params" ]["vit_multi_modal_projector" ]["kernel" ] = to_jax (pt_model .linear_1 .weight ).T
360+ return updated_params
361+
362+ def test_multi_modal_projector (self ):
363+ """Test for the Llama4 Multi Modal Projector implementation."""
364+ # Test parameters
365+ # following config https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json
366+ batch_size = 10
367+ num_patches = 24 * 24 # 336/14 = 24 patches per side
368+ pixel_shuffle_ratio = 0.5
369+ vision_output_dim = 4096
370+ hidden_size = 5120
371+
372+ # PyTorch implementation
373+ class VisionConfig :
374+
375+ def __init__ (self ):
376+ self .vision_output_dim = vision_output_dim
377+
378+ class TextConfig :
379+
380+ def __init__ (self ):
381+ self .hidden_size = hidden_size
382+
383+ class Config :
384+
385+ def __init__ (self ):
386+ self .vision_config = VisionConfig ()
387+ self .text_config = TextConfig ()
388+
389+ class Llama4MultiModalProjector (nn .Module ):
390+ """Llama4 Multi Modal Projector pytorch original implementation."""
391+
392+ def __init__ (self , config ):
393+ super ().__init__ ()
394+ self .linear_1 = nn .Linear (
395+ config .vision_config .vision_output_dim ,
396+ config .text_config .hidden_size ,
397+ bias = False ,
398+ )
399+
400+ def forward (self , image_features ):
401+ hidden_states = self .linear_1 (image_features )
402+ return hidden_states
403+
404+ # Create random input tensor
405+ # Shape: [batch_size*num_patches*(pixel_shuffle_ratio**2), vision_output_dim]
406+ inputs_pt = torch .randn (batch_size * num_patches * int (pixel_shuffle_ratio ** 2 ), vision_output_dim )
407+
408+ # Initialize PyTorch model
409+ pt_model = Llama4MultiModalProjector (Config ())
410+ pt_model .eval ()
411+ pt_output = pt_model (inputs_pt )
412+
413+ # JAX implementation
414+ class JaxConfig :
415+
416+ def __init__ (self ):
417+ self .emb_dim = hidden_size
418+ self .dtype_mm = jnp .float32
419+
420+ # Initialize JAX model
421+ jax_model = llama4 .Llama4MultiModalProjector (JaxConfig ())
422+ params = jax_model .init (jax .random .PRNGKey (0 ), to_jax (inputs_pt ))
423+
424+ # Copy weights from PyTorch to JAX
425+ pt_params = self .__copy_weights (pt_model , params )
426+
427+ # Run JAX forward pass with updated params
428+ jax_output = jax_model .apply (pt_params , to_jax (inputs_pt ))
429+
430+ # Compare shapes
431+ self .assertEqual (pt_output .shape , jax_output .shape )
432+
433+ # Compare outputs with reasonable tolerances
434+ np .testing .assert_allclose (to_jax (pt_output ), jax_output , rtol = 1e-3 , atol = 0.05 )
435+
436+
347437def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
348438 """
349439 Pytorch implementation from HuggingFace:
@@ -544,6 +634,8 @@ def test_vision_attention(self):
544634 is_nope_layer = False ,
545635 use_bias_in_projections = True ,
546636 is_vision = True ,
637+ use_qk_norm = False ,
638+ query_pre_attn_scalar = 1 / math .sqrt (self .cfg .hidden_size_for_vit // self .cfg .num_attention_heads_for_vit ),
547639 )
548640
549641 lnx = to_jax (hidden_states_pt )
0 commit comments