Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions experiments/grug/moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ z-loss only). The architecture choices are hardcoded in
others use half. Specifically, layer `i` uses the long mask iff `i % 4 == 3`.
- **Fp32 router path**: router logits cast to fp32 before top-k, softmax, and
QB statistics.
- **Expert parallelism**: `ragged_all_to_all` or ring-based via
`levanter.grug.grug_moe.moe_mlp` (default: ring). Default capacity factor 1.0.
- **Expert parallelism**: ring, plain-XLA `assigned_token`, or DeepEP-backed
assigned-token transport via `levanter.grug.grug_moe.moe_mlp` (default: ring).
Default capacity factor 1.0.

## Scaling heuristic

Expand Down
25 changes: 22 additions & 3 deletions experiments/grug/moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,7 @@ def __call__(
hidden = self.token_embed.at[token_ids].get(out_sharding=batch_spec)
hidden = self.embed_gated_norm(self.embed_norm(hidden))

segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None
short_mask = AttentionMask(is_causal=True, sliding_window=cfg.sliding_window // 2, segment_ids=segment_ids)
long_mask = AttentionMask(is_causal=True, sliding_window=cfg.sliding_window, segment_ids=segment_ids)
short_mask, long_mask = _model_sliding_attention_masks(mask, cfg)

moe_router_stats: list[dict[str, jax.Array]] = []
for i, block in enumerate(self.blocks):
Expand Down Expand Up @@ -583,6 +581,27 @@ def next_token_loss(
return loss


def _model_sliding_attention_masks(
mask: AttentionMask | jax.Array,
cfg: GrugModelConfig,
) -> tuple[AttentionMask, AttentionMask]:
segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None
thd_segment_metadata = mask.thd_segment_metadata if isinstance(mask, AttentionMask) else None
short_mask = AttentionMask(
is_causal=True,
sliding_window=cfg.sliding_window // 2,
segment_ids=segment_ids,
thd_segment_metadata=thd_segment_metadata,
)
long_mask = AttentionMask(
is_causal=True,
sliding_window=cfg.sliding_window,
segment_ids=segment_ids,
thd_segment_metadata=thd_segment_metadata,
)
return short_mask, long_mask


def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float[Array, "..."]:
return std * random.truncated_normal(key, -3, 3, shape)

Expand Down
25 changes: 25 additions & 0 deletions experiments/grug/moe/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

import jax.numpy as jnp
from levanter.grug.attention import AttentionMask, ThdSegmentMetadata

from experiments.grug.moe.model import GrugModelConfig, _model_sliding_attention_masks


def test_model_sliding_attention_masks_preserve_thd_metadata():
cfg = GrugModelConfig(vocab_size=16, hidden_dim=8, num_layers=2, num_heads=2, num_kv_heads=1, sliding_window=8)
metadata = ThdSegmentMetadata(
segment_lengths=jnp.array([[4, 4]], dtype=jnp.int32),
num_segments=jnp.array([2], dtype=jnp.int32),
)
mask = AttentionMask(is_causal=True, thd_segment_metadata=metadata)

short_mask, long_mask = _model_sliding_attention_masks(mask, cfg)

assert short_mask.sliding_window == 4
assert long_mask.sliding_window == 8
assert short_mask.thd_segment_metadata is metadata
assert long_mask.thd_segment_metadata is metadata
assert short_mask.segment_ids is None
assert long_mask.segment_ids is None
1 change: 1 addition & 0 deletions lib/levanter/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,6 @@ filterwarnings = [
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"torch: mark tests that use Torch (deselect with '-m \"not torch\"')",
"timeout: override the default per-test timeout",
]
asyncio_default_fixture_loop_scope = "function"
Loading
Loading