Skip to content

Commit fef815f

Browse files
author
maxtext authors
committed
Merge pull request #1733 from AI-Hypercomputer:zhaoyuec-add-llama4-multimodal-projector
PiperOrigin-RevId: 765268765
2 parents c7868c9 + ff64c45 commit fef815f

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

MaxText/layers/llama4.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,39 @@ def __call__(self, encoded_patches: Array, deterministic: bool = False) -> Array
234234
return result
235235

236236

237+
class Llama4MultiModalProjector(nn.Module):
238+
"""Implementation of Llama4MultiModalProjector for Llama4 Multi modal model.
239+
240+
This module projects vision features to text hidden dimension.
241+
242+
Attributes:
243+
config: Config containing model parameters
244+
"""
245+
246+
config: Config
247+
248+
def setup(self):
249+
cfg = self.config
250+
self.linear = linears.DenseGeneral(
251+
features=cfg.emb_dim,
252+
dtype=cfg.dtype_mm,
253+
name="vit_multi_modal_projector",
254+
use_bias=False,
255+
)
256+
257+
def __call__(self, image_features: Array) -> Array:
258+
"""Project image features to text hidden dimension.
259+
260+
Args:
261+
image_features: Input tensor of shape [batch_size*num_patches*(pixel_shuffle_ratio**2), vision_output_dim]
262+
263+
Returns:
264+
Tensor of shape [batch_size*num_patches*(pixel_shuffle_ratio**2), emb_dim]
265+
"""
266+
hidden_states = self.linear(image_features)
267+
return hidden_states
268+
269+
237270
def determine_is_nope_layer(layer_id: int, nope_layer_interval: int) -> bool:
238271
"""
239272
Determines whether the given layer at `layer_id` should use RoPE or not (NoPE).
@@ -553,6 +586,8 @@ def __call__(
553586
is_nope_layer=False,
554587
use_bias_in_projections=True,
555588
is_vision=True,
589+
use_qk_norm=False,
590+
query_pre_attn_scalar=1 / math.sqrt(self.config.hidden_size_for_vit // self.config.num_attention_heads_for_vit),
556591
)
557592

558593
hidden_states = attention_layer(

MaxText/tests/check_llama4_layers.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,96 @@ def __init__(self):
344344
np.testing.assert_allclose(to_jax(pt_output), jax_output, rtol=1e-3, atol=0.05)
345345

346346

347+
class Llama4MultiModalProjectorTest(unittest.TestCase):
348+
"""Test for the Llama4 Multi Modal Projector implementation."""
349+
350+
def __copy_weights(self, pt_model, params):
351+
"""Copy weights from PyTorch model to JAX model.
352+
353+
Args:
354+
pt_model: PyTorch Llama4MultiModalProjector model
355+
params: JAX model parameters
356+
"""
357+
# Create new params with copied weights
358+
updated_params = jax.tree_util.tree_map(lambda x: x, params)
359+
updated_params["params"]["vit_multi_modal_projector"]["kernel"] = to_jax(pt_model.linear_1.weight).T
360+
return updated_params
361+
362+
def test_multi_modal_projector(self):
363+
"""Test for the Llama4 Multi Modal Projector implementation."""
364+
# Test parameters
365+
# following config https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json
366+
batch_size = 10
367+
num_patches = 24 * 24 # 336/14 = 24 patches per side
368+
pixel_shuffle_ratio = 0.5
369+
vision_output_dim = 4096
370+
hidden_size = 5120
371+
372+
# PyTorch implementation
373+
class VisionConfig:
374+
375+
def __init__(self):
376+
self.vision_output_dim = vision_output_dim
377+
378+
class TextConfig:
379+
380+
def __init__(self):
381+
self.hidden_size = hidden_size
382+
383+
class Config:
384+
385+
def __init__(self):
386+
self.vision_config = VisionConfig()
387+
self.text_config = TextConfig()
388+
389+
class Llama4MultiModalProjector(nn.Module):
390+
"""Llama4 Multi Modal Projector pytorch original implementation."""
391+
392+
def __init__(self, config):
393+
super().__init__()
394+
self.linear_1 = nn.Linear(
395+
config.vision_config.vision_output_dim,
396+
config.text_config.hidden_size,
397+
bias=False,
398+
)
399+
400+
def forward(self, image_features):
401+
hidden_states = self.linear_1(image_features)
402+
return hidden_states
403+
404+
# Create random input tensor
405+
# Shape: [batch_size*num_patches*(pixel_shuffle_ratio**2), vision_output_dim]
406+
inputs_pt = torch.randn(batch_size * num_patches * int(pixel_shuffle_ratio**2), vision_output_dim)
407+
408+
# Initialize PyTorch model
409+
pt_model = Llama4MultiModalProjector(Config())
410+
pt_model.eval()
411+
pt_output = pt_model(inputs_pt)
412+
413+
# JAX implementation
414+
class JaxConfig:
415+
416+
def __init__(self):
417+
self.emb_dim = hidden_size
418+
self.dtype_mm = jnp.float32
419+
420+
# Initialize JAX model
421+
jax_model = llama4.Llama4MultiModalProjector(JaxConfig())
422+
params = jax_model.init(jax.random.PRNGKey(0), to_jax(inputs_pt))
423+
424+
# Copy weights from PyTorch to JAX
425+
pt_params = self.__copy_weights(pt_model, params)
426+
427+
# Run JAX forward pass with updated params
428+
jax_output = jax_model.apply(pt_params, to_jax(inputs_pt))
429+
430+
# Compare shapes
431+
self.assertEqual(pt_output.shape, jax_output.shape)
432+
433+
# Compare outputs with reasonable tolerances
434+
np.testing.assert_allclose(to_jax(pt_output), jax_output, rtol=1e-3, atol=0.05)
435+
436+
347437
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
348438
"""
349439
Pytorch implementation from HuggingFace:
@@ -544,6 +634,8 @@ def test_vision_attention(self):
544634
is_nope_layer=False,
545635
use_bias_in_projections=True,
546636
is_vision=True,
637+
use_qk_norm=False,
638+
query_pre_attn_scalar=1 / math.sqrt(self.cfg.hidden_size_for_vit // self.cfg.num_attention_heads_for_vit),
547639
)
548640

549641
lnx = to_jax(hidden_states_pt)

0 commit comments

Comments
 (0)