Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,11 +1744,13 @@ def fuse_parametrizations(model: nn.Module) -> nn.Module:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
# Check if the module has any quantization-related children
state_dict = None
is_proxy_compiled = False
for submodule in module.modules():
if isinstance(
submodule,
(WeightQuantProxyFromInjectorBase, BiasQuantProxyFromInjectorBase)):
state_dict = submodule.state_dict()
is_proxy_compiled = submodule.is_proxy_compiled
break
# The rotated tensor is saved by setting leave_parametrized=True
parametrize.remove_parametrizations(
Expand All @@ -1757,6 +1759,8 @@ def fuse_parametrizations(model: nn.Module) -> nn.Module:
# when registering the parametrized parameter
if state_dict is not None:
submodule.load_state_dict(state_dict)
if is_proxy_compiled:
submodule.compile_quant()
return model


Expand Down
8 changes: 8 additions & 0 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def compile_quant(self, compile_export=False):
fullgraph = not self.is_groupwise
self.tensor_quant = torch.compile(self.tensor_quant, dynamic=True, fullgraph=fullgraph)

@property
def is_proxy_compiled(self):
return 'OptimizedModule' in str(type(self.tensor_quant))

@property
def input_view_impl(self):
if self.tensor_quant is not None:
Expand Down Expand Up @@ -187,6 +191,10 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
def tracked_parameter_list(self):
return [m.bias for m in self.tracked_module_list if m.bias is not None]

@property
def is_proxy_compiled(self):
return False

def get_cached(self, attr):
if self._cached_bias is None:
if not is_dynamo_compiling():
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def compile_quant(self, compile_export=False):
self.fused_activation_quant_proxy.tensor_quant = torch.compile(
self.fused_activation_quant_proxy.tensor_quant, dynamic=True, fullgraph=fullgraph)

@property
def is_proxy_compiled(self):
return 'OptimizedModule' in str(type(self.fused_activation_quant_proxy.tensor_quant))

@property
def input_view_impl(self):
if self.fused_activation_quant_proxy.tensor_quant is not None and not isinstance(
Expand Down