Skip to content

Commit 265073a

Browse files
committed
Shorten backward flow site constants
1 parent c027bb7 commit 265073a

6 files changed

Lines changed: 50 additions & 75 deletions

File tree

docs/design/grug-backward-flow-logging.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ canonical Grug base template.
4949
- `trace_backward_activation(x, name, site=...)`: a convenience wrapper for
5050
identity-only stream anchors that adds a `jax.named_scope(name)` around
5151
`log_backward_activation(...)`
52-
- `BACKWARD_FLOW_SITE_IN` / `BACKWARD_FLOW_SITE_OUT`: named constants for the
53-
metric-key site labels
52+
- `BWD_IN` / `BWD_OUT`: named constants for the metric-key site labels
5453
- `normalize_name_stack(...)`: removes transform wrappers such as `jvp(...)` and
5554
`transpose(...)` so metric keys stay stable
5655

@@ -65,20 +64,20 @@ def _tagged_identity(metric_prefix: str, site: BackwardFlowSite, x: jax.Array) -
6564

6665
def _tagged_identity_fwd(metric_prefix: str, site: BackwardFlowSite, x: jax.Array):
6766
levanter.tracker.jit_log(
68-
_tensor_metrics(metric_prefix, x, site=site, kind=BACKWARD_FLOW_KIND_ACTIVATION),
67+
_tensor_metrics(metric_prefix, x, site=site, kind=_BWD_KIND_ACTIVATION),
6968
step=None,
7069
)
7170
return x, None
7271

7372
def _tagged_identity_bwd(metric_prefix: str, site: BackwardFlowSite, _residual: None, cotangent: jax.Array):
7473
levanter.tracker.jit_log(
75-
_tensor_metrics(metric_prefix, cotangent, site=site, kind=BACKWARD_FLOW_KIND_GRADIENT),
74+
_tensor_metrics(metric_prefix, cotangent, site=site, kind=_BWD_KIND_GRADIENT),
7675
step=None,
7776
)
7877
return (cotangent,)
7978

8079
def log_backward_activation(
81-
x: jax.Array, *, site: BackwardFlowSite = BACKWARD_FLOW_SITE_OUT
80+
x: jax.Array, *, site: BackwardFlowSite = BWD_OUT
8281
) -> jax.Array:
8382
context = _ACTIVE_CONTEXT.get()
8483
if context is None:
@@ -87,7 +86,7 @@ def log_backward_activation(
8786
return _tagged_identity(f"{context.prefix}/{name_stack}", site, x)
8887

8988
def trace_backward_activation(
90-
x: jax.Array, name: str, *, site: BackwardFlowSite = BACKWARD_FLOW_SITE_OUT
89+
x: jax.Array, name: str, *, site: BackwardFlowSite = BWD_OUT
9190
) -> jax.Array:
9291
with jax.named_scope(name):
9392
return log_backward_activation(x, site=site)

docs/recipes/add_grug_backward_flow_logging.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ explicitly when a variant should opt out. Positive intervals sample that often.
4141

4242
At each named module boundary you want in the graph, wrap the returned activation with
4343
`log_backward_activation(...)`. For modules where you want to see what backward is
44-
sending *into* the module, mark the input with `BACKWARD_FLOW_SITE_IN`:
44+
sending *into* the module, mark the input with `BWD_IN`:
4545

4646
```python
47-
from levanter.analysis.backward_flow import BACKWARD_FLOW_SITE_IN, log_backward_activation, trace_backward_activation
47+
from levanter.analysis.backward_flow import BWD_IN, log_backward_activation, trace_backward_activation
4848

4949
@named_call
5050
def __call__(self, x):
51-
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
51+
x = log_backward_activation(x, site=BWD_IN)
5252
out = ...
5353
return log_backward_activation(out)
5454
```

experiments/grug/base/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from jaxtyping import Array, Float, Int, PRNGKeyArray
1616

1717
from levanter.analysis.backward_flow import (
18-
BACKWARD_FLOW_SITE_IN,
18+
BWD_IN,
1919
is_backward_flow_active,
2020
log_backward_activation,
2121
trace_backward_activation,
@@ -82,7 +82,7 @@ def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "CausalSelfAttention":
8282

8383
@named_call
8484
def __call__(self, x: Float[Array, "B S D"], mask: AttentionMask | jax.Array) -> Float[Array, "B S D"]:
85-
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
85+
x = log_backward_activation(x, site=BWD_IN)
8686
head_dim = self.cfg.inferred_head_dim
8787
seq_len = x.shape[1]
8888

@@ -111,7 +111,7 @@ def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MLP":
111111

112112
@named_call
113113
def __call__(self, x: Float[Array, "B S D"]) -> Float[Array, "B S D"]:
114-
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
114+
x = log_backward_activation(x, site=BWD_IN)
115115
up = jnp.einsum("bsh,hm->bsm", x, self.mlp_up)
116116
activated = jax.nn.relu(up)
117117
out = jnp.einsum("bsm,mh->bsh", activated, self.mlp_down, out_sharding=Pbatch)

lib/levanter/src/levanter/analysis/__init__.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
"BackwardFlowConfig",
66
"BackwardFlowEdge",
77
"BackwardFlowGraph",
8-
"BACKWARD_FLOW_KIND_ACTIVATION",
9-
"BACKWARD_FLOW_KIND_GRADIENT",
108
"BackwardFlowPlate",
119
"BackwardFlowRenderHints",
12-
"BACKWARD_FLOW_SITE_IN",
13-
"BACKWARD_FLOW_SITE_OUT",
1410
"BackwardFlowSite",
15-
"BackwardFlowTensorKind",
11+
"BWD_IN",
12+
"BWD_OUT",
1613
"SummaryStats",
1714
"cb_compute_entropies",
1815
"cb_compute_top2_gap",
@@ -38,14 +35,11 @@
3835
BackwardFlowConfig,
3936
BackwardFlowEdge,
4037
BackwardFlowGraph,
41-
BACKWARD_FLOW_KIND_ACTIVATION,
42-
BACKWARD_FLOW_KIND_GRADIENT,
4338
BackwardFlowPlate,
4439
BackwardFlowRenderHints,
45-
BACKWARD_FLOW_SITE_IN,
46-
BACKWARD_FLOW_SITE_OUT,
4740
BackwardFlowSite,
48-
BackwardFlowTensorKind,
41+
BWD_IN,
42+
BWD_OUT,
4943
SummaryStats,
5044
backward_flow_graph_from_jaxpr,
5145
backward_flow_node_stats,

lib/levanter/src/levanter/analysis/backward_flow.py

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
_NAME_STACK_PART_RE = re.compile(r"^(?P<wrapper>[A-Za-z_][A-Za-z0-9_]*)\((?P<inner>.*)\)$")
4343
_STAT_NAMES = ("norm", "rms", "rms_scaled", "mean_abs", "max_abs", "max_abs_scaled", "finite_fraction")
4444
BackwardFlowSite: TypeAlias = Literal["in", "out"]
45-
BackwardFlowTensorKind: TypeAlias = Literal["activation", "gradient"]
46-
BACKWARD_FLOW_SITE_IN: BackwardFlowSite = "in"
47-
BACKWARD_FLOW_SITE_OUT: BackwardFlowSite = "out"
48-
BACKWARD_FLOW_KIND_ACTIVATION: BackwardFlowTensorKind = "activation"
49-
BACKWARD_FLOW_KIND_GRADIENT: BackwardFlowTensorKind = "gradient"
50-
_FLOW_SITES = (BACKWARD_FLOW_SITE_IN, BACKWARD_FLOW_SITE_OUT)
51-
_TENSOR_KINDS = (BACKWARD_FLOW_KIND_ACTIVATION, BACKWARD_FLOW_KIND_GRADIENT)
45+
_BackwardFlowTensorKind: TypeAlias = Literal["activation", "gradient"]
46+
BWD_IN: BackwardFlowSite = "in"
47+
BWD_OUT: BackwardFlowSite = "out"
48+
_BWD_KIND_ACTIVATION: _BackwardFlowTensorKind = "activation"
49+
_BWD_KIND_GRADIENT: _BackwardFlowTensorKind = "gradient"
50+
_FLOW_SITES = (BWD_IN, BWD_OUT)
51+
_TENSOR_KINDS = (_BWD_KIND_ACTIVATION, _BWD_KIND_GRADIENT)
5252
_FLOW_DIRECTIONS = ("tb", "lr")
5353
_DEFAULT_PREFIX = "backward_flow"
5454
_DEFAULT_RESIDUAL_GAIN_HORIZON = 50
@@ -202,7 +202,7 @@ def normalize_name_stack(name_stack: str) -> str:
202202
return "/".join(parts)
203203

204204

205-
def log_backward_activation(x: jax.Array, *, site: BackwardFlowSite = BACKWARD_FLOW_SITE_OUT) -> jax.Array:
205+
def log_backward_activation(x: jax.Array, *, site: BackwardFlowSite = BWD_OUT) -> jax.Array:
206206
"""Return ``x`` unchanged while logging activation and backward-gradient scale when enabled."""
207207
context = _ACTIVE_CONTEXT.get()
208208
if context is None:
@@ -219,9 +219,7 @@ def log_backward_activation(x: jax.Array, *, site: BackwardFlowSite = BACKWARD_F
219219
return _tagged_identity_with_scale(f"{context.prefix}/{name_stack}", site, context.gradient_scale, x)
220220

221221

222-
def trace_backward_activation(
223-
x: jax.Array, name: str, *, site: BackwardFlowSite = BACKWARD_FLOW_SITE_OUT
224-
) -> jax.Array:
222+
def trace_backward_activation(x: jax.Array, name: str, *, site: BackwardFlowSite = BWD_OUT) -> jax.Array:
225223
"""Return ``x`` unchanged while logging under an extra JAX named scope."""
226224
if not name:
227225
raise ValueError("name must be non-empty")
@@ -236,18 +234,14 @@ def _tagged_identity(metric_prefix: str, site: BackwardFlowSite, x: jax.Array) -
236234

237235

238236
def _tagged_identity_fwd(metric_prefix: str, site: BackwardFlowSite, x: jax.Array) -> tuple[jax.Array, None]:
239-
levanter.tracker.jit_log(
240-
_tensor_metrics(metric_prefix, x, site=site, kind=BACKWARD_FLOW_KIND_ACTIVATION), step=None
241-
)
237+
levanter.tracker.jit_log(_tensor_metrics(metric_prefix, x, site=site, kind=_BWD_KIND_ACTIVATION), step=None)
242238
return x, None
243239

244240

245241
def _tagged_identity_bwd(
246242
metric_prefix: str, site: BackwardFlowSite, _residual: None, cotangent: jax.Array
247243
) -> tuple[jax.Array]:
248-
levanter.tracker.jit_log(
249-
_tensor_metrics(metric_prefix, cotangent, site=site, kind=BACKWARD_FLOW_KIND_GRADIENT), step=None
250-
)
244+
levanter.tracker.jit_log(_tensor_metrics(metric_prefix, cotangent, site=site, kind=_BWD_KIND_GRADIENT), step=None)
251245
return (cotangent,)
252246

253247

@@ -264,9 +258,7 @@ def _tagged_identity_with_scale(
264258
def _tagged_identity_with_scale_fwd(
265259
metric_prefix: str, site: BackwardFlowSite, gradient_scale: jax.Array, x: jax.Array
266260
) -> tuple[jax.Array, jax.Array]:
267-
levanter.tracker.jit_log(
268-
_tensor_metrics(metric_prefix, x, site=site, kind=BACKWARD_FLOW_KIND_ACTIVATION), step=None
269-
)
261+
levanter.tracker.jit_log(_tensor_metrics(metric_prefix, x, site=site, kind=_BWD_KIND_ACTIVATION), step=None)
270262
return x, gradient_scale
271263

272264

@@ -281,7 +273,7 @@ def _tagged_identity_with_scale_bwd(
281273
metric_prefix,
282274
cotangent,
283275
site=site,
284-
kind=BACKWARD_FLOW_KIND_GRADIENT,
276+
kind=_BWD_KIND_GRADIENT,
285277
gradient_scale=gradient_scale,
286278
),
287279
step=None,
@@ -519,12 +511,12 @@ def _tensor_metrics(
519511
tensor: jax.Array,
520512
*,
521513
site: BackwardFlowSite,
522-
kind: BackwardFlowTensorKind,
514+
kind: _BackwardFlowTensorKind,
523515
gradient_scale: jax.Array | None = None,
524516
) -> dict[str, jax.Array]:
525517
summary = SummaryStats.from_tensor(tensor)
526518
metrics = summary.to_metrics(f"{metric_prefix}/{site}_{kind}")
527-
if kind == BACKWARD_FLOW_KIND_GRADIENT and gradient_scale is not None:
519+
if kind == _BWD_KIND_GRADIENT and gradient_scale is not None:
528520
gradient_scale = jnp.asarray(gradient_scale, dtype=jnp.float32)
529521
metrics[f"{metric_prefix}/{site}_{kind}_rms_scaled"] = summary.rms * gradient_scale
530522
metrics[f"{metric_prefix}/{site}_{kind}_max_abs_scaled"] = summary.max_abs * gradient_scale
@@ -1116,57 +1108,47 @@ def _is_supported_metric_name(metric_name: str) -> bool:
11161108

11171109

11181110
def _metric_value(
1119-
stats: Mapping[str, float], site: BackwardFlowSite, kind: BackwardFlowTensorKind, metric: str
1111+
stats: Mapping[str, float], site: BackwardFlowSite, kind: _BackwardFlowTensorKind, metric: str
11201112
) -> float | None:
11211113
return stats.get(f"{site}_{kind}_{metric}")
11221114

11231115

11241116
def _preferred_gradient_rms(stats: Mapping[str, float]) -> float | None:
1125-
scaled = _preferred_metric(stats, BACKWARD_FLOW_KIND_GRADIENT, "rms_scaled", preferred_site=BACKWARD_FLOW_SITE_IN)
1117+
scaled = _preferred_metric(stats, _BWD_KIND_GRADIENT, "rms_scaled", preferred_site=BWD_IN)
11261118
if scaled is not None:
11271119
return scaled
1128-
return _preferred_metric(stats, BACKWARD_FLOW_KIND_GRADIENT, "rms", preferred_site=BACKWARD_FLOW_SITE_IN)
1120+
return _preferred_metric(stats, _BWD_KIND_GRADIENT, "rms", preferred_site=BWD_IN)
11291121

11301122

11311123
def _has_scaled_gradient_rms(stats: Mapping[str, float]) -> bool:
1132-
return any(
1133-
_metric_value(stats, site, BACKWARD_FLOW_KIND_GRADIENT, "rms_scaled") is not None for site in _FLOW_SITES
1134-
)
1124+
return any(_metric_value(stats, site, _BWD_KIND_GRADIENT, "rms_scaled") is not None for site in _FLOW_SITES)
11351125

11361126

11371127
def _preferred_gradient_max_abs(stats: Mapping[str, float]) -> float | None:
1138-
scaled = _preferred_metric(
1139-
stats, BACKWARD_FLOW_KIND_GRADIENT, "max_abs_scaled", preferred_site=BACKWARD_FLOW_SITE_IN
1140-
)
1128+
scaled = _preferred_metric(stats, _BWD_KIND_GRADIENT, "max_abs_scaled", preferred_site=BWD_IN)
11411129
if scaled is not None:
11421130
return scaled
1143-
return _preferred_metric(stats, BACKWARD_FLOW_KIND_GRADIENT, "max_abs", preferred_site=BACKWARD_FLOW_SITE_IN)
1131+
return _preferred_metric(stats, _BWD_KIND_GRADIENT, "max_abs", preferred_site=BWD_IN)
11441132

11451133

11461134
def _has_scaled_gradient_max_abs(stats: Mapping[str, float]) -> bool:
1147-
return any(
1148-
_metric_value(stats, site, BACKWARD_FLOW_KIND_GRADIENT, "max_abs_scaled") is not None for site in _FLOW_SITES
1149-
)
1135+
return any(_metric_value(stats, site, _BWD_KIND_GRADIENT, "max_abs_scaled") is not None for site in _FLOW_SITES)
11501136

11511137

11521138
def _preferred_activation_rms(stats: Mapping[str, float]) -> float | None:
1153-
return _preferred_metric(stats, BACKWARD_FLOW_KIND_ACTIVATION, "rms", preferred_site=BACKWARD_FLOW_SITE_OUT)
1139+
return _preferred_metric(stats, _BWD_KIND_ACTIVATION, "rms", preferred_site=BWD_OUT)
11541140

11551141

11561142
def _preferred_finite_fraction(stats: Mapping[str, float]) -> float | None:
1157-
gradient_fraction = _preferred_metric(
1158-
stats, BACKWARD_FLOW_KIND_GRADIENT, "finite_fraction", preferred_site=BACKWARD_FLOW_SITE_IN
1159-
)
1143+
gradient_fraction = _preferred_metric(stats, _BWD_KIND_GRADIENT, "finite_fraction", preferred_site=BWD_IN)
11601144
if gradient_fraction is not None:
11611145
return gradient_fraction
1162-
return _preferred_metric(
1163-
stats, BACKWARD_FLOW_KIND_ACTIVATION, "finite_fraction", preferred_site=BACKWARD_FLOW_SITE_OUT
1164-
)
1146+
return _preferred_metric(stats, _BWD_KIND_ACTIVATION, "finite_fraction", preferred_site=BWD_OUT)
11651147

11661148

11671149
def _preferred_metric(
11681150
stats: Mapping[str, float],
1169-
kind: BackwardFlowTensorKind,
1151+
kind: _BackwardFlowTensorKind,
11701152
metric: str,
11711153
*,
11721154
preferred_site: BackwardFlowSite,

lib/levanter/tests/test_backward_flow.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
import levanter.tracker
1010
from levanter.analysis.backward_flow import (
11-
BACKWARD_FLOW_SITE_IN,
12-
BACKWARD_FLOW_SITE_OUT,
1311
BackwardFlowConfig,
1412
BackwardFlowEdge,
1513
BackwardFlowGraph,
1614
BackwardFlowPlate,
15+
BWD_IN,
16+
BWD_OUT,
1717
SummaryStats,
1818
backward_flow_graph_from_jaxpr,
1919
capture_backward_flow,
@@ -46,8 +46,8 @@ def test_normalize_name_stack_strips_jax_transform_wrappers():
4646
def test_log_backward_activation_records_activation_and_gradient_metrics():
4747
@jax.named_call
4848
def inner(x):
49-
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
50-
return log_backward_activation(x * 2, site=BACKWARD_FLOW_SITE_OUT)
49+
x = log_backward_activation(x, site=BWD_IN)
50+
return log_backward_activation(x * 2, site=BWD_OUT)
5151

5252
@jax.jit
5353
def compute_grad(x):
@@ -74,7 +74,7 @@ def compute_grad(x):
7474
def test_log_backward_activation_records_scaled_gradient_rms_when_configured():
7575
@jax.named_call
7676
def inner(x):
77-
return log_backward_activation(x * 2, site=BACKWARD_FLOW_SITE_OUT)
77+
return log_backward_activation(x * 2, site=BWD_OUT)
7878

7979
@jax.jit
8080
def compute_grad(x):
@@ -119,8 +119,8 @@ def compute_grad(x):
119119
def test_log_backward_activation_allows_callers_to_skip_checkpoint_when_active():
120120
@jax.named_call
121121
def inner(x):
122-
x = log_backward_activation(x, site=BACKWARD_FLOW_SITE_IN)
123-
return log_backward_activation(jnp.tanh(x * 2), site=BACKWARD_FLOW_SITE_OUT)
122+
x = log_backward_activation(x, site=BWD_IN)
123+
return log_backward_activation(jnp.tanh(x * 2), site=BWD_OUT)
124124

125125
def maybe_checkpointed_inner(x):
126126
if is_backward_flow_active():
@@ -151,8 +151,8 @@ def init(weight):
151151

152152
def step(self, carry: jax.Array) -> jax.Array:
153153
with jax.named_scope("ArrayBlock"):
154-
carry = log_backward_activation(carry, site=BACKWARD_FLOW_SITE_IN)
155-
return log_backward_activation(jnp.tanh(carry * self.weight), site=BACKWARD_FLOW_SITE_OUT)
154+
carry = log_backward_activation(carry, site=BWD_IN)
155+
return log_backward_activation(jnp.tanh(carry * self.weight), site=BWD_OUT)
156156

157157
def apply_layers(stack: ArrayStacked[Layer], carry: jax.Array) -> jax.Array:
158158
if is_backward_flow_active():

0 commit comments

Comments
 (0)