Skip to content

Commit d679358

Browse files
committed
simpler
1 parent 51f30e8 commit d679358

1 file changed

Lines changed: 19 additions & 65 deletions

File tree

experiments/grug/moe/model.py

Lines changed: 19 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,7 @@ def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MoEMLP":
263263
def __call__(
264264
self,
265265
x: Float[Array, "B S D"],
266-
*,
267-
return_router_stats: bool = False,
268-
) -> Float[Array, "B S D"] | tuple[Float[Array, "B S D"], dict[str, jax.Array]]:
266+
) -> tuple[Float[Array, "B S D"], dict[str, jax.Array]]:
269267
b, s, _ = x.shape
270268
x_flat = rearrange(x, "b s d -> (b s) d")
271269
router_logits = jnp.einsum("td,de->te", x_flat, reshard(self.router, P(None, None)))
@@ -293,11 +291,7 @@ def __call__(
293291
)
294292
out = routed + shared_out
295293

296-
if return_router_stats:
297-
assert router_stats is not None
298-
return out, router_stats
299-
300-
return out
294+
return out, router_stats
301295

302296
def _has_shared(self):
303297
if self.shared_w_up_gate is not None or self.shared_w_down is not None:
@@ -327,17 +321,11 @@ def __call__(
327321
self,
328322
x: Float[Array, "B S D"],
329323
mask: AttentionMask | jax.Array,
330-
*,
331-
return_router_stats: bool = False,
332-
) -> Float[Array, "B S D"] | tuple[Float[Array, "B S D"], dict[str, jax.Array]]:
324+
) -> tuple[Float[Array, "B S D"], dict[str, jax.Array]]:
333325
x = x + self.attn(self.rms_attn(x), mask)
334-
if return_router_stats:
335-
mlp_out, router_stats = self.mlp(self.rms_mlp(x), return_router_stats=True)
336-
x = x + mlp_out
337-
return x, router_stats
338-
339-
x = x + self.mlp(self.rms_mlp(x))
340-
return x
326+
mlp_out, router_stats = self.mlp(self.rms_mlp(x))
327+
x = x + mlp_out
328+
return x, router_stats
341329

342330

343331
class Transformer(eqx.Module):
@@ -368,29 +356,22 @@ def __call__(
368356
self,
369357
token_ids: Int[Array, "B S"],
370358
mask: AttentionMask | jax.Array | None = None,
371-
*,
372-
return_router_stats: bool = False,
373-
) -> Float[Array, "B S D"] | tuple[Float[Array, "B S D"], dict[str, jax.Array]]:
359+
) -> tuple[Float[Array, "B S D"], dict[str, jax.Array]]:
374360
if mask is None:
375361
mask = AttentionMask.causal()
376362

377363
batch_spec = _batch_spec()
378364
hidden = self.token_embed.at[token_ids].get(out_sharding=batch_spec)
379-
if return_router_stats:
380-
all_router_stats: list[dict[str, jax.Array]] = []
381-
for block in self.blocks:
382-
hidden, router_stats = eqx.filter_checkpoint(block)(hidden, mask, return_router_stats=True)
383-
all_router_stats.append(router_stats)
384-
385-
router_metrics = {
386-
"routing_entropy_per_layer": jnp.stack([s["routing_entropy"] for s in all_router_stats], axis=0),
387-
"routing_counts_per_layer": jnp.stack([s["routing_counts"] for s in all_router_stats], axis=0),
388-
}
389-
return self.final_norm(hidden), router_metrics
390-
365+
all_router_stats: list[dict[str, jax.Array]] = []
391366
for block in self.blocks:
392-
hidden = eqx.filter_checkpoint(block)(hidden, mask)
393-
return self.final_norm(hidden)
367+
hidden, router_stats = eqx.filter_checkpoint(block)(hidden, mask)
368+
all_router_stats.append(router_stats)
369+
370+
router_metrics = {
371+
"routing_entropy_per_layer": jnp.stack([s["routing_entropy"] for s in all_router_stats], axis=0),
372+
"routing_counts_per_layer": jnp.stack([s["routing_counts"] for s in all_router_stats], axis=0),
373+
}
374+
return self.final_norm(hidden), router_metrics
394375

395376
@named_call
396377
def logits(
@@ -399,10 +380,10 @@ def logits(
399380
mask: AttentionMask | jax.Array | None = None,
400381
) -> Float[Array, "B S V"]:
401382
batch_spec = _batch_spec()
402-
hidden = self(token_ids, mask=mask)
383+
hidden, _ = self(token_ids, mask=mask)
403384
return jnp.einsum("bsh,hd->bsd", hidden, self.output_proj, out_sharding=batch_spec)
404385

405-
def compute_loss(
386+
def next_token_loss(
406387
self,
407388
token_ids: Int[Array, "B S"],
408389
loss_weight: Float[Array, "B S"],
@@ -413,12 +394,7 @@ def compute_loss(
413394
loss_dtype: jnp.dtype = jnp.float32,
414395
return_router_metrics: bool = False,
415396
) -> jax.Array | tuple[jax.Array, dict[str, jax.Array | Histogram]]:
416-
"""Compute next-token cross-entropy loss for a batch."""
417-
router_metrics: dict[str, jax.Array] | None = None
418-
if return_router_metrics:
419-
hidden, router_metrics = self(token_ids, mask=mask, return_router_stats=True)
420-
else:
421-
hidden = self(token_ids, mask=mask)
397+
hidden, router_metrics = self(token_ids, mask=mask)
422398
labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32)
423399
loss_weight = loss_weight.astype(loss_dtype)
424400

@@ -432,31 +408,9 @@ def compute_loss(
432408
dtype=loss_dtype,
433409
)
434410
if return_router_metrics:
435-
assert router_metrics is not None
436411
return loss, _summarize_router_metrics(router_metrics)
437412
return loss
438413

439-
def next_token_loss(
440-
self,
441-
token_ids: Int[Array, "B S"],
442-
loss_weight: Float[Array, "B S"],
443-
*,
444-
mask: AttentionMask | jax.Array | None = None,
445-
reduction: str = "mean",
446-
logsumexp_weight: float | None = None,
447-
loss_dtype: jnp.dtype = jnp.float32,
448-
return_router_metrics: bool = False,
449-
) -> jax.Array | tuple[jax.Array, dict[str, jax.Array | Histogram]]:
450-
return self.compute_loss(
451-
token_ids,
452-
loss_weight,
453-
mask=mask,
454-
reduction=reduction,
455-
logsumexp_weight=logsumexp_weight,
456-
loss_dtype=loss_dtype,
457-
return_router_metrics=return_router_metrics,
458-
)
459-
460414

461415
def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float[Array, "..."]:
462416
return std * random.truncated_normal(key, -3, 3, shape)

0 commit comments

Comments
 (0)