@@ -968,24 +968,11 @@ def init_states(self, inputs):
968968 """Initializes the pipeline execution state and communication buffers.
969969
970970 This sets up the memory needed to pass activations between pipeline stages
971- (`state_io` and `shift`) and allocates the empty Buffer Sliding Window (BSW)
972- that will hold the gathered FSDP weights.
971+ (`state_io` and `shift`). BSW (Buffer Sliding Window) is computed locally
972+ inside scan_body each iteration rather than pre-allocated, so that
973+ jax.checkpoint can discard it between iterations to prevent OOM.
973974 """
974- loop_state = super ().init_states (inputs )
975-
976- weights = nnx .state (self .layers , _is_static_param )
977-
978- def get_single_repeat_shape (x ):
979- if x is None :
980- return None
981- return jnp .zeros_like (x [0 ]) if self .config .num_pipeline_repeats > 1 else jnp .zeros_like (x )
982-
983- bsw = (
984- jax .tree .map (get_single_repeat_shape , weights ),
985- jax .tree .map (get_single_repeat_shape , weights ),
986- )
987-
988- return loop_state , bsw
975+ return super ().init_states (inputs )
989976
990977 def gather_microbatch_inputs_vmap (self , xs , ids , ids_dim ):
991978 """Slices out the specific sequence inputs (e.g., positions, segments) for the current microbatch."""
@@ -1371,9 +1358,11 @@ def __call__(
13711358 (self .config .num_pipeline_microbatches , self .pipeline_microbatch_size , self .config .max_target_length )
13721359 )
13731360
1374- loop_state , bsw = self .init_states (inputs )
1361+ loop_state = self .init_states (inputs )
13751362
1376- # - Full spec (with circular_repeats axis) -> BSW creation via weight_prefetching.
1363+ # Two spec variants needed:
1364+ # - Full spec (with circular_repeats axis) -> BSW creation inside scan_body via
1365+ # from_all_variables_to_repeat_weights + from_repeat_weights_to_bsw.
13771366 # from_repeat_weights_to_bsw's derive_stage_weight_partition_specs drops the
13781367 # first dim (repeat), so the input must still have it.
13791368 # - Stripped logical spec (circular_repeats removed) -> BSW consumption via
@@ -1402,39 +1391,34 @@ def unbox_val(x):
14021391
14031392 _ , layers_params , layers_metrics , layers_mutables = nnx .split (layers_state , _is_static_param , nnx .Intermediate , ...)
14041393
1405- # Pre-populate bsw[1] with iteration-0 weights so the first scan_body
1406- # slide (next_bsw[0] = current_bsw[1]) picks up the correct weights.
1407- # bsw[0] is a zero placeholder — it is immediately discarded by the slide.
1408- init_repeat_weights = self .from_all_variables_to_repeat_weights (layers_params , 0 )
1409- init_w_curr = self .from_repeat_weights_to_bsw (init_repeat_weights , physical_partition_spec_full )
1410- bsw = (bsw [0 ], init_w_curr )
1411-
14121394 def scan_body (carry , _ ):
1413- current_loop_state , current_bsw , current_layer_mutables = carry
1395+ current_loop_state , current_layer_mutables = carry
14141396 # Fold loop_iteration into RNG keys so each scan step gets a unique
14151397 # dropout mask — mirrors Linen's nn.scan(split_rngs={"random": True}).
14161398 iteration = current_loop_state ["loop_iteration" ]
14171399 advanced_mutables = _advance_rng_state (current_layer_mutables , iteration )
14181400
1419- # 1. Async FSDP Prefetch — only gather NEXT repeat's weights (1 all-gather).
1420- # The current repeat's weights are already in current_bsw[1], carried
1421- # forward from the previous iteration's prefetch (sliding window).
1422- # Use FULL spec - weight_prefetching drops the repeat axis internally via
1423- # derive_stage_weight_partition_specs.
1424- next_weight = self .weight_prefetching (
1425- layers_params , physical_partition_spec_full , current_loop_state ["loop_iteration" ]
1426- )
1427- # Sliding window: previous next (current_bsw[1]) becomes current (bsw[0]),
1428- # freshly prefetched next_weight becomes next (bsw[1]).
1429- next_bsw = (current_bsw [1 ], next_weight )
1430- next_bsw = jax .ad_checkpoint .checkpoint_name (next_bsw , "bsw" )
1431-
1432- # 2. Run Forward & State Shift
1401+ # Compute BOTH current and next weights locally (2 all-gathers per iteration).
1402+ # BSW is NOT carried through scan — it is a body intermediate that
1403+ # jax.checkpoint discards between iterations, preventing OOM.
1404+ # Trade-off: 2 all-gathers/iter instead of 1 (no sliding window).
1405+ # Acceptable until REG-1 (nested scan + custom VJP) restores the optimization.
1406+ cur_repeat_weights = self .from_all_variables_to_repeat_weights (
1407+ layers_params , iteration )
1408+ cur_bsw = self .from_repeat_weights_to_bsw (
1409+ cur_repeat_weights , physical_partition_spec_full )
1410+ nxt_repeat_weights = self .from_all_variables_to_repeat_weights (
1411+ layers_params , iteration + 1 )
1412+ nxt_bsw = self .from_repeat_weights_to_bsw (
1413+ nxt_repeat_weights , physical_partition_spec_full )
1414+ bsw = (cur_bsw , nxt_bsw )
1415+
1416+ # Run Forward & State Shift
14331417 # Use STRIPPED logical spec - run_one_iteration re-derives physical from it,
14341418 # and get_current_weights_from_bsw expects specs without the repeat axis.
14351419 new_loop_state , new_layer_state = self .run_one_iteration (
14361420 current_loop_state ,
1437- next_bsw ,
1421+ bsw ,
14381422 layers_graph ,
14391423 layers_metrics ,
14401424 advanced_mutables ,
@@ -1446,7 +1430,7 @@ def scan_body(carry, _):
14461430 )
14471431
14481432 _ , _ , new_layer_metrics , new_layer_mutables = nnx .split (new_layer_state , _is_static_param , nnx .Intermediate , ...)
1449- return (new_loop_state , next_bsw , new_layer_mutables ), new_layer_metrics
1433+ return (new_loop_state , new_layer_mutables ), new_layer_metrics
14501434
14511435 if self .config .set_remat_policy_on_pipeline_iterations :
14521436 scan_body = jax .checkpoint (
@@ -1455,16 +1439,16 @@ def scan_body(carry, _):
14551439
14561440 # Memory Efficient Execution via pure JAX scan
14571441 if self .config .scan_pipeline_iterations :
1458- (loop_state , bsw , final_layer_mutables ), stacked_metrics = jax .lax .scan (
1459- scan_body , (loop_state , bsw , layers_mutables ), None , length = total_iterations
1442+ (loop_state , final_layer_mutables ), stacked_metrics = jax .lax .scan (
1443+ scan_body , (loop_state , layers_mutables ), None , length = total_iterations
14601444 )
14611445 else :
1462- current_carry = (loop_state , bsw , layers_mutables )
1446+ current_carry = (loop_state , layers_mutables )
14631447 metrics_history = []
14641448 for _ in range (total_iterations ):
14651449 current_carry , step_metrics = scan_body (current_carry , None )
14661450 metrics_history .append (step_metrics )
1467- loop_state , bsw , final_layer_mutables = current_carry
1451+ loop_state , final_layer_mutables = current_carry
14681452 stacked_metrics = jax .tree .map (lambda * xs : jnp .stack (xs ), * metrics_history ) if metrics_history else layers_metrics
14691453
14701454 final_layer_state = nnx .State .merge (layers_params , stacked_metrics , final_layer_mutables )
0 commit comments