Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

ModelSpeedup error: assert len(set(num_channels_list)) == 1, possible incorrect layers in dependency set #5736

Open
@saravanabalagi

Description

@saravanabalagi

ModelSpeedup does not alter the model successfully for a model with 3 successive conv blocks.

Environment:

  • NNI version: 3.0
  • Python version: 3.8.16
  • PyTorch version: 1.13.0
  • Cpu or cuda version: CUDA 11.6

Reproduce the problem

  • create a model and config with desired sparsity_ratio
  • obtain pruning masks using L1NormPruner
  • call ModelSpeedup with batch_size parameter
Minimal Code
# %%
import torch
import torch.nn as nn

from nni.compression.pruning import L1NormPruner
from nni.compression.utils import auto_set_denpendency_group_ids
from nni.compression.speedup import ModelSpeedup

# %%
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 40, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(40)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(40, 80, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(80)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(80, 1, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        return x
    
model = ConvNet()
num_params_unpruned = sum(p.numel() for p in model.parameters())
dummy_input = torch.randn(1, 3, 32, 32)
dummy_output = model(dummy_input)
print(dummy_output.shape)

# %%
sparsity_ratio = 0.5
config_list = [{
    'op_types': ['Conv2d'],
    'sparse_ratio': sparsity_ratio,
}]
config_list = auto_set_denpendency_group_ids(model, config_list, [dummy_input])
pruner = L1NormPruner(model, config_list)
_, masks = pruner.compress()
pruner.unwrap_model()
model = ModelSpeedup(model, [dummy_input], masks, garbage_collect_values=False).speedup_model()

# %%
num_params_pruned = sum(p.numel() for p in model.parameters())
print(f'Number of parameters before pruning: {num_params_unpruned}')
print(f'Number of parameters after pruning: {num_params_pruned}')

num_params_diff = num_params_unpruned - num_params_pruned
prune_ratio = num_params_diff / num_params_unpruned
print(f'Number of parameters pruned: {num_params_diff}')
print(f'Parameter ratio: {(1-prune_ratio)*100:.2f}%')

Error:

Assertion error: number of channels in same set should be identical

Error Trace
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[108], line 1
----> 1 model = ModelSpeedup(model, [dummy_input], masks, garbage_collect_values=False).speedup_model()

File /usr/local/lib/python3.8/dist-packages/nni/compression/speedup/model_speedup.py:429, in ModelSpeedup.speedup_model(self)
    427 self.logger.info('Resolve the mask conflict before mask propagate...')
    428 # fix_mask_conflict(self.masks, self.graph_module, self.dummy_input)
--> 429 self.fix_mask_conflict()
    430 self.logger.info('Infer module masks...')
    431 self.initialize_propagate(self.dummy_input)

File /usr/local/lib/python3.8/dist-packages/nni/compression/speedup/model_speedup.py:243, in ModelSpeedup.fix_mask_conflict(self)
    241 def fix_mask_conflict(self):
    242     fix_group_mask_conflict(self.graph_module, self.masks)
--> 243     fix_channel_mask_conflict(self.graph_module, self.masks)
    244     fix_weight_sharing_mask_conflict(self.graph_module, self.masks)

File /usr/local/lib/python3.8/dist-packages/nni/compression/speedup/mask_conflict.py:296, in fix_channel_mask_conflict(graph_module, masks)
    294 num_channels_list = [len(x) for x in channel_masks if x is not None]
    295 # number of channels in same set should be identical
--> 296 assert len(set(num_channels_list)) == 1
    297 num_channels = num_channels_list[0]
    299 for i, dim_mask in enumerate(channel_masks):

AssertionError: 

The same code works fine without self.conv3 and self.bn3.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions