Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Pruning: Add support for torch.split() operation #5143

Open
@AL3708

Description

@AL3708

If there is torch.split() operation somewhere in the network, an AttributeError is thrown:

AttributeError                            Traceback (most recent call last)
Cell In [92], line 2
      1 pruner._unwrap_model()
----> 2 ModelSpeedup(net, torch.rand(1, 3, 32, 32, device='cpu'), masks).speedup_model()

File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:543, in ModelSpeedup.speedup_model(self)
    540 fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
    542 _logger.info("infer module masks...")
--> 543 self.infer_modules_masks()
    544 _logger.info('resolve the mask conflict')
    546 # load the original stat dict before replace the model

File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:380, in ModelSpeedup.infer_modules_masks(self)
    378 curnode = visit_queue.get()
    379 # forward mask inference for curnode
--> 380 self.update_direct_sparsity(curnode)
    381 successors = self.torch_graph.find_successors(curnode.unique_name)
    382 for successor in successors:

File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:216, in ModelSpeedup.update_direct_sparsity(self, node)
    214 _logger.info('Update mask for %s', module_name)
    215 unique_name = node.unique_name
--> 216 dummy_input, input_debugname = self._prepare_dummy_input(node)
    217 # get the input mask from self.masks
    218 # Note: the input mask of the successor nodes are
    219 # already created by the predecessor node
    220 in_masks = [self.masks[debugname] for debugname in input_debugname]

File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:200, in ModelSpeedup._prepare_dummy_input(self, node)
    197         continue
    198     # The detach operation here is for the in-place operation. We cannot
    199     # directly can the backward on the output tensor of an in-place operator.
--> 200     dummy_input.append(self.internal_result[_input].detach())
    202     debugnames.append(_input)
    204 return dummy_input, debugnames

AttributeError: 'tuple' object has no attribute 'detach'

Environment:

  • NNI version: 2.9
  • Training service (local|remote|pai|aml|etc): local
  • Client OS: Windows 10
  • Python version: 3.10.6
  • PyTorchversion: 1.12
  • Is conda/virtualenv/venv used?: pipenv
  • Is running in Docker?: No

How to reproduce it?:
Use code:

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.conv_block = nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.classifier = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        x = self.stem(x)
        x0, x1 = torch.split(x, 16, dim=1)
        x1 = self.conv_block(x1)
        x = torch.cat([x0, x1], dim=1)
        return self.classifier(x)

net = Network()
config_list = [{
    'total_sparsity': 0.2,
    'op_types': ['Conv2d']
}]
pruner = L1NormPruner(net, config_list)
_, masks = pruner.compress()
pruner._unwrap_model()
# Throws an error
ModelSpeedup(net, torch.rand(1, 3, 32, 32, device='cpu'), masks).speedup_model()

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions