Skip to content

Commit 7d791db

Browse files
committed
address second round of claude review
- training_loop: skip train_dataloader on non-master ranks; batch the per-iter pickle-based all_gather_object to log-chunk boundaries; init input_ids on device. - bypass_checkpoint_utils: atomic latest-symlink replacement via tmp+rename. - sewing_kit/utils: rewrite batched_normalized_mse_loss to relative-L2 (sum((x-t)^2) / (sum(t^2)+eps)); drop epsilon-offset-inside-MSE trick. - child_init: hand each pruning mixin its own keys_to_remove copy, merge after, so composition order can't corrupt the state dict. - puzzletron_nas_plugin: warn with the offending directory path when reusing existing activation scores / pruned checkpoints. - pruning_utils: probe llm_config in _lm_attrs for InternVL-style configs. - hydra_utils: route 5-arg warmup_steps calls with a fractional 4th arg to the legacy 4-arg path to avoid ZeroDivisionError. - tests: relative-L2 properties (scale-invariance, zero-both, finiteness on zero target); new regression test pinning FunctionTarget kwarg dispatch. Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
1 parent d3e2077 commit 7d791db

9 files changed

Lines changed: 309 additions & 94 deletions

File tree

modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Checkpoint utilities for bypass distillation."""
1717

18+
import os
1819
import re
1920
from collections import OrderedDict
2021
from pathlib import Path
@@ -222,10 +223,15 @@ def save_bypass_checkpoint(
222223
save_checkpoint_from_shards(model=model, checkpoint_dir=checkpoint_dir, descriptor=descriptor)
223224

224225
if dist.is_master():
225-
# Create 'latest' symlink
226+
# Create 'latest' symlink via tmp-symlink + atomic rename so concurrent
227+
# readers on a shared filesystem never observe a missing `latest`. The
228+
# plain unlink + symlink_to pair leaves a brief window where the link
229+
# doesn't exist; Path.replace (== os.replace) is atomic on POSIX.
226230
latest_symlink = Path(cfg.bypass.experiment_dir) / "latest"
227-
latest_symlink.unlink(missing_ok=True)
228-
latest_symlink.symlink_to(checkpoint_dir.name)
231+
tmp_symlink = latest_symlink.with_name(f".latest_tmp_{os.getpid()}")
232+
tmp_symlink.unlink(missing_ok=True)
233+
tmp_symlink.symlink_to(checkpoint_dir.name)
234+
tmp_symlink.replace(latest_symlink)
229235
# Save config args json
230236
json_dump(cfg.bypass, checkpoint_dir / "args.json")
231237
# Save completed file

modelopt/torch/puzzletron/bypass_distillation/training_loop.py

Lines changed: 95 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
from .bypass_checkpoint_utils import find_latest_run_dir, load_local_state, save_bypass_checkpoint
5959
from .bypass_utils import get_distributed_modules_ownership, set_experiment_dir, set_experiment_id
60-
from .data_classes import GlobalRank, IterNum, IterStatistics, LocalTrainingStats, TimeToSaveSignal
60+
from .data_classes import GlobalRank, IterNum, IterStatistics, TimeToSaveSignal
6161
from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership
6262

6363
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -118,6 +118,30 @@ def launch_bypass_distillation(hydra_cfg: DictConfig) -> None:
118118
mprint("Bypass distillation sweep completed")
119119

120120

121+
def _flush_loss_buffer(
122+
local_buffer: dict[int, dict[str, float]],
123+
stitched_losses_history: Optional[dict[int, dict[str, float]]],
124+
) -> None:
125+
"""All-gather buffered per-iter losses and merge into master's history.
126+
127+
Pickle-based ``all_gather_object`` was previously called on every micro-batch;
128+
batching to log-chunk boundaries reduces that cost ~``iters_per_log_chunk``×.
129+
All ranks must call this so the collective doesn't deadlock; only master
130+
actually accumulates into ``stitched_losses_history``.
131+
"""
132+
if not local_buffer:
133+
return
134+
gathered: list[Optional[dict[int, dict[str, float]]]] = [None] * dist.size()
135+
torch.distributed.all_gather_object(gathered, local_buffer)
136+
if dist.is_master():
137+
assert stitched_losses_history is not None
138+
for rank_buf in gathered:
139+
if rank_buf is None:
140+
continue
141+
for it, losses in rank_buf.items():
142+
stitched_losses_history.setdefault(it, {}).update(losses)
143+
144+
121145
def train(
122146
cfg: DictConfig,
123147
descriptor: ModelDescriptor,
@@ -126,7 +150,7 @@ def train(
126150
teacher_stitched_model: StitchedModule,
127151
stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor],
128152
stitched_modules_process_ownership: StitchedModulesProcessOwnership,
129-
train_dataloader: DataLoader,
153+
train_dataloader: Optional[DataLoader],
130154
val_dataloader: Optional[DataLoader],
131155
student_model_config: PretrainedConfig,
132156
skip_first_batches: int = 0,
@@ -211,13 +235,18 @@ def train(
211235
f"Grad scaling status: {'enabled' if cfg.bypass.training.use_grad_scaling else 'disabled'}"
212236
)
213237

214-
train_iterator = iter(train_dataloader)
238+
# Only master consumes the dataloader — `next(train_iterator)` is gated by
239+
# `if dist.is_master()` further down. Building the iterator (or running
240+
# skip_first_batches against it) on non-master ranks wastes startup time
241+
# and memory proportional to the dataset, since each tokenizes the full
242+
# corpus only to throw it away.
243+
train_iterator = iter(train_dataloader) if dist.is_master() else None
215244

216245
# Advance past the first `skip_first_batches` batches before the training loop
217246
# starts. Used either to skip a known-bad batch range during debugging, or to
218247
# roll the data iterator forward when resuming a run (model + optimizer state
219248
# are restored from the checkpoint, but the dataloader itself starts fresh).
220-
if skip_first_batches > 0:
249+
if dist.is_master() and skip_first_batches > 0:
221250
mprint(f"Skipping first {skip_first_batches} batches before training")
222251
for _ in range(skip_first_batches):
223252
next(train_iterator)
@@ -233,8 +262,21 @@ def train(
233262
best_steps_by_name: dict[str, int] = dict(cfg.bypass.get("best_steps_by_name", {}))
234263
# Anchor for the "Δ from initial" column: per-block loss from the first log chunk.
235264
initial_losses_by_name: dict[str, float] = dict(cfg.bypass.get("initial_losses_by_name", {}))
236-
# Buffer variables
237-
input_ids = torch.zeros(1, 1, dtype=torch.int64)
265+
266+
# log_interval is in optimizer-step units; multiply by grad_accum to land in
267+
# micro-batch units, which is what the per-iter loss collection counts.
268+
iters_per_log_chunk = (
269+
cfg.bypass.training.log_interval * cfg.bypass.training.grad_accumulation_steps
270+
)
271+
# Per-rank local buffer of {iter_num: {block_name: loss}}. We accumulate
272+
# losses locally on every rank and only collide them via all_gather_object
273+
# at log-chunk boundaries — the object collective is pickle-based and
274+
# was previously the per-iter sync cost. See `_flush_loss_buffer` below.
275+
local_losses_buffer: dict[int, dict[str, float]] = {}
276+
# Buffer variables. Initialise on the active device so non-master ranks
277+
# never hand a CPU tensor to a downstream GPU op if the master-only-fetch
278+
# invariant is ever relaxed (today only master replaces this in the loop).
279+
input_ids = torch.zeros(1, 1, dtype=torch.int64, device=device)
238280

239281
aprint(
240282
f"previous rank: {str(prev_rank):<5} next rank: {str(next_rank):<5} {owned_stitched_module_indices=}"
@@ -247,6 +289,11 @@ def train(
247289
# and incremented at the END of each iteration, so we must use `>` (not `>=`)
248290
# to ensure step `max_steps` itself runs before exiting.
249291
if cfg.bypass.step_num > cfg.bypass.training.max_steps:
292+
# Drain any residual buffered losses (< log-chunk boundary) so the
293+
# final partial chunk's stats reach master and can be logged before
294+
# the function returns. Must run on every rank — collective op.
295+
_flush_loss_buffer(local_losses_buffer, stitched_losses_history)
296+
local_losses_buffer.clear()
250297
if (
251298
cfg.bypass.model.model_overrides.save_checkpoint_when_done
252299
and not cfg.bypass.disable_checkpoint_save
@@ -386,25 +433,17 @@ def train(
386433
else:
387434
iter_stitched_module_losses = {}
388435

389-
# Collect losses from all ranks using all_gather_object
390-
local_training_stats = LocalTrainingStats(
391-
iter_num=cfg.bypass.iter_num,
392-
stitched_module_losses=iter_stitched_module_losses,
393-
)
394-
all_training_stats = [None] * dist.size()
395-
torch.distributed.all_gather_object(all_training_stats, local_training_stats)
396-
397-
if dist.is_master():
398-
if cfg.bypass.iter_num == resumed_iter_num:
399-
mprint(f"Starting from iter {cfg.bypass.iter_num}")
436+
if dist.is_master() and cfg.bypass.iter_num == resumed_iter_num:
437+
mprint(f"Starting from iter {cfg.bypass.iter_num}")
400438

401-
# Merge all stats into the losses history
402-
assert stitched_losses_history is not None
403-
merged_losses: dict[str, float] = {}
404-
for stats in all_training_stats:
405-
if stats is not None:
406-
merged_losses.update(stats.stitched_module_losses)
407-
stitched_losses_history[cfg.bypass.iter_num] = merged_losses
439+
# Buffer this rank's per-block losses locally. The collide-across-ranks
440+
# gather happens only at log-chunk boundaries (`_flush_loss_buffer`),
441+
# which cuts the per-iter pickle-based all_gather_object cost down to
442+
# one gather per `iters_per_log_chunk` micro-batches.
443+
local_losses_buffer[cfg.bypass.iter_num] = iter_stitched_module_losses
444+
if len(local_losses_buffer) >= iters_per_log_chunk:
445+
_flush_loss_buffer(local_losses_buffer, stitched_losses_history)
446+
local_losses_buffer.clear()
408447

409448
cfg.bypass.token_count += cfg.bypass.training.tokens_per_iter
410449
iter_t1 = time.time()
@@ -441,11 +480,9 @@ def train(
441480
# Logging
442481
if dist.is_master():
443482
assert stitched_losses_history is not None
444-
# log_interval is in optimizer-step units; the underlying history is
445-
# per-iter (micro-batch), so the chunk window is grad_accum × wider.
446-
iters_per_log_chunk = (
447-
cfg.bypass.training.log_interval * cfg.bypass.training.grad_accumulation_steps
448-
)
483+
# `iters_per_log_chunk` is computed once before the loop (in
484+
# micro-batch units = log_interval × grad_accum) and reused for
485+
# both the gather-batching threshold and this log drain.
449486
while len(stitched_losses_history) >= iters_per_log_chunk:
450487
lowest_iter = next(iter(stitched_losses_history.keys()))
451488

@@ -830,23 +867,37 @@ def run_bypassed_training(cfg: DictConfig):
830867
load_streaming_fn if not cfg.bypass.data.load_from_disk else load_from_disk_fn
831868
)
832869

833-
train_dataloader = create_train_dataloader(
834-
seed=seed,
835-
tokenizer=tokenizer,
836-
block_size=cfg.bypass.data.block_size,
837-
dataset_path=cfg.dataset_path,
838-
content_field=cfg.bypass.data.data_column,
839-
fim_rate=cfg.bypass.data.fim_rate,
840-
fim_spm_rate=cfg.bypass.data.fim_spm_rate,
841-
micro_batch_size=cfg.bypass.training.micro_batch_size,
842-
load_dataset_fn=load_dataset_fn,
843-
keep_in_memory=cfg.bypass.data.keep_in_memory,
844-
source_datasets_to_discard=cfg.bypass.data.get("source_datasets_to_discard", tuple()),
845-
bos_rate=cfg.bypass.data.bos_rate,
846-
shuffle_seed=cfg.bypass.data.shuffle_train_data_seed,
847-
)
870+
# Only master ever fetches from the train dataloader (training_loop.train
871+
# gates `next(train_iterator)` on `dist.is_master()`), so skip the
872+
# potentially-large HF dataset load + tokenisation on non-master ranks.
873+
if dist.is_master():
874+
train_dataloader = create_train_dataloader(
875+
seed=seed,
876+
tokenizer=tokenizer,
877+
block_size=cfg.bypass.data.block_size,
878+
dataset_path=cfg.dataset_path,
879+
content_field=cfg.bypass.data.data_column,
880+
fim_rate=cfg.bypass.data.fim_rate,
881+
fim_spm_rate=cfg.bypass.data.fim_spm_rate,
882+
micro_batch_size=cfg.bypass.training.micro_batch_size,
883+
load_dataset_fn=load_dataset_fn,
884+
keep_in_memory=cfg.bypass.data.keep_in_memory,
885+
source_datasets_to_discard=cfg.bypass.data.get(
886+
"source_datasets_to_discard", tuple()
887+
),
888+
bos_rate=cfg.bypass.data.bos_rate,
889+
shuffle_seed=cfg.bypass.data.shuffle_train_data_seed,
890+
)
891+
else:
892+
train_dataloader = None
848893

849894
val_dataloader = None
895+
# Note: val_dataloader is kept constructed on every rank even though only
896+
# master reads from it inside calculate_losses_pipeline. The validation
897+
# block uses `val_dataloader is not None` as a "validation enabled" gate
898+
# that must agree across ranks — and calculate_losses_pipeline itself is
899+
# pipeline-parallel and requires every rank to enter it. Skipping
900+
# construction on non-master ranks would break those invariants.
850901
if not cfg.bypass.disable_validation:
851902
val_dataloader = create_validation_dataloader(
852903
accelerator=None,

modelopt/torch/puzzletron/pruning/pruning_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def _lm_attrs(config):
7272
7373
VL configs nest language-model fields like ``num_attention_heads``, ``head_dim``,
7474
and ``hidden_size`` under a sub-config. The attribute name varies by family —
75-
``text_config`` (Qwen3-VL, Llava, Idefics) and ``language_config`` (Llama-4 and
76-
a handful of others) are both common. Probe both before falling back to the
77-
raw config.
75+
``text_config`` (Qwen3-VL, Llava, Idefics), ``language_config`` (Llama-4 and a
76+
handful of others), and ``llm_config`` (InternVL and friends) are all common.
77+
Probe each before falling back to the raw config.
7878
"""
79-
for attr in ("text_config", "language_config"):
79+
for attr in ("text_config", "language_config", "llm_config"):
8080
sub = getattr(config, attr, None)
8181
if sub is not None:
8282
return sub

modelopt/torch/puzzletron/puzzletron_nas_plugin.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv
219219
activations_log_dir = Path(hydra_cfg.pruning.activations_log_dir)
220220
if activations_log_dir.exists() and any(activations_log_dir.glob("rank_*.pth")):
221221
mprint(
222-
f"Puzzletron Progress {score_step}/{N}: pruning activation scores already exist, skipping scoring"
222+
f"Puzzletron Progress {score_step}/{N}: pruning activation scores already "
223+
f"exist at {activations_log_dir} — delete this directory to re-score with "
224+
f"the current config."
223225
)
224226
dist.barrier()
225227
else:
@@ -231,7 +233,9 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv
231233
if dist.is_master():
232234
if pruned_ckpts_dir.exists() and any(pruned_ckpts_dir.iterdir()):
233235
mprint(
234-
f"Puzzletron Progress {prune_step}/{N}: pruned checkpoints already exist, skipping pruning"
236+
f"Puzzletron Progress {prune_step}/{N}: pruned checkpoints already "
237+
f"exist at {pruned_ckpts_dir} — delete this directory to re-prune with "
238+
f"the current config."
235239
)
236240
else:
237241
mprint(

modelopt/torch/puzzletron/sewing_kit/utils.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import torch._dynamo
3636
import torch.distributed
3737
import torch.nn as nn
38-
import torch.nn.functional as F
3938
import torch.utils._pytree as pytree
4039
from torch import Tensor
4140
from torch._subclasses import FakeTensor, FakeTensorMode
@@ -483,29 +482,14 @@ def batched_normalized_mse_loss(
483482
epsilon: float = 1e-6,
484483
batch_dims: Sequence[int] = (0,),
485484
) -> torch.Tensor:
486-
"""Like normalized_mse_loss, but normalization is done on non-batch dims, then averaged.
487-
488-
Useful when activations within a batch item should be normalized independently
489-
rather than normalizing across the full batch.
490-
491-
Note: this slightly diverges from the original Puzzle implementation. With
492-
per-batch-element normalization, an all-zero target slice produces a
493-
denominator of ``epsilon ** 2 ~= 1e-12``, which then explodes the loss for
494-
that slice (the global-reduction variant in ``normalized_mse_loss`` dilutes
495-
it across non-zero elements, hiding the issue). We clamp the denominator
496-
to a floor of ``epsilon`` so the per-element minimum matches the intent of
497-
the epsilon term. The clamp only triggers on near-zero target slices —
498-
typical activations are unaffected.
499-
500-
The denominator uses ``MSE(target, epsilon_tensor)`` rather than
501-
``mean(target ** 2)`` for consistency with ``normalized_mse_loss``; the
502-
``clamp(min=epsilon)`` below already handles zero-target slices, so the
503-
epsilon offset inside the MSE is redundant but harmless at ``1e-6``.
485+
"""Per-batch-element relative-L2 loss.
486+
487+
For each batch element, computes ``||input - target||^2 / (||target||^2 + eps)``
488+
over the non-batch dims, then averages across batch elements. The additive
489+
``epsilon`` in the denominator handles all-zero target slices without a hard
490+
clamp and makes the loss scale-invariant when ``||target||^2 >> eps``.
504491
"""
505-
norm_dims = list(set(range(input.ndim)) - set(batch_dims))
506-
norm_of_target_vectors = F.mse_loss(
507-
target, torch.zeros_like(target) + epsilon, reduction="none"
508-
).mean(norm_dims)
509-
norm_of_target_vectors = norm_of_target_vectors.clamp(min=epsilon)
510-
loss = F.mse_loss(input, target, reduction="none").mean(norm_dims) / norm_of_target_vectors
511-
return loss.mean()
492+
norm_dims = [d for d in range(input.ndim) if d not in batch_dims]
493+
num = ((input - target) ** 2).sum(dim=norm_dims)
494+
den = (target**2).sum(dim=norm_dims) + epsilon
495+
return (num / den).mean()

modelopt/torch/puzzletron/tools/bypassed_training/child_init.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,14 @@ def _process_single_layer(
8686
# Delegate to pruning_mixin if available (supports a single mixin or a list of mixins).
8787
# When the bypass factory composes multiple mixins (e.g. experts_removal + kv_heads),
8888
# it passes them as a list so each can contribute its slice of the layer state dict.
89+
# Each mixin gets its own copy of keys_to_remove and the unions are merged afterward,
90+
# so ordering between mixins can't corrupt the state dict even if a future pair of
91+
# mixins ever happens to touch overlapping keys.
8992
if pruning_mixin is not None:
9093
_mixins = pruning_mixin if isinstance(pruning_mixin, list) else [pruning_mixin]
94+
merged_keys_to_remove = dict(keys_to_remove)
9195
for _mixin in _mixins:
96+
mixin_keys = dict(keys_to_remove)
9297
_layer_out = _mixin.prune_single_layer(
9398
layer_idx=layer_idx,
9499
parent_state_dict=parent_state_dict,
@@ -104,10 +109,11 @@ def _process_single_layer(
104109
is_original_mha=is_original_mha,
105110
head_size=head_size,
106111
hidden_size=hidden_size,
107-
keys_to_remove=keys_to_remove,
112+
keys_to_remove=mixin_keys,
108113
)
109114
layer_out_state_dict.update(_layer_out)
110-
return layer_out_state_dict, keys_to_remove
115+
merged_keys_to_remove.update(mixin_keys)
116+
return layer_out_state_dict, merged_keys_to_remove
111117

112118
# Legacy inline processing (fallback when no pruning_mixin)
113119

modelopt/torch/puzzletron/tools/hydra_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def _warmup_steps_resolver(*args):
5454
if len(args) == 4:
5555
t, b, m, p = args
5656
return warmup_steps(t, b, m, pct=p)
57+
# A 5-arg call where the 4th arg is a fractional float almost certainly
58+
# means `pct` landed in the `grad_accum` slot — `int(0.05) == 0` would
59+
# later raise ZeroDivisionError inside `warmup_steps`. Treat it as legacy.
60+
if len(args) == 5 and isinstance(args[3], float) and args[3] < 1.0:
61+
t, b, m, p, _ = args
62+
return warmup_steps(t, b, m, pct=p)
5763
return warmup_steps(*args)
5864

5965

0 commit comments

Comments
 (0)