Skip to content

Commit 4820092

Browse files
committed
Update pipeline.py
1 parent 5c8b8a2 commit 4820092

File tree

1 file changed

+31
-47
lines changed

1 file changed

+31
-47
lines changed

src/maxtext/layers/pipeline.py

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -968,24 +968,11 @@ def init_states(self, inputs):
968968
"""Initializes the pipeline execution state and communication buffers.
969969
970970
This sets up the memory needed to pass activations between pipeline stages
971-
(`state_io` and `shift`) and allocates the empty Buffer Sliding Window (BSW)
972-
that will hold the gathered FSDP weights.
971+
(`state_io` and `shift`). BSW (Buffer Sliding Window) is computed locally
972+
inside scan_body each iteration rather than pre-allocated, so that
973+
jax.checkpoint can discard it between iterations to prevent OOM.
973974
"""
974-
loop_state = super().init_states(inputs)
975-
976-
weights = nnx.state(self.layers, _is_static_param)
977-
978-
def get_single_repeat_shape(x):
979-
if x is None:
980-
return None
981-
return jnp.zeros_like(x[0]) if self.config.num_pipeline_repeats > 1 else jnp.zeros_like(x)
982-
983-
bsw = (
984-
jax.tree.map(get_single_repeat_shape, weights),
985-
jax.tree.map(get_single_repeat_shape, weights),
986-
)
987-
988-
return loop_state, bsw
975+
return super().init_states(inputs)
989976

990977
def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim):
991978
"""Slices out the specific sequence inputs (e.g., positions, segments) for the current microbatch."""
@@ -1371,9 +1358,11 @@ def __call__(
13711358
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
13721359
)
13731360

1374-
loop_state, bsw = self.init_states(inputs)
1361+
loop_state = self.init_states(inputs)
13751362

1376-
# - Full spec (with circular_repeats axis) -> BSW creation via weight_prefetching.
1363+
# Two spec variants needed:
1364+
# - Full spec (with circular_repeats axis) -> BSW creation inside scan_body via
1365+
# from_all_variables_to_repeat_weights + from_repeat_weights_to_bsw.
13771366
# from_repeat_weights_to_bsw's derive_stage_weight_partition_specs drops the
13781367
# first dim (repeat), so the input must still have it.
13791368
# - Stripped logical spec (circular_repeats removed) -> BSW consumption via
@@ -1402,39 +1391,34 @@ def unbox_val(x):
14021391

14031392
_, layers_params, layers_metrics, layers_mutables = nnx.split(layers_state, _is_static_param, nnx.Intermediate, ...)
14041393

1405-
# Pre-populate bsw[1] with iteration-0 weights so the first scan_body
1406-
# slide (next_bsw[0] = current_bsw[1]) picks up the correct weights.
1407-
# bsw[0] is a zero placeholder — it is immediately discarded by the slide.
1408-
init_repeat_weights = self.from_all_variables_to_repeat_weights(layers_params, 0)
1409-
init_w_curr = self.from_repeat_weights_to_bsw(init_repeat_weights, physical_partition_spec_full)
1410-
bsw = (bsw[0], init_w_curr)
1411-
14121394
def scan_body(carry, _):
1413-
current_loop_state, current_bsw, current_layer_mutables = carry
1395+
current_loop_state, current_layer_mutables = carry
14141396
# Fold loop_iteration into RNG keys so each scan step gets a unique
14151397
# dropout mask — mirrors Linen's nn.scan(split_rngs={"random": True}).
14161398
iteration = current_loop_state["loop_iteration"]
14171399
advanced_mutables = _advance_rng_state(current_layer_mutables, iteration)
14181400

1419-
# 1. Async FSDP Prefetch — only gather NEXT repeat's weights (1 all-gather).
1420-
# The current repeat's weights are already in current_bsw[1], carried
1421-
# forward from the previous iteration's prefetch (sliding window).
1422-
# Use FULL spec - weight_prefetching drops the repeat axis internally via
1423-
# derive_stage_weight_partition_specs.
1424-
next_weight = self.weight_prefetching(
1425-
layers_params, physical_partition_spec_full, current_loop_state["loop_iteration"]
1426-
)
1427-
# Sliding window: previous next (current_bsw[1]) becomes current (bsw[0]),
1428-
# freshly prefetched next_weight becomes next (bsw[1]).
1429-
next_bsw = (current_bsw[1], next_weight)
1430-
next_bsw = jax.ad_checkpoint.checkpoint_name(next_bsw, "bsw")
1431-
1432-
# 2. Run Forward & State Shift
1401+
# Compute BOTH current and next weights locally (2 all-gathers per iteration).
1402+
# BSW is NOT carried through scan — it is a body intermediate that
1403+
# jax.checkpoint discards between iterations, preventing OOM.
1404+
# Trade-off: 2 all-gathers/iter instead of 1 (no sliding window).
1405+
# Acceptable until REG-1 (nested scan + custom VJP) restores the optimization.
1406+
cur_repeat_weights = self.from_all_variables_to_repeat_weights(
1407+
layers_params, iteration)
1408+
cur_bsw = self.from_repeat_weights_to_bsw(
1409+
cur_repeat_weights, physical_partition_spec_full)
1410+
nxt_repeat_weights = self.from_all_variables_to_repeat_weights(
1411+
layers_params, iteration + 1)
1412+
nxt_bsw = self.from_repeat_weights_to_bsw(
1413+
nxt_repeat_weights, physical_partition_spec_full)
1414+
bsw = (cur_bsw, nxt_bsw)
1415+
1416+
# Run Forward & State Shift
14331417
# Use STRIPPED logical spec - run_one_iteration re-derives physical from it,
14341418
# and get_current_weights_from_bsw expects specs without the repeat axis.
14351419
new_loop_state, new_layer_state = self.run_one_iteration(
14361420
current_loop_state,
1437-
next_bsw,
1421+
bsw,
14381422
layers_graph,
14391423
layers_metrics,
14401424
advanced_mutables,
@@ -1446,7 +1430,7 @@ def scan_body(carry, _):
14461430
)
14471431

14481432
_, _, new_layer_metrics, new_layer_mutables = nnx.split(new_layer_state, _is_static_param, nnx.Intermediate, ...)
1449-
return (new_loop_state, next_bsw, new_layer_mutables), new_layer_metrics
1433+
return (new_loop_state, new_layer_mutables), new_layer_metrics
14501434

14511435
if self.config.set_remat_policy_on_pipeline_iterations:
14521436
scan_body = jax.checkpoint(
@@ -1455,16 +1439,16 @@ def scan_body(carry, _):
14551439

14561440
# Memory Efficient Execution via pure JAX scan
14571441
if self.config.scan_pipeline_iterations:
1458-
(loop_state, bsw, final_layer_mutables), stacked_metrics = jax.lax.scan(
1459-
scan_body, (loop_state, bsw, layers_mutables), None, length=total_iterations
1442+
(loop_state, final_layer_mutables), stacked_metrics = jax.lax.scan(
1443+
scan_body, (loop_state, layers_mutables), None, length=total_iterations
14601444
)
14611445
else:
1462-
current_carry = (loop_state, bsw, layers_mutables)
1446+
current_carry = (loop_state, layers_mutables)
14631447
metrics_history = []
14641448
for _ in range(total_iterations):
14651449
current_carry, step_metrics = scan_body(current_carry, None)
14661450
metrics_history.append(step_metrics)
1467-
loop_state, bsw, final_layer_mutables = current_carry
1451+
loop_state, final_layer_mutables = current_carry
14681452
stacked_metrics = jax.tree.map(lambda *xs: jnp.stack(xs), *metrics_history) if metrics_history else layers_metrics
14691453

14701454
final_layer_state = nnx.State.merge(layers_params, stacked_metrics, final_layer_mutables)

0 commit comments

Comments
 (0)