Skip to content

Commit f0eaa19

Browse files
Enable active-param and memory based Minitron pruning constraint (#1377)
### What does this PR do? Type of change: New feature, new tests, documentation. OMNIML-4108: Extends the Minitron NAS pruner to support pruning by **active parameter count** (`active_params`) and **memory footprint** (`memory_mb`) in addition to the existing total parameter count (`params`) constraint. Also adds standalone utilities for analytical model stats. #### Changes **New pruning constraint keys** - `active_params`: prune to a target number of active (routed) params — useful for MoE models where total ≫ active; when present, `active_params` is the **primary sort/display metric** for candidates (priority: `active_params` > `params` > `memory_mb`) - `memory_mb`: prune to fit a memory budget (BF16 weights + KV-cache + Mamba state at a given sequence length and batch size) - Constraints can be combined (AND logic): e.g. `{"params": 6e9, "memory_mb": 12288}` **New standalone utilities** (`modelopt.torch.nas.plugins.megatron_model_stats`) - `mcore_param_count`: analytically computes total and active parameter counts for GPT and Mamba/hybrid MCore models - `mcore_memory_footprint_mb`: estimates memory in MB (weights + KV-cache + Mamba state) - `print_mcore_model_stats`: rich-formatted model stats panel **Rich-formatted pruning logs** — search space, top-k candidate tables, and best subnet panel printed on rank 0 **`prune_score_func` format update** — now `mmlu_<N>pct_bs<bs>` (e.g. `mmlu_10pct_bs32`) to explicitly control batch size for MMLU evaluation; old `mmlu_<N>pct` format removed **Infrastructure** - NeMo container bumped to `nvcr.io/nvidia/nemo:26.04` in CI and docs - Added `examples/megatron_bridge/requirements.txt` with `transformers<5.0` (required for saving some Nemotron-3-Nano models) ### Usage ```python # Prune to 3B active params (MoE-aware) — active_params is the primary sort metric mtp.prune(model, mode=[("mcore_minitron", ss_config)], constraints={"active_params": 3e9}, config=pruning_config) # Prune to fit a 12 GB memory budget mtp.prune(model, mode=[("mcore_minitron", ss_config)], constraints={"memory_mb": 12288}, config=pruning_config) ``` ### Testing Pruned Nemotron-3-Nano-30B-A3B (31.6B, A3.6B) --> A3.0B. Takes <1hr on 8x H100 (more details in #1376) ```bash torchrun --nproc_per_node 8 examples/megatron_bridge/prune_minitron.py \ --pp_size 8 \ --hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ --trust_remote_code \ --prune_target_params 28e9 \ --prune_target_active_params 3e9 \ --hparams_to_skip num_attention_heads \ --seq_length 8192 \ --output_hf_path pruned/Nemotron-3-Nano-30B-A3B-Pruned-28B-A3B-top20-max15depth-max30width-mmlu_10pct_bs32 \ --top_k 20 \ --max_depth_pruning 0.15 \ --max_width_pruning 0.30 \ --prune_score_func mmlu_10pct_bs32 \ --num_layers_in_first_pipeline_stage 5 \ --num_layers_in_last_pipeline_stage 5 ``` ``` ╭──────────────────────────────────────────────────── Original Model Stats ─────────────────────────────────────────────────────╮ │ Total Parameters 31.58B │ │ Active Parameters 3.58B │ │ Memory (BF16, seq_length=8192, batch_size=1) weights: 60230.1 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 60301.9 MB │ ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ Top 20 Candidates with Scores ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┓ ┃ # ┃ export_config ┃ active_params ┃ params ┃ score ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━┩ │ 1 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 120, │ 3.00B │ 27.06B │ 0.3399 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 2 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.4650 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 3 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2343 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 4 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 56, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2552 │ │ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 5 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 21.61B │ 0.2601 │ │ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 6 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 19.28B │ 0.3762 │ │ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ │ 7 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 22.28B │ 0.4783 │ │ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 8 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 21.99B │ 0.2420 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ │ 9 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2399 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ │ 10 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 26.17B │ 0.2601 │ │ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ │ 11 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2503 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 12 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.4329 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 13 │ {'num_layers': 46, 'hidden_size': 2688, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 128, │ 3.00B │ 26.17B │ 0.2587 │ │ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 2816} │ │ │ │ │ 14 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2336 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 15 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2559 │ │ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 16 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 20.70B │ 0.4608 │ │ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ │ 17 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2455 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ │ 18 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 24.42B │ 0.2503 │ │ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │ │ 19 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 120, │ 3.00B │ 27.92B │ 0.2587 │ │ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │ │ 20 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2469 │ │ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │ └────┴───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┴───────────────┴────────┴────────┘ ╭──────────────────────────────────────────────────────────────────────── Best Subnet ─────────────────────────────────────────────────────────────────────────╮ │ export_config {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, 'moe_ffn_hidden_size': 1856, │ │ 'moe_shared_expert_intermediate_size': 3072} │ │ active_params 3.00B │ │ params 22.28B │ │ score 0.4783 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╭───────────────────────────────────────────────────── Pruned Model Stats ──────────────────────────────────────────────────────╮ │ Total Parameters 22.28B │ │ Active Parameters 3.00B │ │ Memory (BF16, seq_length=8192, batch_size=1) weights: 42489.7 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 42561.6 MB │ ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 84fe91b commit f0eaa19

13 files changed

Lines changed: 1697 additions & 282 deletions

File tree

.github/workflows/example_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
uses: ./.github/workflows/_example_tests_runner.yml
8787
secrets: inherit
8888
with:
89-
docker_image: "nvcr.io/nvidia/nemo:26.02"
89+
docker_image: "nvcr.io/nvidia/nemo:26.04"
9090
example: megatron_bridge
9191
timeout_minutes: 30
9292
pip_install_extras: "[hf,puzzletron,dev-test]"

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
**New Features**
1919

2020
- Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|<list>``) and optional assistant-token ``loss_mask`` for answer-only-loss training.
21+
- Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model.
2122
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
2223
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
2324

examples/megatron_bridge/README.md

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br
1616

1717
## Pre-Requisites
1818

19-
Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed.
19+
Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.04`) which has all the dependencies installed.
2020

2121
To get the ModelOpt examples scripts, mount your Model-Optimizer repo to the container as follows:
2222

@@ -26,7 +26,7 @@ if [ ! -d "${MODELOPT_DIR}" ]; then
2626
git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR}
2727
fi
2828

29-
export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02
29+
export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.04
3030
docker run \
3131
--gpus all \
3232
--shm-size=16GB \
@@ -49,11 +49,28 @@ hf auth login --token <your token>
4949
> [!WARNING]
5050
> Use `python -m pip` instead of `pip` to avoid conflicts with the system-wide installed packages in the NeMo containers. You may also refer to this [doc](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/docker/common/README.md#installing-packages-inside-the-container) on how to correctly install packages in the NeMo containers without breaking existing torch installation.
5151
52+
Also install additional dependencies from the [requirements.txt](./requirements.txt) file.
53+
54+
```bash
55+
python -m pip install -r requirements.txt
56+
```
57+
5258
## Pruning
5359

5460
This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md).
5561

56-
Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults:
62+
The script supports three NAS-based pruning targets and one manual export mode:
63+
64+
| Mode | Flag | Description |
65+
| :---: | :---: | :--- |
66+
| NAS | `--prune_target_params` | Prune to a target total parameter count |
67+
| NAS | `--prune_target_active_params` | Prune to a target active parameter count (useful for MoE models). For non-MoE models, this is equivalent to `--prune_target_params`. |
68+
| NAS | `--prune_target_memory_mb` | Prune to a target memory footprint in MB (weights + KV-cache) for a given batch size and sequence length assuming BF16 precision |
69+
| Manual | `--prune_export_config` | Prune directly to a specified architecture config (no NAS). Useful if you want to take top K candidates and do a short knowledge distillation before selecting the best model. |
70+
71+
Multiple NAS targets can be combined — e.g. `--prune_target_params 6e9 --prune_target_memory_mb 12288` finds the best model with under 6B params and under 12GB memory footprint at (default) batch size 1 and sequence length 4096 assuming BF16 precision.
72+
73+
**Prune by total parameter count** — prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults:
5774
1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration,
5875
at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...),
5976
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.
@@ -67,8 +84,29 @@ torchrun --nproc_per_node 2 prune_minitron.py \
6784
--output_hf_path /tmp/Qwen3-8B-Pruned-6B
6885
```
6986

70-
Example usage for manually pruning to a specific architecture using following defaults:
71-
1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration.
87+
**Prune by active parameter count** — useful for MoE models where most experts are inactive per token (e.g. prune Nemotron-3-Nano-30B-A3B-BF16 (3.6B active params) to 3B active params):
88+
89+
```bash
90+
torchrun --nproc_per_node 2 prune_minitron.py \
91+
--pp_size 2 \
92+
--hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \
93+
--prune_target_active_params 3e9 \
94+
--output_hf_path /tmp/Nemotron-3-Nano-30B-A3B-BF16-Pruned-3B-Active
95+
```
96+
97+
**Prune by memory footprint** — prune to fit a target GPU memory budget (weights + KV-cache at the given sequence length and batch size, assuming BF16):
98+
99+
```bash
100+
torchrun --nproc_per_node 2 prune_minitron.py \
101+
--pp_size 2 \
102+
--hf_model_name_or_path Qwen/Qwen3-8B \
103+
--prune_target_memory_mb 12288 \
104+
--seq_length 4096 \
105+
--calib_mbs 1 \
106+
--output_hf_path /tmp/Qwen3-8B-Pruned-12GB
107+
```
108+
109+
**Manual pruning** — prune directly to a specified architecture (no NAS, no score evaluation):
72110

73111
```bash
74112
torchrun --nproc_per_node 2 prune_minitron.py \

examples/megatron_bridge/prune_minitron.py

Lines changed: 107 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
# limitations under the License.
1515
"""Example script for pruning a GPT / Mamba model using Minitron algorithm on a Megatron-Bridge model (load from HF).
1616
17+
Supports three NAS-based pruning targets (can be combined):
18+
--prune_target_params Total parameter count (e.g. 6e9 for 6B total params)
19+
--prune_target_active_params Active parameter count for MoE models (e.g. 3e9 for 3B active params)
20+
--prune_target_memory_mb Memory footprint in MB (uses --seq_length for KV-cache estimate, assumes BF16)
21+
1722
Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2)
1823
while skipping pruning of num_attention_heads using following defaults:
1924
1024 samples from nemotron-post-training-dataset-v2 for calibration,
@@ -47,7 +52,7 @@
4752
import modelopt.torch.opt as mto
4853
import modelopt.torch.prune as mtp
4954
import modelopt.torch.utils.distributed as dist
50-
from modelopt.torch.utils import get_supported_datasets, num2hrb, print_rank_0, warn_rank_0
55+
from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0
5156
from modelopt.torch.utils.plugins.mbridge import (
5257
get_hf_mbridge_calibration_loop,
5358
load_mbridge_model_from_hf,
@@ -105,7 +110,6 @@ def get_args() -> argparse.Namespace:
105110
)
106111
parser.add_argument("--calib_gbs", type=int, default=1, help="Calibration global batch size")
107112
parser.add_argument("--seq_length", type=int, default=4096)
108-
109113
# Pruning parameters
110114
parser.add_argument(
111115
"--prune_intermediate_ckpt",
@@ -117,42 +121,70 @@ def get_args() -> argparse.Namespace:
117121
),
118122
)
119123

120-
target_group = parser.add_mutually_exclusive_group(required=True)
121-
target_group.add_argument(
124+
parser.add_argument(
122125
"--prune_export_config",
123126
type=str,
124127
help=(
125128
'Target pruned config as JSON e.g., \'{"hidden_size": 512, "ffn_hidden_size": 2048}\'. '
126129
f"Supported hyperparameters: {mtp.mcore_minitron.SUPPORTED_HPARAMS}. "
127-
"Cannot be used with --prune_target_params."
130+
"Cannot be combined with NAS-based targets."
128131
),
129132
)
130-
target_group.add_argument(
133+
parser.add_argument(
131134
"--prune_target_params",
132135
type=float,
133136
help=(
134-
"Target parameter count for pruning e.g., 6e9 for pruning to 6B params (total params, not active params). "
135-
"Uses Neural Architecture Search (NAS) to find the best pruned model that maximizes the --prune_score_func."
136-
"Cannot be used with --prune_export_config."
137+
"Target total parameter count e.g., 6e9 for 6B params. "
138+
"Uses NAS to find the best pruned model that maximizes --prune_score_func. "
139+
"Can be combined with --prune_target_active_params and/or --prune_target_memory_mb."
140+
),
141+
)
142+
parser.add_argument(
143+
"--prune_target_active_params",
144+
type=float,
145+
help=(
146+
"Target active parameter count e.g., 3e9 for 3B active params (useful for MoE models). "
147+
"Uses NAS to find the best pruned model that maximizes --prune_score_func. "
148+
"Can be combined with --prune_target_params and/or --prune_target_memory_mb."
149+
),
150+
)
151+
parser.add_argument(
152+
"--prune_target_memory_mb",
153+
type=float,
154+
help=(
155+
"Target memory footprint in MB (weights + KV-cache estimated via seq_length and "
156+
"--inference_batch_size; assumes BF16). "
157+
"Uses NAS to find the best pruned model that maximizes --prune_score_func. "
158+
"Can be combined with --prune_target_params and/or --prune_target_active_params."
159+
),
160+
)
161+
parser.add_argument(
162+
"--inference_batch_size",
163+
type=int,
164+
default=None,
165+
help=(
166+
"Batch size used only for KV-cache sizing in --prune_target_memory_mb. "
167+
"Defaults to --calib_mbs when not set. "
168+
"Use this to target an inference batch size that differs from the calibration micro-batch size."
137169
),
138170
)
139171

140172
parser.add_argument(
141173
"--prune_score_func",
142174
type=str,
143-
default="mmlu_10pct",
175+
default="mmlu_10pct_bs1",
144176
help=(
145-
"Score function to use for NAS-based pruning (--prune_target_params). Only supports MMLU at the moment. "
146-
"Format: mmlu_<N>pct where <N> is the percentage of MMLU data to sample per subject "
147-
"(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)."
177+
"Score function to use for NAS-based pruning. Only supports MMLU at the moment. "
178+
"Format: mmlu_<N>pct_<bs> where <N> is the percentage of MMLU data to sample per subject and <bs> is "
179+
"batch size for fast evaluation (default is mmlu_10pct_bs1)."
148180
),
149181
)
150182
parser.add_argument(
151183
"--ss_channel_divisor",
152184
type=int,
153185
default=None,
154186
help=(
155-
"hidden_size / ffn_hidden_size divisor for NAS-based pruning (--prune_target_params). "
187+
"hidden_size / ffn_hidden_size divisor for NAS-based pruning. "
156188
"Leave as None to use default divisors."
157189
),
158190
)
@@ -162,14 +194,14 @@ def get_args() -> argparse.Namespace:
162194
default=0.4,
163195
help=(
164196
f"Maximum width pruning percentage ({mtp.mcore_minitron.SUPPORTED_HPARAMS - {'num_layers'}}) "
165-
"for NAS-based pruning (--prune_target_params)"
197+
"for NAS-based pruning"
166198
),
167199
)
168200
parser.add_argument(
169201
"--max_depth_pruning",
170202
type=float,
171203
default=0.2,
172-
help="Maximum depth pruning percentage ('num_layers') for NAS-based pruning (--prune_target_params)",
204+
help="Maximum depth pruning percentage ('num_layers') for NAS-based pruning",
173205
)
174206
parser.add_argument(
175207
"--hparams_to_skip",
@@ -178,7 +210,7 @@ def get_args() -> argparse.Namespace:
178210
default=[],
179211
choices=mtp.mcore_minitron.SUPPORTED_HPARAMS,
180212
help=(
181-
"Space-separated list of hparams to skip for NAS-based pruning (--prune_target_params) "
213+
"Space-separated list of hparams to skip for NAS-based pruning "
182214
"e.g. dont prune 'num_attention_heads'"
183215
),
184216
)
@@ -187,13 +219,27 @@ def get_args() -> argparse.Namespace:
187219
type=int,
188220
default=10,
189221
help=(
190-
"Number of top candidates to consider for NAS-based pruning (--prune_target_params). "
222+
"Number of top candidates to consider for NAS-based pruning. "
191223
"Higher values will take longer to prune but may find a better model."
192224
),
193225
)
194226

195227
args = parser.parse_args()
196228

229+
# Validate pruning target arguments
230+
_nas_targets = [
231+
args.prune_target_params,
232+
args.prune_target_active_params,
233+
args.prune_target_memory_mb,
234+
]
235+
if args.prune_export_config and any(t is not None for t in _nas_targets):
236+
parser.error("--prune_export_config cannot be combined with NAS-based targets.")
237+
if not args.prune_export_config and not any(t is not None for t in _nas_targets):
238+
parser.error(
239+
"At least one of --prune_export_config, --prune_target_params,"
240+
" --prune_target_active_params, or --prune_target_memory_mb is required."
241+
)
242+
197243
# Post-process arguments
198244
if args.prune_intermediate_ckpt is None:
199245
if args.output_megatron_path:
@@ -250,11 +296,6 @@ def main(args: argparse.Namespace):
250296
init_model_parallel=True,
251297
moe_grouped_gemm=False,
252298
)
253-
print_rank_0(f"\nPruning model (showing PP rank0): {unwrapped_model}")
254-
print_rank_0(
255-
f"Original model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}"
256-
)
257-
258299
forward_loop = get_hf_mbridge_calibration_loop(
259300
model=model,
260301
provider=provider,
@@ -271,10 +312,20 @@ def main(args: argparse.Namespace):
271312
"forward_loop": forward_loop,
272313
"checkpoint": args.prune_intermediate_ckpt,
273314
}
274-
if args.prune_target_params is not None:
275-
# Restrict search space to a smaller set of candidates
276-
# Allow more choices for MoE FFN as they are generally smaller
277-
# NOTE: You can reduce the divisors and increase config['top_k'] to potentially find a better model.
315+
if args.prune_export_config is not None:
316+
# Less restrictive search space for manual pruning
317+
ss_config = mtp.mcore_minitron.get_mcore_minitron_config(
318+
hidden_size_divisor=64,
319+
ffn_hidden_size_divisor=64,
320+
mamba_head_dim_divisor=8,
321+
num_moe_experts_divisor=8,
322+
num_layers_divisor=1,
323+
)
324+
pruning_constraints = {"export_config": args.prune_export_config}
325+
else:
326+
# NAS-based pruning: restrict search space to a smaller set of candidates.
327+
# Allow more choices for MoE FFN as they are generally smaller.
328+
# NOTE: Reduce divisors and increase config['top_k'] to potentially find a better model.
278329
hidden_size_divisor = args.ss_channel_divisor if args.ss_channel_divisor else 256
279330
ffn_hidden_size_divisor = (
280331
args.ss_channel_divisor
@@ -290,40 +341,53 @@ def main(args: argparse.Namespace):
290341
)
291342
print_rank_0(f"Using search space config: {ss_config}")
292343

293-
pruning_constraints = {"params": args.prune_target_params}
344+
pruning_constraints = {}
345+
if args.prune_target_params is not None:
346+
pruning_constraints["params"] = args.prune_target_params
347+
if args.prune_target_active_params is not None:
348+
pruning_constraints["active_params"] = args.prune_target_active_params
349+
if args.prune_target_memory_mb is not None:
350+
pruning_constraints["memory_mb"] = args.prune_target_memory_mb
351+
294352
print_rank_0(
295353
f"Using NAS-based automatic pruning with score function: {args.prune_score_func}. "
296354
"You can change this to be any other metric you want to maximize (e.g. negative validation loss)."
297355
)
298356

299-
match = re.fullmatch(r"mmlu_(\d+)pct", args.prune_score_func)
300-
if not match:
357+
match = re.fullmatch(r"mmlu_(\d+)pct_bs(\d+)", args.prune_score_func)
358+
legacy_match = re.fullmatch(r"mmlu_(\d+)pct", args.prune_score_func)
359+
if match:
360+
mmlu_frac = float(match.group(1)) / 100.0
361+
batch_size = int(match.group(2))
362+
elif legacy_match:
363+
warn_rank_0(
364+
f"Score function '{args.prune_score_func}' uses the deprecated format "
365+
"'mmlu_<N>pct'. Use 'mmlu_<N>pct_bs<bs>' to specify the evaluation batch size. "
366+
"Falling back to batch_size=1."
367+
)
368+
mmlu_frac = float(legacy_match.group(1)) / 100.0
369+
batch_size = 1
370+
else:
301371
raise ValueError(
302-
f"Invalid score function: {args.prune_score_func}. Expected format: mmlu_<N>pct (e.g. mmlu_10pct)"
372+
f"Invalid score function: {args.prune_score_func}. "
373+
"Expected format: mmlu_<N>pct_bs<bs> (e.g. mmlu_10pct_bs1)"
303374
)
304-
mmlu_frac = float(match.group(1)) / 100.0
305375

306376
def score_func(m):
307377
return megatron_mmlu(
308-
m, tokenizer, few_shots=0, fraction=mmlu_frac, batch_size=args.calib_mbs
378+
m, tokenizer, few_shots=0, fraction=mmlu_frac, batch_size=batch_size
309379
)
310380

311381
pruning_config["score_func"] = score_func
312382
pruning_config["max_width_pruning"] = args.max_width_pruning
313383
pruning_config["max_depth_pruning"] = args.max_depth_pruning
314384
pruning_config["hparams_to_skip"] = args.hparams_to_skip
315385
pruning_config["top_k"] = args.top_k
316-
elif args.prune_export_config is not None:
317-
# Less restrictive search space for manual pruning
318-
ss_config = mtp.mcore_minitron.get_mcore_minitron_config(
319-
hidden_size_divisor=64,
320-
ffn_hidden_size_divisor=64,
321-
mamba_head_dim_divisor=8,
322-
num_moe_experts_divisor=8,
323-
num_layers_divisor=1,
386+
# memory_mb constraint requires batch_size and seq_length
387+
pruning_config["batch_size"] = (
388+
args.inference_batch_size if args.inference_batch_size is not None else args.calib_mbs
324389
)
325-
326-
pruning_constraints = {"export_config": args.prune_export_config}
390+
pruning_config["seq_length"] = args.seq_length
327391
print_rank_0(f"Pruning constraints: {pruning_constraints}")
328392

329393
unwrapped_model, pruning_scores = mtp.prune( # in-place pruning
@@ -343,10 +407,6 @@ def score_func(m):
343407
else "hybrid_layer_pattern"
344408
)
345409
setattr(provider, hybrid_key, getattr(unwrapped_model, hybrid_key))
346-
print_rank_0(f"\nPruned model (showing PP rank0): {unwrapped_model}")
347-
print_rank_0(
348-
f"Pruned model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}"
349-
)
350410

351411
if args.output_megatron_path is not None:
352412
print_rank_0(

0 commit comments

Comments
 (0)