Skip to content

Commit c027bb7

Browse files
committed
Name backward flow magic constants
1 parent 4bee1eb commit c027bb7

6 files changed

Lines changed: 119 additions & 52 deletions

File tree

docs/design/grug-backward-flow-logging.md

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ canonical Grug base template.
4949
- `trace_backward_activation(x, name, site=...)`: a convenience wrapper for
5050
identity-only stream anchors that adds a `jax.named_scope(name)` around
5151
`log_backward_activation(...)`
52+
- `BACKWARD_FLOW_SITE_IN` / `BACKWARD_FLOW_SITE_OUT`: named constants for the
53+
metric-key site labels
5254
- `normalize_name_stack(...)`: removes transform wrappers such as `jvp(...)` and
5355
`transpose(...)` so metric keys stay stable
5456

@@ -58,25 +60,35 @@ registry.
5860

5961
```python
6062
@functools.partial(jax.custom_vjp, nondiff_argnums=(0, 1))
61-
def _tagged_identity(metric_prefix: str, site: str, x: jax.Array) -> jax.Array:
63+
def _tagged_identity(metric_prefix: str, site: BackwardFlowSite, x: jax.Array) -> jax.Array:
6264
return x
6365

64-
def _tagged_identity_fwd(metric_prefix: str, site: str, x: jax.Array):
65-
levanter.tracker.jit_log(_tensor_metrics(metric_prefix, x, site=site, kind="activation"), step=None)
66+
def _tagged_identity_fwd(metric_prefix: str, site: BackwardFlowSite, x: jax.Array):
67+
levanter.tracker.jit_log(
68+
_tensor_metrics(metric_prefix, x, site=site, kind=BACKWARD_FLOW_KIND_ACTIVATION),
69+
step=None,
70+
)
6671
return x, None
6772

68-
def _tagged_identity_bwd(metric_prefix: str, site: str, _residual: None, cotangent: jax.Array):
69-
levanter.tracker.jit_log(_tensor_metrics(metric_prefix, cotangent, site=site, kind="gradient"), step=None)
73+
def _tagged_identity_bwd(metric_prefix: str, site: BackwardFlowSite, _residual: None, cotangent: jax.Array):
74+
levanter.tracker.jit_log(
75+
_tensor_metrics(metric_prefix, cotangent, site=site, kind=BACKWARD_FLOW_KIND_GRADIENT),
76+
step=None,
77+
)
7078
return (cotangent,)
7179

72-
def log_backward_activation(x: jax.Array, *, site: str = "out") -> jax.Array:
80+
def log_backward_activation(
81+
x: jax.Array, *, site: BackwardFlowSite = BACKWARD_FLOW_SITE_OUT
82+
) -> jax.Array:
7383
context = _ACTIVE_CONTEXT.get()
7484
if context is None:
7585
return x
7686
name_stack = normalize_name_stack(str(source_info_util.current_name_stack()))
7787
return _tagged_identity(f"{context.prefix}/{name_stack}", site, x)
7888

79-
def trace_backward_activation(x: jax.Array, name: str, *, site: str = "out") -> jax.Array:
89+
def trace_backward_activation(
90+
x: jax.Array, name: str, *, site: BackwardFlowSite = BACKWARD_FLOW_SITE_OUT
91+
) -> jax.Array:
8092
with jax.named_scope(name):
8193
return log_backward_activation(x, site=site)
8294
```

docs/recipes/add_grug_backward_flow_logging.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,17 @@ explicitly when a variant should opt out. Positive intervals sample that often.
4040
## 2) Mark module outputs
4141

4242
At each named module boundary you want in the graph, wrap the returned activation with
43-
`log_backward_activation(..., site="out")`. For modules where you want to see what
44-
backward is sending *into* the module, also mark the input with `site="in"`:
43+
`log_backward_activation(...)`. For modules where you want to see what backward is
44+
sending *into* the module, mark the input with `BACKWARD_FLOW_SITE_IN`:
4545

4646
```python
47-
from levanter.analysis.backward_flow import log_backward_activation, trace_backward_activation
47+
from levanter.analysis.backward_flow import BACKWARD_FLOW_SITE_IN, log_backward_activation, trace_backward_activation
4848

4949
@named_call
5050
def __call__(self, x):
51-
x = log_backward_activation(x, site="in")
51+
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
5252
out = ...
53-
return log_backward_activation(out, site="out")
53+
return log_backward_activation(out)
5454
```
5555

5656
For identity-only stream anchors, use `trace_backward_activation(...)` to add the probe

experiments/grug/base/model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
from jax.sharding import reshard
1515
from jaxtyping import Array, Float, Int, PRNGKeyArray
1616

17-
from levanter.analysis.backward_flow import is_backward_flow_active, log_backward_activation, trace_backward_activation
17+
from levanter.analysis.backward_flow import (
18+
BACKWARD_FLOW_SITE_IN,
19+
is_backward_flow_active,
20+
log_backward_activation,
21+
trace_backward_activation,
22+
)
1823
from levanter.grug.attention import AttentionMask, RotaryConfig, apply_rotary_embedding, attention
1924
from levanter.grug.loss import fused_linear_softmax_cross_entropy_loss
2025
from levanter.grug.sharding import Pbatch, Pembed_vocab, Plm_head, Plogits, unshard
@@ -77,7 +82,7 @@ def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "CausalSelfAttention":
7782

7883
@named_call
7984
def __call__(self, x: Float[Array, "B S D"], mask: AttentionMask | jax.Array) -> Float[Array, "B S D"]:
80-
x = log_backward_activation(x, site="in")
85+
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
8186
head_dim = self.cfg.inferred_head_dim
8287
seq_len = x.shape[1]
8388

@@ -106,7 +111,7 @@ def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MLP":
106111

107112
@named_call
108113
def __call__(self, x: Float[Array, "B S D"]) -> Float[Array, "B S D"]:
109-
x = log_backward_activation(x, site="in")
114+
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
110115
up = jnp.einsum("bsh,hm->bsm", x, self.mlp_up)
111116
activated = jax.nn.relu(up)
112117
out = jnp.einsum("bsm,mh->bsh", activated, self.mlp_down, out_sharding=Pbatch)

lib/levanter/src/levanter/analysis/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@
55
"BackwardFlowConfig",
66
"BackwardFlowEdge",
77
"BackwardFlowGraph",
8+
"BACKWARD_FLOW_KIND_ACTIVATION",
9+
"BACKWARD_FLOW_KIND_GRADIENT",
810
"BackwardFlowPlate",
911
"BackwardFlowRenderHints",
12+
"BACKWARD_FLOW_SITE_IN",
13+
"BACKWARD_FLOW_SITE_OUT",
14+
"BackwardFlowSite",
15+
"BackwardFlowTensorKind",
1016
"SummaryStats",
1117
"cb_compute_entropies",
1218
"cb_compute_top2_gap",
@@ -32,8 +38,14 @@
3238
BackwardFlowConfig,
3339
BackwardFlowEdge,
3440
BackwardFlowGraph,
41+
BACKWARD_FLOW_KIND_ACTIVATION,
42+
BACKWARD_FLOW_KIND_GRADIENT,
3543
BackwardFlowPlate,
3644
BackwardFlowRenderHints,
45+
BACKWARD_FLOW_SITE_IN,
46+
BACKWARD_FLOW_SITE_OUT,
47+
BackwardFlowSite,
48+
BackwardFlowTensorKind,
3749
SummaryStats,
3850
backward_flow_graph_from_jaxpr,
3951
backward_flow_node_stats,

lib/levanter/src/levanter/analysis/backward_flow.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from itertools import pairwise
1313
import math
1414
import re
15-
from typing import Any, Iterable, Literal, Mapping
15+
from typing import Any, Iterable, Literal, Mapping, TypeAlias
1616

1717
import jax
1818
from jax._src.core import Literal as JaxprLiteral
@@ -41,8 +41,14 @@
4141
)
4242
_NAME_STACK_PART_RE = re.compile(r"^(?P<wrapper>[A-Za-z_][A-Za-z0-9_]*)\((?P<inner>.*)\)$")
4343
_STAT_NAMES = ("norm", "rms", "rms_scaled", "mean_abs", "max_abs", "max_abs_scaled", "finite_fraction")
44-
_FLOW_SITES = ("in", "out")
45-
_TENSOR_KINDS = ("activation", "gradient")
44+
BackwardFlowSite: TypeAlias = Literal["in", "out"]
45+
BackwardFlowTensorKind: TypeAlias = Literal["activation", "gradient"]
46+
BACKWARD_FLOW_SITE_IN: BackwardFlowSite = "in"
47+
BACKWARD_FLOW_SITE_OUT: BackwardFlowSite = "out"
48+
BACKWARD_FLOW_KIND_ACTIVATION: BackwardFlowTensorKind = "activation"
49+
BACKWARD_FLOW_KIND_GRADIENT: BackwardFlowTensorKind = "gradient"
50+
_FLOW_SITES = (BACKWARD_FLOW_SITE_IN, BACKWARD_FLOW_SITE_OUT)
51+
_TENSOR_KINDS = (BACKWARD_FLOW_KIND_ACTIVATION, BACKWARD_FLOW_KIND_GRADIENT)
4652
_FLOW_DIRECTIONS = ("tb", "lr")
4753
_DEFAULT_PREFIX = "backward_flow"
4854
_DEFAULT_RESIDUAL_GAIN_HORIZON = 50
@@ -196,7 +202,7 @@ def normalize_name_stack(name_stack: str) -> str:
196202
return "/".join(parts)
197203

198204

199-
def log_backward_activation(x: jax.Array, *, site: str = "out") -> jax.Array:
205+
def log_backward_activation(x: jax.Array, *, site: BackwardFlowSite = BACKWARD_FLOW_SITE_OUT) -> jax.Array:
200206
"""Return ``x`` unchanged while logging activation and backward-gradient scale when enabled."""
201207
context = _ACTIVE_CONTEXT.get()
202208
if context is None:
@@ -213,7 +219,9 @@ def log_backward_activation(x: jax.Array, *, site: str = "out") -> jax.Array:
213219
return _tagged_identity_with_scale(f"{context.prefix}/{name_stack}", site, context.gradient_scale, x)
214220

215221

216-
def trace_backward_activation(x: jax.Array, name: str, *, site: str = "out") -> jax.Array:
222+
def trace_backward_activation(
223+
x: jax.Array, name: str, *, site: BackwardFlowSite = BACKWARD_FLOW_SITE_OUT
224+
) -> jax.Array:
217225
"""Return ``x`` unchanged while logging under an extra JAX named scope."""
218226
if not name:
219227
raise ValueError("name must be non-empty")
@@ -223,43 +231,59 @@ def trace_backward_activation(x: jax.Array, name: str, *, site: str = "out") ->
223231

224232

225233
@functools.partial(jax.custom_vjp, nondiff_argnums=(0, 1))
226-
def _tagged_identity(metric_prefix: str, site: str, x: jax.Array) -> jax.Array:
234+
def _tagged_identity(metric_prefix: str, site: BackwardFlowSite, x: jax.Array) -> jax.Array:
227235
return x
228236

229237

230-
def _tagged_identity_fwd(metric_prefix: str, site: str, x: jax.Array) -> tuple[jax.Array, None]:
231-
levanter.tracker.jit_log(_tensor_metrics(metric_prefix, x, site=site, kind="activation"), step=None)
238+
def _tagged_identity_fwd(metric_prefix: str, site: BackwardFlowSite, x: jax.Array) -> tuple[jax.Array, None]:
239+
levanter.tracker.jit_log(
240+
_tensor_metrics(metric_prefix, x, site=site, kind=BACKWARD_FLOW_KIND_ACTIVATION), step=None
241+
)
232242
return x, None
233243

234244

235-
def _tagged_identity_bwd(metric_prefix: str, site: str, _residual: None, cotangent: jax.Array) -> tuple[jax.Array]:
236-
levanter.tracker.jit_log(_tensor_metrics(metric_prefix, cotangent, site=site, kind="gradient"), step=None)
245+
def _tagged_identity_bwd(
246+
metric_prefix: str, site: BackwardFlowSite, _residual: None, cotangent: jax.Array
247+
) -> tuple[jax.Array]:
248+
levanter.tracker.jit_log(
249+
_tensor_metrics(metric_prefix, cotangent, site=site, kind=BACKWARD_FLOW_KIND_GRADIENT), step=None
250+
)
237251
return (cotangent,)
238252

239253

240254
_tagged_identity.defvjp(_tagged_identity_fwd, _tagged_identity_bwd)
241255

242256

243257
@functools.partial(jax.custom_vjp, nondiff_argnums=(0, 1))
244-
def _tagged_identity_with_scale(metric_prefix: str, site: str, gradient_scale: jax.Array, x: jax.Array) -> jax.Array:
258+
def _tagged_identity_with_scale(
259+
metric_prefix: str, site: BackwardFlowSite, gradient_scale: jax.Array, x: jax.Array
260+
) -> jax.Array:
245261
return x
246262

247263

248264
def _tagged_identity_with_scale_fwd(
249-
metric_prefix: str, site: str, gradient_scale: jax.Array, x: jax.Array
265+
metric_prefix: str, site: BackwardFlowSite, gradient_scale: jax.Array, x: jax.Array
250266
) -> tuple[jax.Array, jax.Array]:
251-
levanter.tracker.jit_log(_tensor_metrics(metric_prefix, x, site=site, kind="activation"), step=None)
267+
levanter.tracker.jit_log(
268+
_tensor_metrics(metric_prefix, x, site=site, kind=BACKWARD_FLOW_KIND_ACTIVATION), step=None
269+
)
252270
return x, gradient_scale
253271

254272

255273
def _tagged_identity_with_scale_bwd(
256274
metric_prefix: str,
257-
site: str,
275+
site: BackwardFlowSite,
258276
gradient_scale: jax.Array,
259277
cotangent: jax.Array,
260278
) -> tuple[jax.Array, jax.Array]:
261279
levanter.tracker.jit_log(
262-
_tensor_metrics(metric_prefix, cotangent, site=site, kind="gradient", gradient_scale=gradient_scale),
280+
_tensor_metrics(
281+
metric_prefix,
282+
cotangent,
283+
site=site,
284+
kind=BACKWARD_FLOW_KIND_GRADIENT,
285+
gradient_scale=gradient_scale,
286+
),
263287
step=None,
264288
)
265289
return jnp.zeros_like(gradient_scale), cotangent
@@ -494,13 +518,13 @@ def _tensor_metrics(
494518
metric_prefix: str,
495519
tensor: jax.Array,
496520
*,
497-
site: str,
498-
kind: str,
521+
site: BackwardFlowSite,
522+
kind: BackwardFlowTensorKind,
499523
gradient_scale: jax.Array | None = None,
500524
) -> dict[str, jax.Array]:
501525
summary = SummaryStats.from_tensor(tensor)
502526
metrics = summary.to_metrics(f"{metric_prefix}/{site}_{kind}")
503-
if kind == "gradient" and gradient_scale is not None:
527+
if kind == BACKWARD_FLOW_KIND_GRADIENT and gradient_scale is not None:
504528
gradient_scale = jnp.asarray(gradient_scale, dtype=jnp.float32)
505529
metrics[f"{metric_prefix}/{site}_{kind}_rms_scaled"] = summary.rms * gradient_scale
506530
metrics[f"{metric_prefix}/{site}_{kind}_max_abs_scaled"] = summary.max_abs * gradient_scale
@@ -1091,49 +1115,61 @@ def _is_supported_metric_name(metric_name: str) -> bool:
10911115
return False
10921116

10931117

1094-
def _metric_value(stats: Mapping[str, float], site: str, kind: str, metric: str) -> float | None:
1118+
def _metric_value(
1119+
stats: Mapping[str, float], site: BackwardFlowSite, kind: BackwardFlowTensorKind, metric: str
1120+
) -> float | None:
10951121
return stats.get(f"{site}_{kind}_{metric}")
10961122

10971123

10981124
def _preferred_gradient_rms(stats: Mapping[str, float]) -> float | None:
1099-
scaled = _preferred_metric(stats, "gradient", "rms_scaled", preferred_site="in")
1125+
scaled = _preferred_metric(stats, BACKWARD_FLOW_KIND_GRADIENT, "rms_scaled", preferred_site=BACKWARD_FLOW_SITE_IN)
11001126
if scaled is not None:
11011127
return scaled
1102-
return _preferred_metric(stats, "gradient", "rms", preferred_site="in")
1128+
return _preferred_metric(stats, BACKWARD_FLOW_KIND_GRADIENT, "rms", preferred_site=BACKWARD_FLOW_SITE_IN)
11031129

11041130

11051131
def _has_scaled_gradient_rms(stats: Mapping[str, float]) -> bool:
1106-
return any(_metric_value(stats, site, "gradient", "rms_scaled") is not None for site in _FLOW_SITES)
1132+
return any(
1133+
_metric_value(stats, site, BACKWARD_FLOW_KIND_GRADIENT, "rms_scaled") is not None for site in _FLOW_SITES
1134+
)
11071135

11081136

11091137
def _preferred_gradient_max_abs(stats: Mapping[str, float]) -> float | None:
1110-
scaled = _preferred_metric(stats, "gradient", "max_abs_scaled", preferred_site="in")
1138+
scaled = _preferred_metric(
1139+
stats, BACKWARD_FLOW_KIND_GRADIENT, "max_abs_scaled", preferred_site=BACKWARD_FLOW_SITE_IN
1140+
)
11111141
if scaled is not None:
11121142
return scaled
1113-
return _preferred_metric(stats, "gradient", "max_abs", preferred_site="in")
1143+
return _preferred_metric(stats, BACKWARD_FLOW_KIND_GRADIENT, "max_abs", preferred_site=BACKWARD_FLOW_SITE_IN)
11141144

11151145

11161146
def _has_scaled_gradient_max_abs(stats: Mapping[str, float]) -> bool:
1117-
return any(_metric_value(stats, site, "gradient", "max_abs_scaled") is not None for site in _FLOW_SITES)
1147+
return any(
1148+
_metric_value(stats, site, BACKWARD_FLOW_KIND_GRADIENT, "max_abs_scaled") is not None for site in _FLOW_SITES
1149+
)
11181150

11191151

11201152
def _preferred_activation_rms(stats: Mapping[str, float]) -> float | None:
1121-
return _preferred_metric(stats, "activation", "rms", preferred_site="out")
1153+
return _preferred_metric(stats, BACKWARD_FLOW_KIND_ACTIVATION, "rms", preferred_site=BACKWARD_FLOW_SITE_OUT)
11221154

11231155

11241156
def _preferred_finite_fraction(stats: Mapping[str, float]) -> float | None:
1125-
gradient_fraction = _preferred_metric(stats, "gradient", "finite_fraction", preferred_site="in")
1157+
gradient_fraction = _preferred_metric(
1158+
stats, BACKWARD_FLOW_KIND_GRADIENT, "finite_fraction", preferred_site=BACKWARD_FLOW_SITE_IN
1159+
)
11261160
if gradient_fraction is not None:
11271161
return gradient_fraction
1128-
return _preferred_metric(stats, "activation", "finite_fraction", preferred_site="out")
1162+
return _preferred_metric(
1163+
stats, BACKWARD_FLOW_KIND_ACTIVATION, "finite_fraction", preferred_site=BACKWARD_FLOW_SITE_OUT
1164+
)
11291165

11301166

11311167
def _preferred_metric(
11321168
stats: Mapping[str, float],
1133-
kind: str,
1169+
kind: BackwardFlowTensorKind,
11341170
metric: str,
11351171
*,
1136-
preferred_site: str,
1172+
preferred_site: BackwardFlowSite,
11371173
) -> float | None:
11381174
preferred = _metric_value(stats, preferred_site, kind, metric)
11391175
if preferred is not None:

lib/levanter/tests/test_backward_flow.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import levanter.tracker
1010
from levanter.analysis.backward_flow import (
11+
BACKWARD_FLOW_SITE_IN,
12+
BACKWARD_FLOW_SITE_OUT,
1113
BackwardFlowConfig,
1214
BackwardFlowEdge,
1315
BackwardFlowGraph,
@@ -44,8 +46,8 @@ def test_normalize_name_stack_strips_jax_transform_wrappers():
4446
def test_log_backward_activation_records_activation_and_gradient_metrics():
4547
@jax.named_call
4648
def inner(x):
47-
x = log_backward_activation(x, site="in")
48-
return log_backward_activation(x * 2, site="out")
49+
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
50+
return log_backward_activation(x * 2, site=BACKWARD_FLOW_SITE_OUT)
4951

5052
@jax.jit
5153
def compute_grad(x):
@@ -72,7 +74,7 @@ def compute_grad(x):
7274
def test_log_backward_activation_records_scaled_gradient_rms_when_configured():
7375
@jax.named_call
7476
def inner(x):
75-
return log_backward_activation(x * 2, site="out")
77+
return log_backward_activation(x * 2, site=BACKWARD_FLOW_SITE_OUT)
7678

7779
@jax.jit
7880
def compute_grad(x):
@@ -117,8 +119,8 @@ def compute_grad(x):
117119
def test_log_backward_activation_allows_callers_to_skip_checkpoint_when_active():
118120
@jax.named_call
119121
def inner(x):
120-
x = log_backward_activation(x, site="in")
121-
return log_backward_activation(jnp.tanh(x * 2), site="out")
122+
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
123+
return log_backward_activation(jnp.tanh(x * 2), site=BACKWARD_FLOW_SITE_OUT)
122124

123125
def maybe_checkpointed_inner(x):
124126
if is_backward_flow_active():
@@ -149,8 +151,8 @@ def init(weight):
149151

150152
def step(self, carry: jax.Array) -> jax.Array:
151153
with jax.named_scope("ArrayBlock"):
152-
carry = log_backward_activation(carry, site="in")
153-
return log_backward_activation(jnp.tanh(carry * self.weight), site="out")
154+
carry = log_backward_activation(carry, site=BACKWARD_FLOW_SITE_IN)
155+
return log_backward_activation(jnp.tanh(carry * self.weight), site=BACKWARD_FLOW_SITE_OUT)
154156

155157
def apply_layers(stack: ArrayStacked[Layer], carry: jax.Array) -> jax.Array:
156158
if is_backward_flow_active():

0 commit comments

Comments
 (0)