Skip to content

Commit a25b424

Browse files
committed
Clarify backward flow site semantics
1 parent 265073a commit a25b424

2 files changed

Lines changed: 13 additions & 4 deletions

File tree

docs/design/grug-backward-flow-logging.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ canonical Grug base template.
4949
- `trace_backward_activation(x, name, site=...)`: a convenience wrapper for
5050
identity-only stream anchors that adds a `jax.named_scope(name)` around
5151
`log_backward_activation(...)`
52-
- `BWD_IN` / `BWD_OUT`: named constants for the metric-key site labels
52+
- `BWD_IN` / `BWD_OUT`: named constants for forward input/output boundary labels.
53+
The backward value at `BWD_OUT` is the cotangent with respect to the returned
54+
activation, and the backward value at `BWD_IN` is the cotangent with respect to
55+
the input activation.
5356
- `normalize_name_stack(...)`: removes transform wrappers such as `jvp(...)` and
5457
`transpose(...)` so metric keys stay stable
5558

docs/recipes/add_grug_backward_flow_logging.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,17 @@ class GrugTrainerConfig:
3737
The reusable `BackwardFlowConfig()` default is still disabled with `interval=0`; use that
3838
explicitly when a variant should opt out. Positive intervals sample that often.
3939

40-
## 2) Mark module outputs
40+
## 2) Mark module boundaries
4141

4242
At each named module boundary you want in the graph, wrap the returned activation with
43-
`log_backward_activation(...)`. For modules where you want to see what backward is
44-
sending *into* the module, mark the input with `BWD_IN`:
43+
`log_backward_activation(...)`. Omitting `site` means `BWD_OUT`, the forward output
44+
boundary. The backward metric there is the cotangent with respect to the returned
45+
activation, such as `dL/dout`.
46+
47+
For modules where you also want the gradient at the forward input boundary, mark the
48+
input with `BWD_IN`. The backward metric there is the cotangent with respect to that
49+
input, such as `dL/dx`; in reverse-mode terms it is what the module's backward pass sends
50+
upstream, not what downstream sends into the module.
4551

4652
```python
4753
from levanter.analysis.backward_flow import BWD_IN, log_backward_activation, trace_backward_activation

0 commit comments

Comments
 (0)