Skip to content

Commit 87237e7

Browse files
authored
Update on the QuantModule & DynamicModule to accept external forward (#824)
## What does this PR do? **Type of change:** new feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** This MR improves robustness when `forward()` is monkey‑patched (replaced at runtime) on modules that later get wrapped/converted by ModelOpt (DynamicModule + quant wrappers). It addresses two concrete failure modes introduced/exposed by supporting “patched forward” modules: 1. Forward “leakage” after export: a dynamic wrapper forward could remain bound on an instance even after export() restores the original (non‑dynamic) class, causing runtime errors in unrelated codepaths (e.g. KD export/save/restore chains). 1. Infinite recursion in quant wrappers: _forward_pre_dm can sometimes point to a wrapper forward that already participates in the class chain, causing a recursion loop when quant wrappers call _forward_pre_dm directly. ## Usage <!-- You can potentially add a usage example below. --> ```python lin = torch.nn.Linear(4, 4) def upcast_forward(x): # external closure: NOT part of any class MRO return torch.nn.functional.linear(x, lin.weight.to(x.dtype), lin.bias.to(x.dtype)) lin.forward = upcast_forward # framework/user patches forward # Later, ModelOpt converts/wraps the module. # It stashes the patched function as `_forward_pre_dm` and binds the wrapper forward on the class. # During quantization, QuantInputBase.forward sees `_forward_pre_dm` is NOT in MRO -> calls it. ``` ``` # Imagine a module already wrapped by quant classes: # QuantLinearConvBase.forward -> super().forward -> QuantInputBase.forward -> ... # If `_forward_pre_dm` accidentally points to QuantLinearConvBase.forward (which IS in MRO), # and QuantInputBase.forward calls it directly, you get: # QuantInputBase.forward -> _forward_pre_dm (QuantLinearConvBase.forward) # -> super().forward -> QuantInputBase.forward -> ... # infinite recursion # The fix: if `_forward_pre_dm` is a forward already in MRO, ignore it and use super().forward. ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Bug Fixes** * Improved forward method restoration during module export to prevent state leakage * Enhanced quantization behavior when using chained optimization modes * **Tests** * Added regression tests for quantization with runtime forward patching * Added validation tests for sparse quantization combined with distillation workflows <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 9de9b8f commit 87237e7

File tree

4 files changed

+235
-1
lines changed

4 files changed

+235
-1
lines changed

modelopt/torch/opt/dynamic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,15 @@ def export(self) -> nn.Module:
584584
assert not is_dynamic, "Exported module must not be a DynamicModule anymore!"
585585
delattr(self, "_dm_attribute_manager")
586586

587+
# If this module had a monkey-patched forward before DynamicModule.convert(), we may have
588+
# overridden it by binding the dynamic forward onto the instance (to follow the MRO).
589+
# On final export, restore the original forward to avoid leaking a dynamic forward
590+
# (e.g., DistillationModel.forward) onto the exported (non-dynamic) module instance.
591+
# please see: https://github.com/NVIDIA/Model-Optimizer/pull/824
592+
if hasattr(self, "_forward_pre_dm"):
593+
setattr(self, "forward", getattr(self, "_forward_pre_dm"))
594+
delattr(self, "_forward_pre_dm")
595+
587596
return self
588597

589598
@classmethod
@@ -621,6 +630,10 @@ def bind_forward_method_if_needed(self):
621630
# accelerate patched module
622631
bind_forward_method(self, self.__class__.forward)
623632
else:
633+
if not hasattr(self, "_forward_pre_dm"):
634+
# Keep the patched forward for downstream modules that want to call it.
635+
self._forward_pre_dm = self.forward
636+
bind_forward_method(self, self.__class__.forward)
624637
warnings.warn(
625638
"Received a module with monkey patched forward method. Dynamic converted module"
626639
" might not work."

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,26 @@ class QuantInputBase(QuantModule):
110110
def forward(self, input, *args, **kwargs):
111111
"""Quantize the input before calling the original forward method."""
112112
input = self.input_quantizer(input)
113-
output = super().forward(input, *args, **kwargs)
113+
# Check MR: https://github.com/NVIDIA/Model-Optimizer/pull/824
114+
if hasattr(self, "_forward_pre_dm"):
115+
pre_fwd = getattr(self, "_forward_pre_dm")
116+
117+
def _is_forward_in_mro(bound_or_func) -> bool:
118+
# If this is a bound method, compare its underlying function to any `forward`
119+
# implementation in the current MRO. If it matches, it's not an external monkey-patch.
120+
if hasattr(bound_or_func, "__func__"):
121+
fn = bound_or_func.__func__
122+
for cls in type(self).mro():
123+
if cls.__dict__.get("forward") is fn:
124+
return True
125+
return False
126+
127+
if pre_fwd is getattr(self, "forward") or _is_forward_in_mro(pre_fwd):
128+
output = super().forward(input, *args, **kwargs)
129+
else:
130+
output = pre_fwd(input, *args, **kwargs)
131+
else:
132+
output = super().forward(input, *args, **kwargs)
114133
if isinstance(output, tuple):
115134
return (self.output_quantizer(output[0]), *output[1:])
116135
return self.output_quantizer(output)

tests/unit/torch/opt/test_chaining.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,28 @@
1515

1616
import pytest
1717
import torch
18+
import torch.nn.functional as F
1819
from _test_utils.torch.misc import compare_outputs
1920
from _test_utils.torch.opt.utils import apply_mode_with_sampling
2021
from torchvision.models.mobilenetv2 import InvertedResidual
2122

2223
import modelopt.torch.distill as mtd
2324
import modelopt.torch.nas as mtn
2425
import modelopt.torch.opt as mto
26+
import modelopt.torch.quantization as mtq
2527
import modelopt.torch.sparsity as mts
2628
from modelopt.torch.utils.distributed import _serialize
2729

2830

31+
class SimpleLinearModel(torch.nn.Module):
32+
def __init__(self):
33+
super().__init__()
34+
self.linear = torch.nn.Linear(4, 4)
35+
36+
def forward(self, x):
37+
return self.linear(x)
38+
39+
2940
def get_model():
3041
return InvertedResidual(16, 32, 1, 6)
3142

@@ -228,3 +239,101 @@ def test_sparse_quantized_module():
228239
model = mtn.export(model)
229240
assert torch.equal(conv.weight, weight_expected)
230241
assert torch.equal(conv._parameters["weight"], weight_expected), "Weight should be overwritten!"
242+
243+
244+
def test_sparse_quantize_kd_linear_forward_backward():
245+
"""Ensure sparse + quantize + distill works for linear forward/backward."""
246+
model = SimpleLinearModel()
247+
teacher_model = SimpleLinearModel()
248+
249+
called = {"patched_forward": 0, "input_q": 0, "weight_q": 0, "pass": 0}
250+
251+
def _make_patched_forward(linear):
252+
def patched_forward(x):
253+
called["patched_forward"] += 1
254+
w = linear.weight
255+
b = linear.bias if linear.bias is not None else None
256+
return F.linear(x, w, b)
257+
258+
return patched_forward
259+
260+
model.linear.forward = _make_patched_forward(model.linear)
261+
teacher_model.linear.forward = _make_patched_forward(teacher_model.linear)
262+
263+
def _get_linear_kd_mode():
264+
config = {
265+
"teacher_model": teacher_model,
266+
"criterion": {("linear", "linear"): mtd.LogitsDistillationLoss()},
267+
"loss_balancer": mtd.StaticLossBalancer(),
268+
}
269+
return [("kd_loss", config)]
270+
271+
model = mto.apply_mode(model, mode="sparse_magnitude", init_state=True)
272+
model = mto.apply_mode(model, mode="quantize")
273+
model = mto.apply_mode(model, mode=_get_linear_kd_mode())
274+
275+
def _count_quant_input(_m, _inp, _out):
276+
called["input_q"] += 1
277+
278+
def _count_quant_weight(_m, _inp, _out):
279+
called["weight_q"] += 1
280+
281+
model.linear.input_quantizer.register_forward_hook(_count_quant_input)
282+
model.linear.weight_quantizer.register_forward_hook(_count_quant_weight)
283+
284+
model.train()
285+
x = torch.randn(2, 4)
286+
target = torch.randn(2, 4)
287+
output = model(x)
288+
loss = F.mse_loss(output, target)
289+
loss.backward()
290+
291+
assert output.shape == target.shape
292+
assert any(p.grad is not None for p in model.parameters() if p.requires_grad), (
293+
"Expected gradients on student parameters."
294+
)
295+
assert called["patched_forward"] == 2
296+
assert called["input_q"] == 1
297+
assert called["weight_q"] == 1
298+
299+
300+
def test_chained_modes_preserve_forward_patching_during_quantize():
301+
"""Ensure chained modes do not break runtime forward patching during quantize."""
302+
model = InvertedResidual(16, 32, 1, 6).to(torch.float16)
303+
model = mto.apply_mode(model, mode="fastnas", init_state=True)
304+
model = mto.apply_mode(model, mode="export_nas")
305+
306+
conv = model.conv[0][0]
307+
called = {"patched_forward": 0, "input_q": 0, "weight_q": 0}
308+
309+
def patched_forward(x):
310+
called["patched_forward"] += 1
311+
return F.conv2d(
312+
x,
313+
conv.weight,
314+
conv.bias,
315+
conv.stride,
316+
conv.padding,
317+
conv.dilation,
318+
conv.groups,
319+
)
320+
321+
conv.forward = patched_forward
322+
323+
def _count_input(_m, _inp, _out):
324+
called["input_q"] += 1
325+
326+
def _count_weight(_m, _inp, _out):
327+
called["weight_q"] += 1
328+
329+
def forward_loop(model):
330+
conv.input_quantizer.register_forward_hook(_count_input)
331+
conv.weight_quantizer.register_forward_hook(_count_weight)
332+
x = torch.randn(1, 16, 8, 8, dtype=torch.float16)
333+
model(x)
334+
335+
mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop)
336+
337+
assert called["patched_forward"] == 1
338+
assert called["input_q"] == 1
339+
assert called["weight_q"] == 1
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import types
17+
18+
import torch
19+
import torch.nn.functional as F
20+
from torch import nn
21+
22+
import modelopt.torch.quantization as mtq
23+
from modelopt.torch.quantization import QuantModuleRegistry
24+
from modelopt.torch.quantization.nn.modules.quant_module import QuantLinearConvBase
25+
26+
27+
def test_quant_input_base_ignores_forward_pre_dm_in_mro():
28+
"""Regression test for recursion when `_forward_pre_dm` points to a wrapper forward in the MRO.
29+
30+
In complex wrapper stacks, `_forward_pre_dm` may accidentally end up referencing a `forward`
31+
method already present in the quant wrapper MRO (e.g. QuantLinearConvBase.forward). If
32+
QuantInputBase.forward calls that directly, it can recurse indefinitely:
33+
34+
QuantLinearConvBase.forward -> super().forward (QuantInputBase.forward)
35+
-> _forward_pre_dm (QuantLinearConvBase.forward) -> ...
36+
37+
The fix is to detect this case and fall back to `super().forward` instead.
38+
"""
39+
lin = nn.Linear(8, 8, bias=False)
40+
QuantModuleRegistry.convert(lin)
41+
42+
# Force the problematic state: `_forward_pre_dm` points to a wrapper forward already in MRO.
43+
lin._forward_pre_dm = types.MethodType(QuantLinearConvBase.forward, lin)
44+
45+
x = torch.randn(2, 8)
46+
y = lin(x)
47+
assert isinstance(y, torch.Tensor)
48+
assert y.shape == (2, 8)
49+
50+
51+
def test_quantize_calibration_calls_quantizers_with_runtime_forward_patch():
52+
"""Regression test for on-the-fly forward patching during mtq.quantize calibration.
53+
54+
Some frameworks replace `module.forward` on-the-fly with a closure just before a forward pass.
55+
During mtq.quantize calibration, quantizers must still run (input + weight at minimum).
56+
"""
57+
lin = nn.Linear(8, 8, bias=True).to(torch.float32)
58+
59+
called = {"patched_forward": 0, "input_q": 0, "weight_q": 0}
60+
61+
# Monkey patch instance-level forward (closure-style, no `self` argument).
62+
def patched_forward(x):
63+
called["patched_forward"] += 1
64+
# Use module parameters directly; if quantization wrappers are active, weight access
65+
# should still be routed through the quantized path.
66+
w = lin.weight.to(dtype=x.dtype)
67+
b = lin.bias.to(dtype=x.dtype) if lin.bias is not None else None
68+
return F.linear(x, w, b)
69+
70+
def _count_input(_m, _inp, _out):
71+
called["input_q"] += 1
72+
73+
def _count_weight(_m, _inp, _out):
74+
called["weight_q"] += 1
75+
76+
lin.forward = patched_forward
77+
x = torch.randn(2, 8, dtype=torch.float16)
78+
79+
def forward_loop(model):
80+
# Patch forward on-the-fly (after conversion, right before calibration forward).
81+
82+
# Count quantizer executions during calibration.
83+
model.input_quantizer.register_forward_hook(_count_input)
84+
model.weight_quantizer.register_forward_hook(_count_weight)
85+
86+
model(x)
87+
88+
mtq.quantize(lin, mtq.INT8_DEFAULT_CFG, forward_loop)
89+
lin(x)
90+
91+
assert called["patched_forward"] == 2
92+
assert called["input_q"] == 2
93+
assert called["weight_q"] == 2

0 commit comments

Comments
 (0)