This repository was archived by the owner on Sep 18, 2024. It is now read-only.
This repository was archived by the owner on Sep 18, 2024. It is now read-only.
Pruning: Add support for torch.split() operation #5143
Open
Description
If there is torch.split() operation somewhere in the network, an AttributeError is thrown:
AttributeError Traceback (most recent call last)
Cell In [92], line 2
1 pruner._unwrap_model()
----> 2 ModelSpeedup(net, torch.rand(1, 3, 32, 32, device='cpu'), masks).speedup_model()
File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:543, in ModelSpeedup.speedup_model(self)
540 fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
542 _logger.info("infer module masks...")
--> 543 self.infer_modules_masks()
544 _logger.info('resolve the mask conflict')
546 # load the original stat dict before replace the model
File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:380, in ModelSpeedup.infer_modules_masks(self)
378 curnode = visit_queue.get()
379 # forward mask inference for curnode
--> 380 self.update_direct_sparsity(curnode)
381 successors = self.torch_graph.find_successors(curnode.unique_name)
382 for successor in successors:
File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:216, in ModelSpeedup.update_direct_sparsity(self, node)
214 _logger.info('Update mask for %s', module_name)
215 unique_name = node.unique_name
--> 216 dummy_input, input_debugname = self._prepare_dummy_input(node)
217 # get the input mask from self.masks
218 # Note: the input mask of the successor nodes are
219 # already created by the predecessor node
220 in_masks = [self.masks[debugname] for debugname in input_debugname]
File ~\.virtualenvs\...\lib\site-packages\nni\compression\pytorch\speedup\compressor.py:200, in ModelSpeedup._prepare_dummy_input(self, node)
197 continue
198 # The detach operation here is for the in-place operation. We cannot
199 # directly can the backward on the output tensor of an in-place operator.
--> 200 dummy_input.append(self.internal_result[_input].detach())
202 debugnames.append(_input)
204 return dummy_input, debugnames
AttributeError: 'tuple' object has no attribute 'detach'
Environment:
- NNI version: 2.9
- Training service (local|remote|pai|aml|etc): local
- Client OS: Windows 10
- Python version: 3.10.6
- PyTorchversion: 1.12
- Is conda/virtualenv/venv used?: pipenv
- Is running in Docker?: No
How to reproduce it?:
Use code:
class Network(nn.Module):
def __init__(self):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True)
)
self.conv_block = nn.Sequential(
nn.Conv2d(16, 16, 3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True)
)
self.classifier = nn.Conv2d(32, 1, 1)
def forward(self, x):
x = self.stem(x)
x0, x1 = torch.split(x, 16, dim=1)
x1 = self.conv_block(x1)
x = torch.cat([x0, x1], dim=1)
return self.classifier(x)
net = Network()
config_list = [{
'total_sparsity': 0.2,
'op_types': ['Conv2d']
}]
pruner = L1NormPruner(net, config_list)
_, masks = pruner.compress()
pruner._unwrap_model()
# Throws an error
ModelSpeedup(net, torch.rand(1, 3, 32, 32, device='cpu'), masks).speedup_model()