Skip to content

Commit c4bdde4

Browse files
authored
[Docs] Add HBM optimization guide and cross-links (#3595)
Add a technical reference guide for fitting JAX/Levanter/Haliax training in HBM and link it from high-signal docs so OOM guidance is easy to find. - Add docs/references/hbm-optimization.md covering sharding, activation checkpointing/offloading, optimizer offload, nested sqrt(n) remat, batch/sequence controls, and practical memory tuning tips. - Add a Technical Reference nav entry in mkdocs.yml. - Link the guide from experiments/grug/README.md, docs/tutorials/train-an-lm.md, docs/tutorials/local-gpu.md, and .agents/skills/change-grug/SKILL.md. Testing: - ./infra/pre-commit.py --all-files --fix Fixes #3594
1 parent f74cda9 commit c4bdde4

File tree

6 files changed

+182
-0
lines changed

6 files changed

+182
-0
lines changed

.agents/skills/change-grug/SKILL.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Keep it grug-style:
8181
- Equinox modules with `init` + `__call__`
8282
- minimal config knobs
8383
- keep legibility first; if a block gets hard to read, introduce a small local helper instead of adding framework indirection
84+
- when HBM is tight, use `docs/references/hbm-optimization.md` before introducing bespoke memory hacks
8485

8586
### 5) Delete stale paths
8687

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Making Things Fit in HBM
2+
3+
This guide is a practical checklist for JAX/Levanter/Haliax training runs that are close to OOM.
4+
5+
The main knobs are:
6+
7+
1. Shard more.
8+
2. Checkpoint and offload activations.
9+
3. Offload optimizer/parameter state.
10+
4. Use model parallelism where it actually helps.
11+
5. Use nested (`sqrt(n)`) checkpointing for scanned stacks.
12+
6. Reduce per-device batch (and sequence length if needed).
13+
14+
## 1) Shard More (Usually the First Lever)
15+
16+
If arrays are accidentally replicated instead of partitioned, HBM disappears fast.
17+
18+
Use explicit placement at boundaries:
19+
20+
- `hax.shard(...)` for Haliax `NamedArray` trees.
21+
- `jax.device_put(...)` for explicit initial placement.
22+
- `jax.sharding.reshard(...)` when you need to change sharding mid-pipeline.
23+
- For LMs, explicitly shard output projection / vocab-axis tensors so logits are partitioned rather than replicated.
24+
25+
```python
26+
import jax
27+
from jax.sharding import NamedSharding, PartitionSpec as P
28+
29+
# Example: shard parameters across data/model axes instead of replicating.
30+
param_sharding = NamedSharding(mesh, P("data", "model"))
31+
params = jax.device_put(params, param_sharding)
32+
```
33+
34+
For FSDP-style setups, confirm large parameter tensors are split across the data axis rather than replicated.
35+
In classic Levanter/Haliax codepaths, this is usually handled for you, but custom tensors and custom losses may still need explicit resharding.
36+
37+
## 2) Activation Checkpointing and Activation Offloading
38+
39+
Checkpointing (rematerialization) trades compute for memory by saving fewer intermediates in forward and recomputing them in backward.
40+
41+
Activation offloading is a variant: selected activations are moved from device memory to pinned host memory after forward, then moved back before backward.
42+
43+
Conceptually, with JAX checkpoint policies you choose, per named intermediate, whether to:
44+
45+
- Save on device.
46+
- Offload to host.
47+
- Recompute.
48+
49+
In Haliax/Levanter scanned stacks, this is typically exposed via `gradient_checkpointing` policies (e.g. standard recompute, offload variants, nested variants).
50+
51+
References:
52+
53+
- [JAX: Gradient checkpointing (`jax.checkpoint` / `jax.remat`)](https://docs.jax.dev/en/latest/gradient-checkpointing.html)
54+
- [JAX Memories and Host Offloading](https://docs.jax.dev/en/latest/notebooks/host-offloading.html)
55+
56+
## 3) Explicit Offloading of Optimizer State (and Sometimes Params)
57+
58+
Optimizer state is often one of the largest memory consumers (especially Adam-family optimizers).
59+
60+
A common pattern is:
61+
62+
1. Keep optimizer state in pinned host memory between steps.
63+
2. Bring it to device only for update math.
64+
3. Return updated state back to host.
65+
66+
```python
67+
import jax
68+
import optax
69+
70+
s_dev = params_sharding
71+
s_host = s_dev.with_memory_kind("pinned_host")
72+
opt_state = jax.device_put(opt_state, s_host)
73+
74+
@jax.jit(donate_argnums=(0,), out_shardings=(s_dev, s_host))
75+
def train_step(params, opt_state, batch):
76+
opt_state = jax.device_put(opt_state, s_dev)
77+
grads = jax.grad(loss_fn)(params, batch)
78+
updates, opt_state = optimizer.update(grads, opt_state, params)
79+
params = optax.apply_updates(params, updates)
80+
return params, jax.device_put(opt_state, s_host)
81+
```
82+
83+
This usually buys substantial HBM headroom, at the cost of transfer bandwidth/latency.
84+
85+
Reference:
86+
87+
- [JAX Memories and Host Offloading (optimizer state + parameter offloading)](https://docs.jax.dev/en/latest/notebooks/host-offloading.html)
88+
89+
## 4) Model Parallelism Can Beat "Max FSDP" in Some Regimes
90+
91+
Sometimes parameter tensors or activations are too large even with aggressive data-axis sharding.
92+
In that case, giving devices to model/tensor parallel axes can reduce peak HBM even though it reduces FSDP degree.
93+
94+
Rule of thumb: sweep a small grid of mesh shapes (for example, more `data` vs more `model`) and compare:
95+
96+
- Peak HBM
97+
- Step time
98+
- Achievable global batch
99+
100+
The best throughput-at-memory-budget point is often not the "maximum data parallel" point.
101+
102+
## 5) `sqrt(n)` Checkpointing for Scanned Layer Stacks
103+
104+
For a stack length `N`, nested checkpointing chunks the work into blocks of size `B` and stores only block boundaries.
105+
106+
When `B ~= sqrt(N)`, memory for saved boundaries is `O(sqrt(N))` instead of `O(N)`, with recomputation overhead.
107+
108+
This is useful for deep scanned stacks where plain checkpointing/offloading still does not fit.
109+
In Haliax scanned modules, nested checkpointing is available as a policy option.
110+
111+
## 6) Reduce Per-Device Batch (and Sequence Length)
112+
113+
If you are right at the limit:
114+
115+
- Reduce microbatch/per-device batch.
116+
- If needed, reduce sequence length.
117+
- Recover global batch with gradient accumulation.
118+
119+
These are the most direct and reliable HBM controls.
120+
121+
## 7) Buffer Donation (`donate_argnums`)
122+
123+
Donation lets JAX reuse input buffers for outputs at JIT boundaries, reducing peak live memory.
124+
125+
Reference:
126+
127+
- [JAX Buffer Donation](https://docs.jax.dev/en/latest/buffer_donation.html)
128+
129+
## 8) Optimizer Choice Matters for Memory
130+
131+
For equal parameter count, optimizer state memory can differ drastically.
132+
133+
- Adam-like methods keep multiple full-size state tensors.
134+
- Memory-lean alternatives (where acceptable for your training regime) can materially reduce HBM pressure.
135+
136+
If you keep Adam-family optimizers, offloading their state is often the practical compromise.
137+
138+
## 9) Profile Memory Before and After Each Change
139+
140+
Use JAX memory profiling tools to confirm what changed:
141+
142+
- [JAX: Profiling device memory](https://docs.jax.dev/en/latest/device_memory_profiling.html)
143+
- [JAX: GPU memory allocation notes](https://docs.jax.dev/en/latest/gpu_memory_allocation.html)
144+
145+
Memory tuning is much faster when each knob change is measured, not guessed.
146+
147+
## 10) Avoid Giant Temporary Tensors
148+
149+
Large temporaries can dominate peak memory even when parameter state fits.
150+
151+
- Avoid materializing full-size intermediates when a fused/chunked computation exists.
152+
- For language models, the full logits tensor (`batch x seq x vocab`) is often the worst offender.
153+
- Use memory-efficient attention kernels/backends where available in your model stack.
154+
155+
## 11) Keep EMA and Other Replicas Off HBM
156+
157+
Extra full-parameter copies (for example EMA weights) can be expensive in HBM.
158+
159+
- Keep long-lived replicas in host memory when possible.
160+
- Materialize them on-device only when needed (for eval/export windows).
161+
162+
## 12) Use Lower Precision Where Safe
163+
164+
HBM scales linearly with dtype size.
165+
166+
- Prefer BF16 activations/weights on hardware where it is standard.
167+
- Be explicit about which states must remain FP32 (often optimizer moments), then offload those if needed.
168+
169+
## 13) Tune Eval Memory Separately from Train
170+
171+
Evaluation often has different memory pressure than training.
172+
173+
- Set eval batch size independently.
174+
- Reduce concurrent eval tasks/checkpoints when needed.
175+
- Keep eval from overlapping peak-memory parts of training if your pipeline allows it.

docs/tutorials/local-gpu.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ If you are using a DGX Spark or similar machine with unified memory, you may nee
6262
echo 'export XLA_PYTHON_CLIENT_MEM_FRACTION=0.5' >> ~/.bashrc
6363
```
6464

65+
For broader JAX/Levanter memory tuning (sharding, checkpointing, offloading), see [Making Things Fit in HBM](../references/hbm-optimization.md).
66+
6567
## Running an Experiment
6668

6769
Now you can run an experiment.

docs/tutorials/train-an-lm.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ Set up your training configuration by calculating the number of training steps a
9797
)
9898
```
9999

100+
If you hit HBM OOM while scaling model size, batch size, or sequence length, see [Making Things Fit in HBM](../references/hbm-optimization.md) for a practical tuning checklist.
101+
100102
## Creating the Training Pipeline
101103

102104
Connect your model configuration, training parameters, and dataset to create a training pipeline:

experiments/grug/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ enforces these minimum interfaces:
159159

160160
- Grug principles: [`/.agents/projects/grugformer.md`](../../.agents/projects/grugformer.md)
161161
- Change workflow: [`.agents/skills/change-grug/`](../../.agents/skills/change-grug/SKILL.md)
162+
- HBM/OOM tuning guide: [`/docs/references/hbm-optimization.md`](../../docs/references/hbm-optimization.md)
162163
- Executor mechanics: [`/docs/explanations/executor.md`](../../docs/explanations/executor.md)
163164
- Executor tutorial: [`/docs/tutorials/executor-101.md`](../../docs/tutorials/executor-101.md)
164165
- TPU debug workflow: [`/docs/dev-guide/dev_tpu.md`](../../docs/dev-guide/dev_tpu.md)

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ nav:
9797
- Executor API: references/executor-api.md
9898
- Default Steps: references/default-steps.md
9999
- Training Configuration: references/train-config.md
100+
- HBM Optimization: references/hbm-optimization.md
100101

101102
markdown_extensions:
102103
- markdown.extensions.footnotes

0 commit comments

Comments
 (0)