|
| 1 | +# Copyright The Levanter Authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Custom-VJP down/gather implementation for local Grug MoE.""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import jax |
| 9 | +import jax.numpy as jnp |
| 10 | +from haliax.nn.ragged_dot import ragged_dot |
| 11 | +from levanter.grug.grug_moe import _gather_sum_reference |
| 12 | + |
| 13 | + |
| 14 | +def _custom_vjp_down_bwd( |
| 15 | + dout: jax.Array, |
| 16 | + h_interleaved: jax.Array, |
| 17 | + w_down: jax.Array, |
| 18 | + combine_weights: jax.Array, |
| 19 | + token_ids_sort: jax.Array, |
| 20 | + sorted_assignment_ids: jax.Array, |
| 21 | + expert_frequency_offset: jax.Array, |
| 22 | + dispatch_output: jax.Array, |
| 23 | +) -> tuple[jax.Array, jax.Array, jax.Array]: |
| 24 | + group_sizes = jnp.diff(expert_frequency_offset) |
| 25 | + assignments = h_interleaved.shape[0] |
| 26 | + sorted_scores = combine_weights.reshape(assignments)[sorted_assignment_ids].astype(jnp.float32) |
| 27 | + dout_sorted = dout[token_ids_sort] |
| 28 | + |
| 29 | + def activation_forward(h: jax.Array) -> jax.Array: |
| 30 | + gate = h[:, 0::2] |
| 31 | + up = h[:, 1::2] |
| 32 | + return jax.nn.silu(gate) * up |
| 33 | + |
| 34 | + hidden, activation_pullback = jax.vjp(activation_forward, h_interleaved) |
| 35 | + weighted_dout = (dout_sorted.astype(jnp.float32) * sorted_scores[:, None]).astype(dispatch_output.dtype) |
| 36 | + _, down_pullback = jax.vjp(lambda h, w: ragged_dot(h, w, group_sizes), hidden, w_down) |
| 37 | + d_hidden, d_w_down = down_pullback(weighted_dout) |
| 38 | + (d_h_interleaved,) = activation_pullback(d_hidden) |
| 39 | + d_scores_sorted = jnp.sum(dout_sorted.astype(jnp.float32) * dispatch_output.astype(jnp.float32), axis=-1) |
| 40 | + d_scores = jnp.zeros_like(sorted_scores).at[sorted_assignment_ids].set(d_scores_sorted) |
| 41 | + return d_h_interleaved, d_scores, d_w_down |
| 42 | + |
| 43 | + |
| 44 | +@jax.custom_vjp |
| 45 | +def custom_vjp_interleaved_down_gather_sum( |
| 46 | + w13_out_interleaved: jax.Array, |
| 47 | + combine_weights: jax.Array, |
| 48 | + w_down: jax.Array, |
| 49 | + token_ids_sort: jax.Array, |
| 50 | + sorted_assignment_ids: jax.Array, |
| 51 | + dispatch_positions: jax.Array, |
| 52 | + group_sizes: jax.Array, |
| 53 | +) -> jax.Array: |
| 54 | + out, _ = _custom_vjp_interleaved_down_gather_sum_forward( |
| 55 | + w13_out_interleaved, |
| 56 | + combine_weights, |
| 57 | + w_down, |
| 58 | + token_ids_sort, |
| 59 | + sorted_assignment_ids, |
| 60 | + dispatch_positions, |
| 61 | + group_sizes, |
| 62 | + ) |
| 63 | + return out |
| 64 | + |
| 65 | + |
| 66 | +def _custom_vjp_interleaved_down_gather_sum_forward( |
| 67 | + w13_out_interleaved: jax.Array, |
| 68 | + combine_weights: jax.Array, |
| 69 | + w_down: jax.Array, |
| 70 | + token_ids_sort: jax.Array, |
| 71 | + sorted_assignment_ids: jax.Array, |
| 72 | + dispatch_positions: jax.Array, |
| 73 | + group_sizes: jax.Array, |
| 74 | +) -> tuple[jax.Array, tuple[jax.Array, ...]]: |
| 75 | + del sorted_assignment_ids |
| 76 | + hidden = jax.nn.silu(w13_out_interleaved[:, 0::2]) * w13_out_interleaved[:, 1::2] |
| 77 | + dispatch_output = ragged_dot(hidden, w_down, group_sizes) |
| 78 | + out = _gather_sum_reference(dispatch_output, dispatch_positions, combine_weights) |
| 79 | + expert_frequency_offset = jnp.concatenate( |
| 80 | + [jnp.zeros((1,), dtype=jnp.int32), jnp.cumsum(group_sizes, dtype=jnp.int32)] |
| 81 | + ) |
| 82 | + return out, ( |
| 83 | + w13_out_interleaved, |
| 84 | + combine_weights, |
| 85 | + w_down, |
| 86 | + token_ids_sort, |
| 87 | + expert_frequency_offset, |
| 88 | + dispatch_output, |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +def _custom_vjp_interleaved_down_gather_sum_fwd( |
| 93 | + w13_out_interleaved: jax.Array, |
| 94 | + combine_weights: jax.Array, |
| 95 | + w_down: jax.Array, |
| 96 | + token_ids_sort: jax.Array, |
| 97 | + sorted_assignment_ids: jax.Array, |
| 98 | + dispatch_positions: jax.Array, |
| 99 | + group_sizes: jax.Array, |
| 100 | +) -> tuple[jax.Array, tuple[jax.Array, ...]]: |
| 101 | + out, residuals = _custom_vjp_interleaved_down_gather_sum_forward( |
| 102 | + w13_out_interleaved, |
| 103 | + combine_weights, |
| 104 | + w_down, |
| 105 | + token_ids_sort, |
| 106 | + sorted_assignment_ids, |
| 107 | + dispatch_positions, |
| 108 | + group_sizes, |
| 109 | + ) |
| 110 | + return out, (*residuals, sorted_assignment_ids) |
| 111 | + |
| 112 | + |
| 113 | +def _custom_vjp_interleaved_down_gather_sum_bwd( |
| 114 | + residuals: tuple[jax.Array, ...], |
| 115 | + dout: jax.Array, |
| 116 | +) -> tuple[jax.Array, jax.Array, jax.Array, None, None, None, None]: |
| 117 | + ( |
| 118 | + w13_out_interleaved, |
| 119 | + combine_weights, |
| 120 | + w_down, |
| 121 | + token_ids_sort, |
| 122 | + expert_frequency_offset, |
| 123 | + dispatch_output, |
| 124 | + sorted_assignment_ids, |
| 125 | + ) = residuals |
| 126 | + d_h_interleaved, d_scores_flat, d_w_down = _custom_vjp_down_bwd( |
| 127 | + dout, |
| 128 | + w13_out_interleaved, |
| 129 | + w_down, |
| 130 | + combine_weights, |
| 131 | + token_ids_sort, |
| 132 | + sorted_assignment_ids, |
| 133 | + expert_frequency_offset, |
| 134 | + dispatch_output, |
| 135 | + ) |
| 136 | + d_combine_weights = d_scores_flat.reshape(combine_weights.shape).astype(combine_weights.dtype) |
| 137 | + return d_h_interleaved, d_combine_weights, d_w_down.astype(w_down.dtype), None, None, None, None |
| 138 | + |
| 139 | + |
| 140 | +custom_vjp_interleaved_down_gather_sum.defvjp( |
| 141 | + _custom_vjp_interleaved_down_gather_sum_fwd, |
| 142 | + _custom_vjp_interleaved_down_gather_sum_bwd, |
| 143 | +) |
0 commit comments