Skip to content

Commit 00c7d13

Browse files
committed
Split bypass Puzzletron integration
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
1 parent 470fe16 commit 00c7d13

21 files changed

Lines changed: 2769 additions & 64 deletions
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Bypass Distillation Tutorial: Nemotron-3-Nano-30B-A3B (KV-heads-only)
2+
3+
A minimal end-to-end demonstration that **bypass distillation improves quality** at the same compression budget. The setup is a **toy pruning task on a real production model** — we compress only KV heads (12 → 9, a modest 25% reduction) so a single comparison surfaces the bypass benefit cleanly without needing extensive downstream evaluation. The model itself ([Nemotron-3-Nano-30B-A3B-Base-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16)) is a real 30B-A3B MoE-Mamba hybrid, not a tiny stand-in.
4+
5+
## What this tutorial does
6+
7+
The teacher has 6 attention layers (each with `num_key_value_heads=2`) interleaved between Mamba and MoE-FFN blocks — **12 KV heads total** across the whole model. We compress to **9 KV heads (75% of teacher)** in two ways and compare:
8+
9+
1. **Without bypass** — replacement library uses Truncate-init weights (KV heads sliced from teacher; no further training).
10+
2. **With bypass** — the bypass step runs ~50M tokens of per-block knowledge distillation, training a 1-KV-head variant per attention layer against the teacher.
11+
12+
Both runs use the same MIP solver and the same constraint (`target_num_kv_heads: 9`), so MIP picks per attention layer from `{teacher 2-head, 1-head}`. FFN/MoE/Mamba blocks are copied verbatim from the teacher in both runs — only attention weights change.
13+
14+
**Metrics:** `lm_loss` and `token_accuracy_top_1` measured against the same held-out dataset by the realize-model step (printed automatically to `puzzle_dir/log.txt`).
15+
16+
## Hardware & install
17+
18+
- 8×H100 80GB (the teacher needs ≥60 GiB for activation scoring on a 4096 context).
19+
- Container: `nvcr.io/nvidia/nemo:26.04` or later.
20+
- `pip install -e ".[dev]"` from the modelopt repo root.
21+
- Mamba kernels (required by Nemotron-3-Nano's hybrid backbone):
22+
23+
```bash
24+
pip install mamba-ssm[causal-conv1d] --no-build-isolation
25+
```
26+
27+
- HF auth set up so the model is downloadable: `huggingface-cli login`.
28+
29+
## Step A — pipeline without bypass
30+
31+
Edit `examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml` to point `puzzle_dir` and `dataset_path` at writable locations, then:
32+
33+
```bash
34+
torchrun --nproc_per_node=8 examples/puzzletron/main.py \
35+
--config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml
36+
```
37+
38+
This runs the 8-step puzzletron pipeline (convert → score pruning activations → prune → build replacement library → score replacements → MIP → realize). With `bypass:` added in Step B the pipeline grows to 9 steps; without it, the bypass step is skipped and progress prints `N/8`. Wall-clock: roughly **1h on 8×H100** for this KV-heads-only task (KV-head importance scoring is one forward pass via `IndependentKvHeadContributionHook`, much cheaper than iterative FFN-channel scoring).
39+
40+
When the realize-model step finishes, the log lines at `${puzzle_dir}/log.txt` contain:
41+
42+
```text
43+
validate_model_with_kl_div(model_name='teacher', ...)
44+
Average losses = {'lm_loss': ..., 'token_accuracy_top_1': ..., 'token_accuracy_top_5': ..., 'token_accuracy_top_10': ...}
45+
...
46+
validate_model_with_kl_div(model_name='solution_0', ...)
47+
Average losses = {..., 'token_accuracy_top_1': ..., ...}
48+
```
49+
50+
Record the teacher's `token_accuracy_top_1` and `solution_0`'s `token_accuracy_top_1`. **Move or rename `${puzzle_dir}/single_sequence_replacement_solutions--validation/` and `${puzzle_dir}/mip/` aside** before Step B if you want to keep the no-bypass artifacts — Step B reuses the same `puzzle_dir` and the library/scoring/MIP outputs will be overwritten.
51+
52+
## Step B — pipeline with bypass
53+
54+
Use the bypass-enabled config, which overrides the base config's empty `- bypass:` entry with `bypass: defaults`:
55+
56+
```yaml
57+
defaults:
58+
- nemotron-3-nano-30b-a3b
59+
- override bypass: defaults
60+
- _self_
61+
```
62+
63+
Run the bypass config:
64+
65+
```bash
66+
torchrun --nproc_per_node=8 examples/puzzletron/main.py \
67+
--config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b-with-bypass.yaml
68+
```
69+
70+
Skip-if-done caching reuses Step A's converted teacher checkpoint, activation scores, and pruned checkpoints. Only Step 5 (bypass distillation, ~50M tokens) and the downstream library/scoring/MIP rerun.
71+
72+
Bypass writes its outputs under `${puzzle_dir}/bypass/bypass_runs/<bypass_experiment_id>/` and creates a symlink `${puzzle_dir}/ckpts/<bypass_experiment_id>` that the replacement library builder picks up automatically.
73+
74+
Capture `solution_0`'s `token_accuracy_top_1` from the new realize-model log section.
75+
76+
## Results
77+
78+
Reducing total KV heads from 12 → 9 (75% of teacher) at fixed FFN/MoE/Mamba on Nemotron-3-Nano-30B-A3B-Base-BF16:
79+
80+
| Run | `target_num_kv_heads` | `lm_loss` | `token_accuracy_top_1` |
81+
|------------------------------|----------------------:|----------:|-----------------------:|
82+
| Teacher | 12 | 0.5950 | 0.8468 |
83+
| Pruned, **no bypass** (Truncate-init) | 9 | 0.6347 | 0.8373 |
84+
| Pruned, **with bypass** (50M-token BLD) | 9 | **0.6055**| **0.8441** |
85+
86+
**Bypass closes ~74% of the regression gap** at this compression budget:
87+
88+
- `lm_loss` gap to teacher: `0.0397` without bypass → `0.0105` with bypass — bypass recovers **74%**.
89+
- `token_accuracy_top_1` gap to teacher: `0.0095` without bypass → `0.0027` with bypass — bypass recovers **72%**.
90+
91+
For 50M tokens of per-block KD, that's a substantial lift on a real 30B-A3B teacher.
92+
93+
## Going further: full accuracy recovery
94+
95+
Bypass distillation is Stage 1 of the PUZZLE pipeline — local, per-block KD that tightens the replacement library. For larger compression targets (or more aggressive KV pruning) you'll want Stage 2: **global knowledge distillation** on the realized student. See [`examples/pruning/puzzletron/`](../pruning/puzzletron/) for the Megatron-Bridge recipe and concrete MMLU recovery numbers.

examples/puzzletron/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ To use the Puzzle algorithm effectively, we need to specify the target number of
1111

1212
In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md).
1313

14-
> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).
14+
> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100, Nemotron-3-Nano-30B-A3B-Base-BF16 on 8x H100 — see the [bypass distillation tutorial](Nemotron-3-Nano-30B-A3B-Base-BF16.md)). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).
1515
1616
## Environment
1717

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# @package bypass
2+
# Bypass Distillation Configuration
3+
# This config defines parameters for blockwise local distillation (BLD),
4+
# which trains alternative transformer block configurations using per-block
5+
# knowledge distillation from a teacher model.
6+
7+
# Runtime Configuration
8+
dtype: "bf16" # Model precision: bf16 for efficiency, fp32 for stability
9+
seed: 42 # Random seed for reproducibility
10+
11+
# Experiment Tracking
12+
experiment_id: # Unique identifier for this experiment. Will be dynamically set
13+
experiment_dir: # Directory for this experiment. Will be dynamically set
14+
iter_num: 1 # Current iteration number
15+
step_num: 1 # Current step number within iteration
16+
token_count: 0 # Token count tracker (auto-updated during training)
17+
18+
# Data Configuration
19+
data:
20+
data_column: "messages"
21+
block_size: 512 # Sequence length (tokens per sample)
22+
bos_rate: 0.5
23+
fim_rate: 0
24+
fim_spm_rate: 0
25+
source_datasets_to_discard: []
26+
load_from_disk: true # Load preprocessed data from disk or from stream
27+
keep_in_memory: false
28+
val_dataset_name: valid
29+
max_eval_samples: 4
30+
eval_samples_per_process: # Samples per GPU during distributed eval (auto if null)
31+
shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data
32+
33+
# Training Configuration
34+
training:
35+
learning_rate: 1e-4 # Initial learning rate (1e-4 = 0.0001)
36+
training_tokens: 1e+4 # Total training tokens (10K tokens - sanity check)
37+
micro_batch_size: 2
38+
val_micro_batch_size: 1
39+
warmup_ratio: 0.05
40+
warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.grad_accumulation_steps},${.warmup_ratio}} # Auto-calculated warmup steps
41+
min_lr_factor: 1e-5
42+
grad_accumulation_steps: 1
43+
skip_first_batches: 0 # Use for debugging or to skip few batches which cause crashes or optimization issues.
44+
weight_decay: 0.1
45+
decay_lr: true
46+
beta1: 0.9
47+
beta2: 0.95
48+
use_grad_scaling: false
49+
grad_clip: 1.0
50+
grad_clip_type: norm
51+
clipping_count: 0
52+
log_interval: 5
53+
eval_interval: 5
54+
55+
# Model Loading Configuration
56+
resume_checkpoint_path: # Path to resume training from checkpoint
57+
find_last_ckpt_for_resume: true # Auto-resume by finding last checkpoint (bool)
58+
parameter_count:
59+
init_checkpoint_path: # Path to initialize weights from
60+
61+
model:
62+
student_weights_dtype: "bf16" # Student model weight precision
63+
64+
model_overrides:
65+
delete_old_checkpoints: true # Clean up old checkpoints to save disk space
66+
save_interval_seconds: 12900 # Save checkpoint every ~3.5 hours
67+
save_interval: 1e+9 # Save checkpoint every 1B steps (effectively disabled)
68+
save_checkpoint_when_done: true # Save final checkpoint when training completes
69+
70+
# Architecture modifications for student model
71+
model_config_overrides:
72+
ffn:
73+
- intermediate_size:
74+
no_op: # Disable FFN entirely (true/false)
75+
attention:
76+
- num_key_value_heads: # Number of kv-heads (for GQA)
77+
no_op: # Disable attention entirely (true/false)
78+
79+
# Model Factory Configuration - Controls student model creation and initialization
80+
model_factory:
81+
factory: bypass_factory_fn # Unified factory supporting all layer types
82+
block_loss_func: normalized_mse_loss # Loss function for comparing teacher/student blocks. vectorwise_normalized_mse_loss / batched_normalized_mse_loss / normalized_mse_loss
83+
gqa_init_mode: AverageKV # How to initialize K/V heads in GQA. All options here: GQAInitMode
84+
mlp_init_mode: Truncate # MLP initialization. All options here: MlpInitMode
85+
mlp_init_config: # Configuration for MLP initialization (if needed)
86+
activations_log_dir: # Directory with activation statistics (required for PruneByActivationsLog)
87+
linear_init_mode: FromTeacher # How to initialize linear layers: FromTeacher, Random, etc.
88+
submodule_for_loss_calculation: # Specific submodule for loss calc.
89+
keys_to_learn: # Subblock(s) to train: entire_block, subblock_attention, subblock_ffn, subblock_mamba, or a list of those.
90+
91+
# Validation Configuration
92+
disable_initial_validate: false
93+
validate_teacher_model: true
94+
validate_student_model: true
95+
disable_validation: false # Enable validation to exercise all code paths
96+
best_val_loss: 1e+9 # Track best validation loss achieved
97+
98+
# Performance Optimization
99+
compile: false # Use PyTorch compilation
100+
disable_fa2: false # Disable Flash Attention 2 (false = use FA2 if available)
101+
teacher_model_load_on_cpu: false
102+
103+
# Checkpoint Management
104+
save_checkpoint_before_training: false # Save initial checkpoint before training
105+
disable_checkpoint_save: false # Disable all checkpoint saving
106+
save_best_ckpt: true # Save checkpoint when validation improves
107+
kill_after_first_save: false # Exit after first checkpoint save (for testing)
108+
realize_best_or_latest: "best"
109+
110+
wandb_log: false
111+
wandb:
112+
project:
113+
entity:
114+
115+
# Multiple bypass configurations to train sequentially.
116+
# Each entry overrides model.model_config_overrides and optionally model_factory.keys_to_learn.
117+
# If empty or absent, a single run uses the settings above.
118+
configs:
119+
- model_config_overrides:
120+
ffn:
121+
- intermediate_size: 3072
122+
attention:
123+
- num_key_value_heads: 8
124+
keys_to_learn: subblock_ffn
125+
- model_config_overrides:
126+
ffn:
127+
- intermediate_size: 5888
128+
attention:
129+
- num_key_value_heads: 8
130+
keys_to_learn: subblock_ffn
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
defaults:
2+
- pruning: kv_heads_pruning
3+
- scoring: ../validate_solutions_defaults
4+
- realize_model: ../validate_solutions_defaults
5+
- bypass:
6+
- override hydra/hydra_logging: disabled
7+
- _self_
8+
9+
puzzle_dir: ???
10+
descriptor: nemotron_h
11+
teacher_dir: ${puzzle_dir}/ckpts/teacher/
12+
replacement_library_path: ${puzzle_dir}/replacement_library.json
13+
dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2
14+
15+
skip_realize_model: false
16+
17+
# KV-heads-only pruning: lock off FFN/MoE-side variants. The replacement library
18+
# exposes {teacher 2-head, 1-head} per attention layer; FFN and Mamba
19+
# blocks are copied verbatim from the teacher.
20+
build_replacement_library:
21+
add_ffn_no_ops: false
22+
add_attention_no_ops: false
23+
24+
calc_subblock_stats:
25+
batch_sizes: [64, 96, 128]
26+
prefill_seq_len: 4096
27+
generation_seq_len: 4096
28+
num_active_tokens_override: # Optional override for sequence lengths
29+
prefill_queue_size: 0
30+
allocate_prefill_query: false
31+
benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
32+
merge_with_existing_stats: false
33+
subblock_stats_filename: "subblock_stats.json"
34+
moe_stats_filename: "moe_stats.json"
35+
runtime_stats:
36+
backend: trt_torch
37+
38+
scoring:
39+
descriptor: ${descriptor}
40+
solutions_to_validate:
41+
skip_existing_solutions: true
42+
43+
replacement_library_path: ${replacement_library_path}
44+
solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
45+
teacher_dir: ${to_path:${teacher_dir}}
46+
output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
47+
48+
eval_samples: 128
49+
micro_batch_size: 1
50+
seed: 42
51+
shuffle_seed: 444
52+
dataset_path: ${dataset_path}
53+
54+
mip:
55+
single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
56+
subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
57+
output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
58+
gathered_metrics_path:
59+
puzzle_profile:
60+
61+
objective: metrics.cosine_embedding_loss_hidden_states
62+
bigger_is_better: false
63+
64+
subblock_stats_args:
65+
- batch_size: 96
66+
weights_dtype: torch.bfloat16
67+
activations_dtype: torch.bfloat16
68+
kv_cache_dtype: torch.bfloat16
69+
70+
report_additional_costs:
71+
- stats.memory_mib
72+
- stats.num_params
73+
- stats.num_kv_heads
74+
- stats.has_attention
75+
- stats.has_ffn
76+
- stats.kv_cache_memory_mib
77+
- stats.attention_memory_mib
78+
- stats.ffn_memory_mib
79+
- stats.ffn_num_params
80+
- stats.attention_num_params
81+
82+
human_constraints:
83+
target_num_kv_heads: 9 # toy KV-heads-only target; see nemotron-3-nano-30b-a3b.yaml
84+
85+
mip_constraints:
86+
metric_overrides:
87+
max_seconds_per_solution: 60
88+
89+
realize_model:
90+
descriptor: ${descriptor}
91+
teacher_dir: ${to_path:${teacher_dir}}
92+
tokenizer_name: ${to_path:${teacher_dir}}
93+
replacement_library_path: ${replacement_library_path}
94+
save_models: true
95+
solutions_path: # Filled dynamically
96+
97+
# Validate params
98+
skip_validation: false
99+
eval_samples: 128
100+
micro_batch_size: 1
101+
seed: 42
102+
shuffle_seed: 444
103+
dataset_path: ${dataset_path}
104+
105+
nccl_timeout_minutes: ${timedelta_minutes:10}
106+
107+
# This section redirects Hydra outputs
108+
hydra:
109+
run:
110+
dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}

0 commit comments

Comments
 (0)