Commit f0eaa19
authored
Enable active-param and memory based Minitron pruning constraint (#1377)
### What does this PR do?
Type of change: New feature, new tests, documentation.
OMNIML-4108: Extends the Minitron NAS pruner to support pruning by
**active parameter count** (`active_params`) and **memory footprint**
(`memory_mb`) in addition to the existing total parameter count
(`params`) constraint. Also adds standalone utilities for analytical
model stats.
#### Changes
**New pruning constraint keys**
- `active_params`: prune to a target number of active (routed) params —
useful for MoE models where total ≫ active; when present,
`active_params` is the **primary sort/display metric** for candidates
(priority: `active_params` > `params` > `memory_mb`)
- `memory_mb`: prune to fit a memory budget (BF16 weights + KV-cache +
Mamba state at a given sequence length and batch size)
- Constraints can be combined (AND logic): e.g. `{"params": 6e9,
"memory_mb": 12288}`
**New standalone utilities**
(`modelopt.torch.nas.plugins.megatron_model_stats`)
- `mcore_param_count`: analytically computes total and active parameter
counts for GPT and Mamba/hybrid MCore models
- `mcore_memory_footprint_mb`: estimates memory in MB (weights +
KV-cache + Mamba state)
- `print_mcore_model_stats`: rich-formatted model stats panel
**Rich-formatted pruning logs** — search space, top-k candidate tables,
and best subnet panel printed on rank 0
**`prune_score_func` format update** — now `mmlu_<N>pct_bs<bs>` (e.g.
`mmlu_10pct_bs32`) to explicitly control batch size for MMLU evaluation;
old `mmlu_<N>pct` format removed
**Infrastructure**
- NeMo container bumped to `nvcr.io/nvidia/nemo:26.04` in CI and docs
- Added `examples/megatron_bridge/requirements.txt` with
`transformers<5.0` (required for saving some Nemotron-3-Nano models)
### Usage
```python
# Prune to 3B active params (MoE-aware) — active_params is the primary sort metric
mtp.prune(model, mode=[("mcore_minitron", ss_config)], constraints={"active_params": 3e9}, config=pruning_config)
# Prune to fit a 12 GB memory budget
mtp.prune(model, mode=[("mcore_minitron", ss_config)], constraints={"memory_mb": 12288}, config=pruning_config)
```
### Testing
Pruned Nemotron-3-Nano-30B-A3B (31.6B, A3.6B) --> A3.0B. Takes <1hr on
8x H100 (more details in #1376)
```bash
torchrun --nproc_per_node 8 examples/megatron_bridge/prune_minitron.py \
--pp_size 8 \
--hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \
--trust_remote_code \
--prune_target_params 28e9 \
--prune_target_active_params 3e9 \
--hparams_to_skip num_attention_heads \
--seq_length 8192 \
--output_hf_path pruned/Nemotron-3-Nano-30B-A3B-Pruned-28B-A3B-top20-max15depth-max30width-mmlu_10pct_bs32 \
--top_k 20 \
--max_depth_pruning 0.15 \
--max_width_pruning 0.30 \
--prune_score_func mmlu_10pct_bs32 \
--num_layers_in_first_pipeline_stage 5 \
--num_layers_in_last_pipeline_stage 5
```
```
╭──────────────────────────────────────────────────── Original Model Stats ─────────────────────────────────────────────────────╮
│ Total Parameters 31.58B │
│ Active Parameters 3.58B │
│ Memory (BF16, seq_length=8192, batch_size=1) weights: 60230.1 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 60301.9 MB │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
Top 20 Candidates with Scores
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┓
┃ # ┃ export_config ┃ active_params ┃ params ┃ score ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━┩
│ 1 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 120, │ 3.00B │ 27.06B │ 0.3399 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 2 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.4650 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 3 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2343 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 4 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 56, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2552 │
│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 5 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 21.61B │ 0.2601 │
│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 6 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 19.28B │ 0.3762 │
│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │
│ 7 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 22.28B │ 0.4783 │
│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 8 │ {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 96, │ 3.00B │ 21.99B │ 0.2420 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │
│ 9 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2399 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │
│ 10 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 112, │ 3.00B │ 26.17B │ 0.2601 │
│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │
│ 11 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 112, │ 3.00B │ 25.37B │ 0.2503 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 12 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.4329 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 13 │ {'num_layers': 46, 'hidden_size': 2688, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 128, │ 3.00B │ 26.17B │ 0.2587 │
│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 2816} │ │ │ │
│ 14 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 64, 'mamba_head_dim': 56, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2336 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 15 │ {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'mamba_head_dim': 56, 'num_moe_experts': 96, │ 3.00B │ 20.09B │ 0.2559 │
│ │ 'moe_ffn_hidden_size': 1536, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 16 │ {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 96, │ 3.00B │ 20.70B │ 0.4608 │
│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
│ 17 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2455 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │
│ 18 │ {'num_layers': 50, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 104, │ 3.00B │ 24.42B │ 0.2503 │
│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} │ │ │ │
│ 19 │ {'num_layers': 48, 'hidden_size': 2560, 'mamba_num_heads': 48, 'mamba_head_dim': 48, 'num_moe_experts': 120, │ 3.00B │ 27.92B │ 0.2587 │
│ │ 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3712} │ │ │ │
│ 20 │ {'num_layers': 46, 'hidden_size': 2560, 'mamba_num_heads': 56, 'mamba_head_dim': 64, 'num_moe_experts': 104, │ 3.00B │ 23.68B │ 0.2469 │
│ │ 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3072} │ │ │ │
└────┴───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┴───────────────┴────────┴────────┘
╭──────────────────────────────────────────────────────────────────────── Best Subnet ─────────────────────────────────────────────────────────────────────────╮
│ export_config {'num_layers': 52, 'hidden_size': 2304, 'mamba_num_heads': 64, 'mamba_head_dim': 64, 'num_moe_experts': 104, 'moe_ffn_hidden_size': 1856, │
│ 'moe_shared_expert_intermediate_size': 3072} │
│ active_params 3.00B │
│ params 22.28B │
│ score 0.4783 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭───────────────────────────────────────────────────── Pruned Model Stats ──────────────────────────────────────────────────────╮
│ Total Parameters 22.28B │
│ Active Parameters 3.00B │
│ Memory (BF16, seq_length=8192, batch_size=1) weights: 42489.7 MB, kv_cache: 48.0 MB, mamba_state: 23.8 MB, Total: 42561.6 MB │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
```
### Before your PR is "*Ready for review*"
- Is this change backward compatible?: ✅
- Did you write any new necessary tests?: ✅
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
✅
---------
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>1 parent 84fe91b commit f0eaa19
13 files changed
Lines changed: 1697 additions & 282 deletions
File tree
- .github/workflows
- examples
- megatron_bridge
- pruning
- modelopt/torch
- nas/plugins
- prune/plugins
- tests
- _test_utils/torch
- examples/megatron_bridge
- gpu_megatron/torch
- nas/plugins
- prune/plugins
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
86 | 86 | | |
87 | 87 | | |
88 | 88 | | |
89 | | - | |
| 89 | + | |
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
19 | | - | |
| 19 | + | |
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
| 29 | + | |
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
| |||
49 | 49 | | |
50 | 50 | | |
51 | 51 | | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
52 | 58 | | |
53 | 59 | | |
54 | 60 | | |
55 | 61 | | |
56 | | - | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
57 | 74 | | |
58 | 75 | | |
59 | 76 | | |
| |||
67 | 84 | | |
68 | 85 | | |
69 | 86 | | |
70 | | - | |
71 | | - | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
72 | 110 | | |
73 | 111 | | |
74 | 112 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
17 | 22 | | |
18 | 23 | | |
19 | 24 | | |
| |||
47 | 52 | | |
48 | 53 | | |
49 | 54 | | |
50 | | - | |
| 55 | + | |
51 | 56 | | |
52 | 57 | | |
53 | 58 | | |
| |||
105 | 110 | | |
106 | 111 | | |
107 | 112 | | |
108 | | - | |
109 | 113 | | |
110 | 114 | | |
111 | 115 | | |
| |||
117 | 121 | | |
118 | 122 | | |
119 | 123 | | |
120 | | - | |
121 | | - | |
| 124 | + | |
122 | 125 | | |
123 | 126 | | |
124 | 127 | | |
125 | 128 | | |
126 | 129 | | |
127 | | - | |
| 130 | + | |
128 | 131 | | |
129 | 132 | | |
130 | | - | |
| 133 | + | |
131 | 134 | | |
132 | 135 | | |
133 | 136 | | |
134 | | - | |
135 | | - | |
136 | | - | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
137 | 169 | | |
138 | 170 | | |
139 | 171 | | |
140 | 172 | | |
141 | 173 | | |
142 | 174 | | |
143 | | - | |
| 175 | + | |
144 | 176 | | |
145 | | - | |
146 | | - | |
147 | | - | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
148 | 180 | | |
149 | 181 | | |
150 | 182 | | |
151 | 183 | | |
152 | 184 | | |
153 | 185 | | |
154 | 186 | | |
155 | | - | |
| 187 | + | |
156 | 188 | | |
157 | 189 | | |
158 | 190 | | |
| |||
162 | 194 | | |
163 | 195 | | |
164 | 196 | | |
165 | | - | |
| 197 | + | |
166 | 198 | | |
167 | 199 | | |
168 | 200 | | |
169 | 201 | | |
170 | 202 | | |
171 | 203 | | |
172 | | - | |
| 204 | + | |
173 | 205 | | |
174 | 206 | | |
175 | 207 | | |
| |||
178 | 210 | | |
179 | 211 | | |
180 | 212 | | |
181 | | - | |
| 213 | + | |
182 | 214 | | |
183 | 215 | | |
184 | 216 | | |
| |||
187 | 219 | | |
188 | 220 | | |
189 | 221 | | |
190 | | - | |
| 222 | + | |
191 | 223 | | |
192 | 224 | | |
193 | 225 | | |
194 | 226 | | |
195 | 227 | | |
196 | 228 | | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
197 | 243 | | |
198 | 244 | | |
199 | 245 | | |
| |||
250 | 296 | | |
251 | 297 | | |
252 | 298 | | |
253 | | - | |
254 | | - | |
255 | | - | |
256 | | - | |
257 | | - | |
258 | 299 | | |
259 | 300 | | |
260 | 301 | | |
| |||
271 | 312 | | |
272 | 313 | | |
273 | 314 | | |
274 | | - | |
275 | | - | |
276 | | - | |
277 | | - | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
278 | 329 | | |
279 | 330 | | |
280 | 331 | | |
| |||
290 | 341 | | |
291 | 342 | | |
292 | 343 | | |
293 | | - | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
294 | 352 | | |
295 | 353 | | |
296 | 354 | | |
297 | 355 | | |
298 | 356 | | |
299 | | - | |
300 | | - | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
301 | 371 | | |
302 | | - | |
| 372 | + | |
| 373 | + | |
303 | 374 | | |
304 | | - | |
305 | 375 | | |
306 | 376 | | |
307 | 377 | | |
308 | | - | |
| 378 | + | |
309 | 379 | | |
310 | 380 | | |
311 | 381 | | |
312 | 382 | | |
313 | 383 | | |
314 | 384 | | |
315 | 385 | | |
316 | | - | |
317 | | - | |
318 | | - | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
324 | 389 | | |
325 | | - | |
326 | | - | |
| 390 | + | |
327 | 391 | | |
328 | 392 | | |
329 | 393 | | |
| |||
343 | 407 | | |
344 | 408 | | |
345 | 409 | | |
346 | | - | |
347 | | - | |
348 | | - | |
349 | | - | |
350 | 410 | | |
351 | 411 | | |
352 | 412 | | |
| |||
0 commit comments