|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright The Marin Authors |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import json |
| 8 | + |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +import jax |
| 12 | +import jax.numpy as jnp |
| 13 | +from jax.experimental import multihost_utils |
| 14 | +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P |
| 15 | + |
| 16 | +from iris.runtime.jax_init import initialize_jax |
| 17 | +from levanter.grug.grug_moe import moe_mlp |
| 18 | +from levanter.utils.activation import ActivationFunctionEnum |
| 19 | + |
| 20 | + |
| 21 | +def _make_ep_mesh() -> Mesh: |
| 22 | + devices = jax.devices() |
| 23 | + if len(devices) < 2 or len(devices) % 2 != 0: |
| 24 | + raise RuntimeError(f"Need an even number of devices >= 2, got {len(devices)}") |
| 25 | + mesh_devices = np.array(devices).reshape(len(devices) // 2, 2, 1) |
| 26 | + return Mesh( |
| 27 | + mesh_devices, |
| 28 | + axis_names=("data", "expert", "model"), |
| 29 | + axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Explicit), |
| 30 | + ) |
| 31 | + |
| 32 | + |
| 33 | +def _make_inputs( |
| 34 | + *, |
| 35 | + key: jax.Array, |
| 36 | + tokens: int, |
| 37 | + hidden_dim: int, |
| 38 | + intermediate_dim: int, |
| 39 | + num_experts: int, |
| 40 | + topk: int, |
| 41 | + overflow: bool, |
| 42 | +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: |
| 43 | + k_x, k_sel, k_logits, k_w13, k_w2 = jax.random.split(key, 5) |
| 44 | + x = jax.random.normal(k_x, (tokens, hidden_dim), dtype=jnp.float32) |
| 45 | + if overflow: |
| 46 | + selected_experts = jnp.zeros((tokens, topk), dtype=jnp.int32) |
| 47 | + combine_weights = jnp.full((tokens, topk), 1.0 / topk, dtype=jnp.float32) |
| 48 | + else: |
| 49 | + selected_experts = jax.random.randint(k_sel, (tokens, topk), 0, num_experts, dtype=jnp.int32) |
| 50 | + combine_logits = jax.random.normal(k_logits, (tokens, topk), dtype=jnp.float32) |
| 51 | + combine_weights = jax.nn.softmax(combine_logits, axis=-1) |
| 52 | + w_up_gate = jax.random.normal(k_w13, (num_experts, hidden_dim, 2 * intermediate_dim), dtype=jnp.float32) |
| 53 | + w_down = jax.random.normal(k_w2, (num_experts, intermediate_dim, hidden_dim), dtype=jnp.float32) |
| 54 | + return x, selected_experts, combine_weights, w_up_gate, w_down |
| 55 | + |
| 56 | + |
| 57 | +def _tree_diff_stats(a, b) -> dict[str, float]: |
| 58 | + leaves_a = jax.tree.leaves(a) |
| 59 | + leaves_b = jax.tree.leaves(b) |
| 60 | + max_abs = 0.0 |
| 61 | + max_rel = 0.0 |
| 62 | + l2_sq = 0.0 |
| 63 | + ref_l2_sq = 0.0 |
| 64 | + for xa, xb in zip(leaves_a, leaves_b, strict=True): |
| 65 | + da = np.asarray(xa) |
| 66 | + db = np.asarray(xb) |
| 67 | + diff = np.abs(da - db) |
| 68 | + max_abs = max(max_abs, float(diff.max(initial=0.0))) |
| 69 | + denom = np.maximum(np.abs(db), 1e-12) |
| 70 | + max_rel = max(max_rel, float((diff / denom).max(initial=0.0))) |
| 71 | + l2_sq += float(np.sum((da - db) ** 2)) |
| 72 | + ref_l2_sq += float(np.sum(db**2)) |
| 73 | + return { |
| 74 | + "max_abs": max_abs, |
| 75 | + "max_rel": max_rel, |
| 76 | + "l2": l2_sq**0.5, |
| 77 | + "ref_l2": ref_l2_sq**0.5, |
| 78 | + "rel_l2": (l2_sq**0.5) / max(ref_l2_sq**0.5, 1e-12), |
| 79 | + } |
| 80 | + |
| 81 | + |
| 82 | +def _host_array(x: jax.Array) -> np.ndarray: |
| 83 | + if jax.process_count() > 1 and getattr(x, "ndim", 0) > 0: |
| 84 | + x = multihost_utils.process_allgather(x, tiled=True) |
| 85 | + return np.asarray(x) |
| 86 | + |
| 87 | + |
| 88 | +def _host_scalar(x: jax.Array) -> float: |
| 89 | + return float(np.asarray(x)) |
| 90 | + |
| 91 | + |
| 92 | +def _run_case(mesh: Mesh, *, overflow: bool) -> dict[str, object]: |
| 93 | + hidden_dim = 128 |
| 94 | + intermediate_dim = 256 |
| 95 | + num_experts = 8 |
| 96 | + topk = 4 |
| 97 | + tokens = max(len(jax.devices()) * 16, 64) |
| 98 | + |
| 99 | + with jax.set_mesh(mesh): |
| 100 | + x, selected_experts, combine_weights, w_up_gate, w_down = _make_inputs( |
| 101 | + key=jax.random.key(17 if overflow else 7), |
| 102 | + tokens=tokens, |
| 103 | + hidden_dim=hidden_dim, |
| 104 | + intermediate_dim=intermediate_dim, |
| 105 | + num_experts=num_experts, |
| 106 | + topk=topk, |
| 107 | + overflow=overflow, |
| 108 | + ) |
| 109 | + |
| 110 | + batch_sharding = NamedSharding(mesh, P(("data", "expert"), None)) |
| 111 | + expert_sharding = NamedSharding(mesh, P("expert", None, None)) |
| 112 | + x = jax.sharding.reshard(x, batch_sharding) |
| 113 | + selected_experts = jax.sharding.reshard(selected_experts, batch_sharding) |
| 114 | + combine_weights = jax.sharding.reshard(combine_weights, batch_sharding) |
| 115 | + w_up_gate = jax.sharding.reshard(w_up_gate, expert_sharding) |
| 116 | + w_down = jax.sharding.reshard(w_down, expert_sharding) |
| 117 | + |
| 118 | + def run_impl(implementation: str): |
| 119 | + def loss_and_drop( |
| 120 | + x_arg, |
| 121 | + selected_experts_arg, |
| 122 | + combine_weights_arg, |
| 123 | + w_up_gate_arg, |
| 124 | + w_down_arg, |
| 125 | + ): |
| 126 | + out, dropped = moe_mlp( |
| 127 | + x_arg, |
| 128 | + selected_experts_arg, |
| 129 | + combine_weights_arg, |
| 130 | + w_up_gate_arg, |
| 131 | + w_down_arg, |
| 132 | + activation=ActivationFunctionEnum.silu, |
| 133 | + implementation=implementation, |
| 134 | + mesh=None, |
| 135 | + report_capacity_overflow=True, |
| 136 | + capacity_factor=1.0, |
| 137 | + ) |
| 138 | + loss = jnp.mean(out.astype(jnp.float32) ** 2) |
| 139 | + return loss, (out, dropped) |
| 140 | + |
| 141 | + fn = jax.jit(jax.value_and_grad(loss_and_drop, has_aux=True, argnums=(0, 3, 4))) |
| 142 | + (loss, (out, dropped)), grads = fn(x, selected_experts, combine_weights, w_up_gate, w_down) |
| 143 | + return loss, out, dropped, grads |
| 144 | + |
| 145 | + ring_loss, ring_out, ring_dropped, ring_grads = run_impl("ring") |
| 146 | + ragged_loss, ragged_out, ragged_dropped, ragged_grads = run_impl("ragged_all_to_all") |
| 147 | + |
| 148 | + ring_loss = _host_scalar(ring_loss) |
| 149 | + ragged_loss = _host_scalar(ragged_loss) |
| 150 | + ring_out_np = _host_array(ring_out) |
| 151 | + ragged_out_np = _host_array(ragged_out) |
| 152 | + ring_grad_x = _host_array(ring_grads[0]) |
| 153 | + ragged_grad_x = _host_array(ragged_grads[0]) |
| 154 | + ring_grad_w_up_gate = _host_array(ring_grads[1]) |
| 155 | + ragged_grad_w_up_gate = _host_array(ragged_grads[1]) |
| 156 | + ring_grad_w_down = _host_array(ring_grads[2]) |
| 157 | + ragged_grad_w_down = _host_array(ragged_grads[2]) |
| 158 | + |
| 159 | + return { |
| 160 | + "overflow": overflow, |
| 161 | + "tokens": tokens, |
| 162 | + "num_devices": len(jax.devices()), |
| 163 | + "num_processes": jax.process_count(), |
| 164 | + "ring_loss": ring_loss, |
| 165 | + "ragged_loss": ragged_loss, |
| 166 | + "loss_delta": ring_loss - ragged_loss, |
| 167 | + "ring_dropped": int(np.asarray(ring_dropped)), |
| 168 | + "ragged_dropped": int(np.asarray(ragged_dropped)), |
| 169 | + "output_diff": _tree_diff_stats(ring_out_np, ragged_out_np), |
| 170 | + "grad_x_diff": _tree_diff_stats(ring_grad_x, ragged_grad_x), |
| 171 | + "grad_w_up_gate_diff": _tree_diff_stats(ring_grad_w_up_gate, ragged_grad_w_up_gate), |
| 172 | + "grad_w_down_diff": _tree_diff_stats(ring_grad_w_down, ragged_grad_w_down), |
| 173 | + } |
| 174 | + |
| 175 | + |
| 176 | +def main() -> None: |
| 177 | + initialize_jax() |
| 178 | + mesh = _make_ep_mesh() |
| 179 | + normal = _run_case(mesh, overflow=False) |
| 180 | + overflow = _run_case(mesh, overflow=True) |
| 181 | + if jax.process_index() == 0: |
| 182 | + print(json.dumps({"normal": normal, "overflow": overflow}, indent=2, sort_keys=True), flush=True) |
| 183 | + |
| 184 | + |
| 185 | +if __name__ == "__main__": |
| 186 | + main() |
0 commit comments