Skip to content

Commit 7ad01c0

Browse files
Enable active-param and memory based Minitron pruning constraint + rich logging
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 50706d1 commit 7ad01c0

12 files changed

Lines changed: 1757 additions & 226 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Changelog
1919

2020
- Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|<list>``) and optional assistant-token ``loss_mask`` for answer-only-loss training.
2121
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
22+
- Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model.
2223

2324
0.44 (2026-05-xx)
2425
^^^^^^^^^^^^^^^^^

examples/megatron_bridge/README.md

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,18 @@ hf auth login --token <your token>
5353

5454
This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md).
5555

56-
Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults:
56+
The script supports three NAS-based pruning targets and one manual export mode:
57+
58+
| Mode | Flag | Description |
59+
| :---: | :---: | :--- |
60+
| NAS | `--prune_target_params` | Prune to a target total parameter count |
61+
| NAS | `--prune_target_active_params` | Prune to a target active parameter count (useful for MoE models). For non-MoE models, this is equivalent to `--prune_target_params`. |
62+
| NAS | `--prune_target_memory_mb` | Prune to a target memory footprint in MB (weights + KV-cache) for a given batch size and sequence length assuming BF16 precision |
63+
| Manual | `--prune_export_config` | Prune directly to a specified architecture config (no NAS). Useful if you want to take top K candidates and do a short knowledge distillation before selecting the best model. |
64+
65+
Multiple NAS targets can be combined — e.g. `--prune_target_params 6e9 --prune_target_memory_mb 12288` finds the best model with under 6B params and under 12GB memory footprint at (default) batch size 1 and sequence length 4096 assuming BF16 precision.
66+
67+
**Prune by total parameter count** — prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults:
5768
1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration,
5869
at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...),
5970
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.
@@ -67,8 +78,29 @@ torchrun --nproc_per_node 2 prune_minitron.py \
6778
--output_hf_path /tmp/Qwen3-8B-Pruned-6B
6879
```
6980

70-
Example usage for manually pruning to a specific architecture using following defaults:
71-
1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration.
81+
**Prune by active parameter count** — useful for MoE models where most experts are inactive per token (e.g. prune Nemotron-3-Nano-30B-A3B-BF16 (3.6B active params) to 3B active params):
82+
83+
```bash
84+
torchrun --nproc_per_node 2 prune_minitron.py \
85+
--pp_size 2 \
86+
--hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \
87+
--prune_target_active_params 3e9 \
88+
--output_hf_path /tmp/Nemotron-3-Nano-30B-A3B-BF16-Pruned-3B-Active
89+
```
90+
91+
**Prune by memory footprint** — prune to fit a target GPU memory budget (weights + KV-cache at the given sequence length and batch size, assuming BF16):
92+
93+
```bash
94+
torchrun --nproc_per_node 2 prune_minitron.py \
95+
--pp_size 2 \
96+
--hf_model_name_or_path Qwen/Qwen3-8B \
97+
--prune_target_memory_mb 12288 \
98+
--seq_length 4096 \
99+
--calib_mbs 1 \
100+
--output_hf_path /tmp/Qwen3-8B-Pruned-12GB
101+
```
102+
103+
**Manual pruning** — prune directly to a specified architecture (no NAS, no score evaluation):
72104

73105
```bash
74106
torchrun --nproc_per_node 2 prune_minitron.py \

examples/megatron_bridge/prune_minitron.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
# limitations under the License.
1515
"""Example script for pruning a GPT / Mamba model using Minitron algorithm on a Megatron-Bridge model (load from HF).
1616
17+
Supports three NAS-based pruning targets (can be combined):
18+
--prune_target_params Total parameter count (e.g. 6e9 for 6B total params)
19+
--prune_target_active_params Active parameter count for MoE models (e.g. 3e9 for 3B active params)
20+
--prune_target_memory_mb Memory footprint in MB (uses --seq_length for KV-cache estimate, assumes BF16)
21+
1722
Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2)
1823
while skipping pruning of num_attention_heads using following defaults:
1924
1024 samples from nemotron-post-training-dataset-v2 for calibration,
@@ -47,7 +52,7 @@
4752
import modelopt.torch.opt as mto
4853
import modelopt.torch.prune as mtp
4954
import modelopt.torch.utils.distributed as dist
50-
from modelopt.torch.utils import get_supported_datasets, num2hrb, print_rank_0, warn_rank_0
55+
from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0
5156
from modelopt.torch.utils.plugins.mbridge import (
5257
get_hf_mbridge_calibration_loop,
5358
load_mbridge_model_from_hf,
@@ -105,7 +110,6 @@ def get_args() -> argparse.Namespace:
105110
)
106111
parser.add_argument("--calib_gbs", type=int, default=1, help="Calibration global batch size")
107112
parser.add_argument("--seq_length", type=int, default=4096)
108-
109113
# Pruning parameters
110114
parser.add_argument(
111115
"--prune_intermediate_ckpt",
@@ -117,23 +121,40 @@ def get_args() -> argparse.Namespace:
117121
),
118122
)
119123

120-
target_group = parser.add_mutually_exclusive_group(required=True)
121-
target_group.add_argument(
124+
parser.add_argument(
122125
"--prune_export_config",
123126
type=str,
124127
help=(
125128
'Target pruned config as JSON e.g., \'{"hidden_size": 512, "ffn_hidden_size": 2048}\'. '
126129
f"Supported hyperparameters: {mtp.mcore_minitron.SUPPORTED_HPARAMS}. "
127-
"Cannot be used with --prune_target_params."
130+
"Cannot be combined with NAS-based targets."
128131
),
129132
)
130-
target_group.add_argument(
133+
parser.add_argument(
131134
"--prune_target_params",
132135
type=float,
133136
help=(
134-
"Target parameter count for pruning e.g., 6e9 for pruning to 6B params (total params, not active params). "
135-
"Uses Neural Architecture Search (NAS) to find the best pruned model that maximizes the --prune_score_func."
136-
"Cannot be used with --prune_export_config."
137+
"Target total parameter count e.g., 6e9 for 6B params. "
138+
"Uses NAS to find the best pruned model that maximizes --prune_score_func. "
139+
"Can be combined with --prune_target_active_params and/or --prune_target_memory_mb."
140+
),
141+
)
142+
parser.add_argument(
143+
"--prune_target_active_params",
144+
type=float,
145+
help=(
146+
"Target active parameter count e.g., 3e9 for 3B active params (useful for MoE models). "
147+
"Uses NAS to find the best pruned model that maximizes --prune_score_func. "
148+
"Can be combined with --prune_target_params and/or --prune_target_memory_mb."
149+
),
150+
)
151+
parser.add_argument(
152+
"--prune_target_memory_mb",
153+
type=float,
154+
help=(
155+
"Target memory footprint in MB (weights + KV-cache estimated via seq_length and calib_mbs; assumes BF16). "
156+
"Uses NAS to find the best pruned model that maximizes --prune_score_func. "
157+
"Can be combined with --prune_target_params and/or --prune_target_active_params."
137158
),
138159
)
139160

@@ -142,7 +163,7 @@ def get_args() -> argparse.Namespace:
142163
type=str,
143164
default="mmlu_10pct",
144165
help=(
145-
"Score function to use for NAS-based pruning (--prune_target_params). Only supports MMLU at the moment. "
166+
"Score function to use for NAS-based pruning. Only supports MMLU at the moment. "
146167
"Format: mmlu_<N>pct where <N> is the percentage of MMLU data to sample per subject "
147168
"(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)."
148169
),
@@ -152,7 +173,7 @@ def get_args() -> argparse.Namespace:
152173
type=int,
153174
default=None,
154175
help=(
155-
"hidden_size / ffn_hidden_size divisor for NAS-based pruning (--prune_target_params). "
176+
"hidden_size / ffn_hidden_size divisor for NAS-based pruning. "
156177
"Leave as None to use default divisors."
157178
),
158179
)
@@ -162,14 +183,14 @@ def get_args() -> argparse.Namespace:
162183
default=0.4,
163184
help=(
164185
f"Maximum width pruning percentage ({mtp.mcore_minitron.SUPPORTED_HPARAMS - {'num_layers'}}) "
165-
"for NAS-based pruning (--prune_target_params)"
186+
"for NAS-based pruning"
166187
),
167188
)
168189
parser.add_argument(
169190
"--max_depth_pruning",
170191
type=float,
171192
default=0.2,
172-
help="Maximum depth pruning percentage ('num_layers') for NAS-based pruning (--prune_target_params)",
193+
help="Maximum depth pruning percentage ('num_layers') for NAS-based pruning",
173194
)
174195
parser.add_argument(
175196
"--hparams_to_skip",
@@ -178,7 +199,7 @@ def get_args() -> argparse.Namespace:
178199
default=[],
179200
choices=mtp.mcore_minitron.SUPPORTED_HPARAMS,
180201
help=(
181-
"Space-separated list of hparams to skip for NAS-based pruning (--prune_target_params) "
202+
"Space-separated list of hparams to skip for NAS-based pruning "
182203
"e.g. dont prune 'num_attention_heads'"
183204
),
184205
)
@@ -187,13 +208,27 @@ def get_args() -> argparse.Namespace:
187208
type=int,
188209
default=10,
189210
help=(
190-
"Number of top candidates to consider for NAS-based pruning (--prune_target_params). "
211+
"Number of top candidates to consider for NAS-based pruning. "
191212
"Higher values will take longer to prune but may find a better model."
192213
),
193214
)
194215

195216
args = parser.parse_args()
196217

218+
# Validate pruning target arguments
219+
_nas_targets = [
220+
args.prune_target_params,
221+
args.prune_target_active_params,
222+
args.prune_target_memory_mb,
223+
]
224+
if args.prune_export_config and any(t is not None for t in _nas_targets):
225+
parser.error("--prune_export_config cannot be combined with NAS-based targets.")
226+
if not args.prune_export_config and not any(t is not None for t in _nas_targets):
227+
parser.error(
228+
"At least one of --prune_export_config, --prune_target_params,"
229+
" --prune_target_active_params, or --prune_target_memory_mb is required."
230+
)
231+
197232
# Post-process arguments
198233
if args.prune_intermediate_ckpt is None:
199234
if args.output_megatron_path:
@@ -250,11 +285,6 @@ def main(args: argparse.Namespace):
250285
init_model_parallel=True,
251286
moe_grouped_gemm=False,
252287
)
253-
print_rank_0(f"\nPruning model (showing PP rank0): {unwrapped_model}")
254-
print_rank_0(
255-
f"Original model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}"
256-
)
257-
258288
forward_loop = get_hf_mbridge_calibration_loop(
259289
model=model,
260290
provider=provider,
@@ -271,10 +301,20 @@ def main(args: argparse.Namespace):
271301
"forward_loop": forward_loop,
272302
"checkpoint": args.prune_intermediate_ckpt,
273303
}
274-
if args.prune_target_params is not None:
275-
# Restrict search space to a smaller set of candidates
276-
# Allow more choices for MoE FFN as they are generally smaller
277-
# NOTE: You can reduce the divisors and increase config['top_k'] to potentially find a better model.
304+
if args.prune_export_config is not None:
305+
# Less restrictive search space for manual pruning
306+
ss_config = mtp.mcore_minitron.get_mcore_minitron_config(
307+
hidden_size_divisor=64,
308+
ffn_hidden_size_divisor=64,
309+
mamba_head_dim_divisor=8,
310+
num_moe_experts_divisor=8,
311+
num_layers_divisor=1,
312+
)
313+
pruning_constraints = {"export_config": args.prune_export_config}
314+
else:
315+
# NAS-based pruning: restrict search space to a smaller set of candidates.
316+
# Allow more choices for MoE FFN as they are generally smaller.
317+
# NOTE: Reduce divisors and increase config['top_k'] to potentially find a better model.
278318
hidden_size_divisor = args.ss_channel_divisor if args.ss_channel_divisor else 256
279319
ffn_hidden_size_divisor = (
280320
args.ss_channel_divisor
@@ -290,7 +330,14 @@ def main(args: argparse.Namespace):
290330
)
291331
print_rank_0(f"Using search space config: {ss_config}")
292332

293-
pruning_constraints = {"params": args.prune_target_params}
333+
pruning_constraints = {}
334+
if args.prune_target_params is not None:
335+
pruning_constraints["params"] = args.prune_target_params
336+
if args.prune_target_active_params is not None:
337+
pruning_constraints["active_params"] = args.prune_target_active_params
338+
if args.prune_target_memory_mb is not None:
339+
pruning_constraints["memory_mb"] = args.prune_target_memory_mb
340+
294341
print_rank_0(
295342
f"Using NAS-based automatic pruning with score function: {args.prune_score_func}. "
296343
"You can change this to be any other metric you want to maximize (e.g. negative validation loss)."
@@ -313,17 +360,9 @@ def score_func(m):
313360
pruning_config["max_depth_pruning"] = args.max_depth_pruning
314361
pruning_config["hparams_to_skip"] = args.hparams_to_skip
315362
pruning_config["top_k"] = args.top_k
316-
elif args.prune_export_config is not None:
317-
# Less restrictive search space for manual pruning
318-
ss_config = mtp.mcore_minitron.get_mcore_minitron_config(
319-
hidden_size_divisor=64,
320-
ffn_hidden_size_divisor=64,
321-
mamba_head_dim_divisor=8,
322-
num_moe_experts_divisor=8,
323-
num_layers_divisor=1,
324-
)
325-
326-
pruning_constraints = {"export_config": args.prune_export_config}
363+
# memory_mb constraint requires batch_size and seq_length
364+
pruning_config["batch_size"] = args.calib_mbs
365+
pruning_config["seq_length"] = args.seq_length
327366
print_rank_0(f"Pruning constraints: {pruning_constraints}")
328367

329368
unwrapped_model, pruning_scores = mtp.prune( # in-place pruning
@@ -343,10 +382,6 @@ def score_func(m):
343382
else "hybrid_layer_pattern"
344383
)
345384
setattr(provider, hybrid_key, getattr(unwrapped_model, hybrid_key))
346-
print_rank_0(f"\nPruned model (showing PP rank0): {unwrapped_model}")
347-
print_rank_0(
348-
f"Pruned model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}"
349-
)
350385

351386
if args.output_megatron_path is not None:
352387
print_rank_0(

examples/pruning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ If your model parameters are already sorted and you just want to prune the weigh
179179

180180
| **Algorithm** | **Model** | **Pruning Constraints** |
181181
| :---: | :---: | :---: |
182-
| Minitron | Megatron-core (M-LM, M-Bridge) based GPT / Mamba / MoE / Hybrid LLM Models<sup>1</sup> | **Manual:** `export_config` with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) pruned values<br>**Auto:** `params` (requires `score_func` in config) |
182+
| Minitron | Megatron-core (M-LM, M-Bridge) based GPT / Mamba / MoE / Hybrid LLM Models<sup>1</sup> | **Manual:** `export_config` with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) pruned values<br>**Auto:** one or more of `params`, `active_params`, `memory_mb` (requires `score_func` in config) |
183183
| FastNAS | Computer Vision models | `flops`, `params` |
184184
| GradNAS | HuggingFace BERT, GPT-J | `flops`, `params` |
185185

modelopt/torch/nas/plugins/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121

2222
with import_plugin("megatron"):
2323
from .megatron import *
24-
25-
with import_plugin("transformer engine"):
26-
from .transformer_engine import *
24+
from .megatron_model_stats import *
2725

2826
with import_plugin("transformers"):
2927
from .transformers import *

0 commit comments

Comments
 (0)