Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 77 additions & 72 deletions lib/levanter/src/levanter/grug/_moe/ep_deepep.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,34 @@ def _pack_deepep_local_assignments(
local_experts: int,
num_recv_tokens: Int[Array, ""],
) -> DeepEPLocalAssignments:
max_recv_tokens, topk = recv_topk_idx.shape
total_assignments = max_recv_tokens * topk

recv_token_indices = jnp.repeat(jnp.arange(max_recv_tokens, dtype=jnp.int32), topk)
expert_flat = recv_topk_idx.reshape(-1).astype(jnp.int32)
recv_valid = jnp.arange(max_recv_tokens, dtype=jnp.int32) < num_recv_tokens
local_mask = recv_valid[:, None] & (recv_topk_idx >= 0) & (recv_topk_idx < local_experts)
local_mask_flat = local_mask.reshape(-1)
local_bucket = jnp.where(local_mask_flat, expert_flat, local_experts)
local_group_sizes = jnp.bincount(local_bucket, length=local_experts + 1).astype(jnp.int32)[:-1]
total_valid = jnp.sum(local_group_sizes, dtype=jnp.int32)

flat_positions = jnp.arange(total_assignments, dtype=jnp.int32)
order_key = local_bucket * total_assignments + flat_positions
max_order_key = (local_experts + 1) * total_assignments
selection_key = jnp.where(local_mask_flat, max_order_key - order_key, -1)
_, sorted_assignment_indices = jax.lax.top_k(selection_key, total_assignments)

recv_token_indices = jnp.take(recv_token_indices, sorted_assignment_indices, axis=0)
x_dispatch = jnp.take(recv_x, recv_token_indices, axis=0)
assignment_weights = jnp.take(recv_topk_weights.reshape(-1), sorted_assignment_indices, axis=0).astype(
recv_x.dtype
)
valid_sorted = jnp.arange(total_assignments, dtype=jnp.int32) < total_valid
x_dispatch = jnp.where(valid_sorted[:, None], x_dispatch, 0)
assignment_weights = jnp.where(valid_sorted, assignment_weights, 0)
return DeepEPLocalAssignments(x_dispatch, assignment_weights, recv_token_indices, local_group_sizes)
with jax.named_scope("deepep_pack_local_assignments"):
max_recv_tokens, topk = recv_topk_idx.shape
total_assignments = max_recv_tokens * topk

recv_token_indices = jnp.repeat(jnp.arange(max_recv_tokens, dtype=jnp.int32), topk)
expert_flat = recv_topk_idx.reshape(-1).astype(jnp.int32)
recv_valid = jnp.arange(max_recv_tokens, dtype=jnp.int32) < num_recv_tokens
local_mask = recv_valid[:, None] & (recv_topk_idx >= 0) & (recv_topk_idx < local_experts)
local_mask_flat = local_mask.reshape(-1)
local_bucket = jnp.where(local_mask_flat, expert_flat, local_experts)
local_group_sizes = jnp.bincount(local_bucket, length=local_experts + 1).astype(jnp.int32)[:-1]
total_valid = jnp.sum(local_group_sizes, dtype=jnp.int32)

flat_positions = jnp.arange(total_assignments, dtype=jnp.int32)
order_key = local_bucket * total_assignments + flat_positions
max_order_key = (local_experts + 1) * total_assignments
selection_key = jnp.where(local_mask_flat, max_order_key - order_key, -1)
_, sorted_assignment_indices = jax.lax.top_k(selection_key, total_assignments)

recv_token_indices = jnp.take(recv_token_indices, sorted_assignment_indices, axis=0)
x_dispatch = jnp.take(recv_x, recv_token_indices, axis=0)
assignment_weights = jnp.take(recv_topk_weights.reshape(-1), sorted_assignment_indices, axis=0).astype(
recv_x.dtype
)
valid_sorted = jnp.arange(total_assignments, dtype=jnp.int32) < total_valid
x_dispatch = jnp.where(valid_sorted[:, None], x_dispatch, 0)
assignment_weights = jnp.where(valid_sorted, assignment_weights, 0)
return DeepEPLocalAssignments(x_dispatch, assignment_weights, recv_token_indices, local_group_sizes)


def _collapse_deepep_local_assignments(
Expand All @@ -85,14 +86,15 @@ def _collapse_deepep_local_assignments(
recv_capacity: int,
num_recv_tokens: Int[Array, ""],
) -> Float[Array, "TR D"]:
recv_out = jax.ops.segment_sum(
out_dispatch * assignment_weights[:, None],
recv_token_indices,
num_segments=recv_capacity,
indices_are_sorted=False,
)
recv_valid = jnp.arange(recv_capacity, dtype=jnp.int32) < num_recv_tokens
return jnp.where(recv_valid[:, None], recv_out, 0)
with jax.named_scope("deepep_collapse_local_assignments"):
recv_out = jax.ops.segment_sum(
out_dispatch * assignment_weights[:, None],
recv_token_indices,
num_segments=recv_capacity,
indices_are_sorted=False,
)
recv_valid = jnp.arange(recv_capacity, dtype=jnp.int32) < num_recv_tokens
return jnp.where(recv_valid[:, None], recv_out, 0)


def _moe_mlp_ep_deepep_local(
Expand Down Expand Up @@ -120,32 +122,34 @@ def _moe_mlp_ep_deepep_local(
max_recv_tokens = x_local.shape[0] * ep_size

with jax.named_scope("dispatch"):
num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = deepep_get_dispatch_layout(
selected_experts_local,
num_ranks=ep_size,
num_experts=num_experts,
)
(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_src_idx,
rank_prefix_matrix,
channel_prefix_matrix,
recv_channel_prefix_matrix,
send_head,
_local_expert_counts,
num_recv_tokens,
) = deepep_dispatch_intranode(
x_local,
selected_experts_local,
combine_weights_local,
num_tokens_per_rank,
num_tokens_per_expert,
is_token_in_rank,
num_experts=num_experts,
max_recv_tokens=max_recv_tokens,
)
with jax.named_scope("deepep_layout"):
num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank = deepep_get_dispatch_layout(
selected_experts_local,
num_ranks=ep_size,
num_experts=num_experts,
)
with jax.named_scope("deepep_dispatch_transport"):
(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_src_idx,
rank_prefix_matrix,
channel_prefix_matrix,
recv_channel_prefix_matrix,
send_head,
_local_expert_counts,
num_recv_tokens,
) = deepep_dispatch_intranode(
x_local,
selected_experts_local,
combine_weights_local,
num_tokens_per_rank,
num_tokens_per_expert,
is_token_in_rank,
num_experts=num_experts,
max_recv_tokens=max_recv_tokens,
)
num_recv_tokens_scalar = jnp.squeeze(num_recv_tokens, axis=0)
local_assignments = _pack_deepep_local_assignments(
recv_x,
Expand Down Expand Up @@ -175,16 +179,17 @@ def _moe_mlp_ep_deepep_local(
recv_capacity=recv_x.shape[0],
num_recv_tokens=num_recv_tokens_scalar,
)
out_local, _ = deepep_combine_intranode(
recv_out,
recv_topk_weights,
recv_src_idx,
rank_prefix_matrix,
channel_prefix_matrix,
recv_channel_prefix_matrix,
send_head,
num_recv_tokens,
is_token_in_rank,
)
with jax.named_scope("deepep_combine_transport"):
out_local, _ = deepep_combine_intranode(
recv_out,
recv_topk_weights,
recv_src_idx,
rank_prefix_matrix,
channel_prefix_matrix,
recv_channel_prefix_matrix,
send_head,
num_recv_tokens,
is_token_in_rank,
)
dropped_total = jnp.array(0, dtype=jnp.int32)
return out_local.astype(x_local.dtype), dropped_total
Loading