Commit 87237e7
authored
Update on the QuantModule & DynamicModule to accept external forward (#824)
## What does this PR do?
**Type of change:** new feature <!-- Use one of the following: Bug fix,
new feature, new example, new tests, documentation. -->
**Overview:**
This MR improves robustness when `forward()` is monkey‑patched (replaced
at runtime) on modules that later get wrapped/converted by ModelOpt
(DynamicModule + quant wrappers).
It addresses two concrete failure modes introduced/exposed by supporting
“patched forward” modules:
1. Forward “leakage” after export: a dynamic wrapper forward could
remain bound on an instance even after export() restores the original
(non‑dynamic) class, causing runtime errors in unrelated codepaths (e.g.
KD export/save/restore chains).
1. Infinite recursion in quant wrappers: _forward_pre_dm can sometimes
point to a wrapper forward that already participates in the class chain,
causing a recursion loop when quant wrappers call _forward_pre_dm
directly.
## Usage
<!-- You can potentially add a usage example below. -->
```python
lin = torch.nn.Linear(4, 4)
def upcast_forward(x):
# external closure: NOT part of any class MRO
return torch.nn.functional.linear(x, lin.weight.to(x.dtype), lin.bias.to(x.dtype))
lin.forward = upcast_forward # framework/user patches forward
# Later, ModelOpt converts/wraps the module.
# It stashes the patched function as `_forward_pre_dm` and binds the wrapper forward on the class.
# During quantization, QuantInputBase.forward sees `_forward_pre_dm` is NOT in MRO -> calls it.
```
```
# Imagine a module already wrapped by quant classes:
# QuantLinearConvBase.forward -> super().forward -> QuantInputBase.forward -> ...
# If `_forward_pre_dm` accidentally points to QuantLinearConvBase.forward (which IS in MRO),
# and QuantInputBase.forward calls it directly, you get:
# QuantInputBase.forward -> _forward_pre_dm (QuantLinearConvBase.forward)
# -> super().forward -> QuantInputBase.forward -> ...
# infinite recursion
# The fix: if `_forward_pre_dm` is a forward already in MRO, ignore it and use super().forward.
```
## Testing
<!-- Mention how have you tested your change if applicable. -->
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: No <!--- If No, explain why.
-->
- **Did you write any new necessary tests?**: Yes
- **Did you add or update any necessary documentation?**: No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:No
<!--- Only for new features, API changes, critical bug fixes or bw
breaking changes. -->
## Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **Bug Fixes**
* Improved forward method restoration during module export to prevent
state leakage
* Enhanced quantization behavior when using chained optimization modes
* **Tests**
* Added regression tests for quantization with runtime forward patching
* Added validation tests for sparse quantization combined with
distillation workflows
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>1 parent 9de9b8f commit 87237e7
File tree
4 files changed
+235
-1
lines changed- modelopt/torch
- opt
- quantization/nn/modules
- tests/unit/torch
- opt
- quantization
4 files changed
+235
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
584 | 584 | | |
585 | 585 | | |
586 | 586 | | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
587 | 596 | | |
588 | 597 | | |
589 | 598 | | |
| |||
621 | 630 | | |
622 | 631 | | |
623 | 632 | | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
624 | 637 | | |
625 | 638 | | |
626 | 639 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
110 | 110 | | |
111 | 111 | | |
112 | 112 | | |
113 | | - | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
114 | 133 | | |
115 | 134 | | |
116 | 135 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| 18 | + | |
18 | 19 | | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
24 | 25 | | |
| 26 | + | |
25 | 27 | | |
26 | 28 | | |
27 | 29 | | |
28 | 30 | | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
29 | 40 | | |
30 | 41 | | |
31 | 42 | | |
| |||
228 | 239 | | |
229 | 240 | | |
230 | 241 | | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
0 commit comments