Skip to content

Commit e0183dc

Browse files
committed
Fix
1 parent 10ee5ee commit e0183dc

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/brevitas/graph/equalize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1744,11 +1744,13 @@ def fuse_parametrizations(model: nn.Module) -> nn.Module:
17441744
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
17451745
# Check if the module has any quantization-related children
17461746
state_dict = None
1747+
is_proxy_compiled = False
17471748
for submodule in module.modules():
17481749
if isinstance(
17491750
submodule,
17501751
(WeightQuantProxyFromInjectorBase, BiasQuantProxyFromInjectorBase)):
17511752
state_dict = submodule.state_dict()
1753+
is_proxy_compiled = submodule.is_proxy_compiled
17521754
break
17531755
# The rotated tensor is saved by setting leave_parametrized=True
17541756
parametrize.remove_parametrizations(
@@ -1757,7 +1759,7 @@ def fuse_parametrizations(model: nn.Module) -> nn.Module:
17571759
# when registering the parametrized parameter
17581760
if state_dict is not None:
17591761
submodule.load_state_dict(state_dict)
1760-
if submodule.is_proxy_compiled:
1762+
if is_proxy_compiled:
17611763
submodule.compile_quant()
17621764
return model
17631765

0 commit comments

Comments
 (0)