Skip to content

Commit 04d1ffb

Browse files
committed
[quantization] Fix QuantGELU to preserve nn.GELU approximate mode
QuantGELU previously called _gelu(x) without forwarding the original nn.GELU.approximate setting, so nn.GELU(approximate="tanh") was executed as exact GELU. Store the original approximate mode and use it in forward. Add NO_QUANT parity coverage for tanh GELU. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent fa941ea commit 04d1ffb

2 files changed

Lines changed: 37 additions & 0 deletions

File tree

test/quantization/wrapq/wrappers/test_quant_elementwise.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@
5454
(torch.nn.Tanh(), torch.tanh, QuantTanh),
5555
(torch.nn.ReLU(), torch.relu, QuantReLU),
5656
(torch.nn.GELU(), torch.nn.functional.gelu, QuantGELU),
57+
(
58+
torch.nn.GELU(approximate="tanh"),
59+
partial(torch.nn.functional.gelu, approximate="tanh"),
60+
QuantGELU,
61+
),
5762
]
5863

5964
try:
@@ -77,6 +82,22 @@ def _calibrate(self, qw, x):
7782
_ = qw(x)
7883
qw.freeze_qparams()
7984

85+
# ------------------------------------------------------------------
86+
def test_gelu_approximate_tanh_no_quant_parity(self):
87+
x = torch.linspace(-6.0, 6.0, steps=257).reshape(-1, 1)
88+
fp32_mod = torch.nn.GELU(approximate="tanh")
89+
qw = PTQWrapper(fp32_mod)
90+
91+
self.assertIs(qw._mode, Mode.NO_QUANT)
92+
93+
with torch.no_grad():
94+
q_out = qw(x)
95+
fp_out = fp32_mod(x)
96+
wrong_out = torch.nn.functional.gelu(x, approximate="none")
97+
98+
torch.testing.assert_close(q_out, fp_out, rtol=0, atol=0)
99+
self.assertGreater((wrong_out - fp_out).abs().max().item(), 1e-6)
100+
80101
# ------------------------------------------------------------------
81102
def test_registry_and_factory(self):
82103
for fp32_mod, _, quant_cls in ACTIVATIONS:

tico/quantization/wrapq/wrappers/quant_elementwise.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,26 @@ def FUNC(x: torch.Tensor) -> torch.Tensor:
146146

147147
@register(nn.GELU)
148148
class QuantGELU(QuantElementwise):
149+
def __init__(
150+
self,
151+
fp_module: nn.Module,
152+
*,
153+
qcfg: Optional[PTQConfig] = None,
154+
fp_name: Optional[str] = None,
155+
):
156+
super().__init__(fp_module, qcfg=qcfg, fp_name=fp_name)
157+
self.approximate = getattr(fp_module, "approximate", "none")
158+
149159
@staticmethod
150160
def FUNC(x: torch.Tensor) -> torch.Tensor:
151161
return _gelu(x)
152162

163+
def forward(self, x: torch.Tensor) -> torch.Tensor:
164+
x_q = self._fq(x, self.act_in_obs)
165+
y = _gelu(x_q, approximate=self.approximate)
166+
y_q = self._fq(y, self.act_out_obs)
167+
return y_q
168+
153169

154170
@try_register("transformers.activations.GELUTanh")
155171
class QuantGELUTanh(QuantElementwise):

0 commit comments

Comments
 (0)