Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Runtime Error during L1FilterPruner (and other one-shot pruners)  #3944

Open
@nikhil153

Description

@nikhil153

Describe the issue:

RuntimeError: Has not supported infering output shape from input shape for module/function: `aten::unsqueeze`, .aten::unsqueeze.90

(Complete log below)

Environment:

  • NNI version: Release 2.3 (github master branch)
  • Training service (local|remote|pai|aml|etc): local
  • Client OS: Ubuntu 18.04
  • Server OS (for remote mode only):
  • Python version: 3.7.9
  • PyTorch/TensorFlow version: 1.7.0
  • Is conda/virtualenv/venv used?: conda
  • Is running in Docker?: no

Configuration:
The network architecture is based on this paper. Here is a figure showing the details:
image

Below is my test script that uses the model definition and pretrained weights from the model repo

# IMPORTS
import argparse
import nibabel as nib
import numpy as np
from datetime import datetime
import time
import sys
import os
import glob
import os.path as op
import logging
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms, utils

from scipy.ndimage.filters import median_filter, gaussian_filter
from skimage.measure import label, regionprops
from skimage.measure import label

from collections import OrderedDict
from os import makedirs

from models.networks import FastSurferCNN

# Compute costs
import pandas as pd

# nni (autoML model compression)
sys.path.append('../../nni')
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, L2FilterPruner, LevelPruner, FPGMPruner
from nni.compression.pytorch import ModelSpeedup


def options_parse():
    """
    Command line option parser
    """
    parser = argparse.ArgumentParser()

    # Options for model parameters setup (only change if model training was changed)
    parser.add_argument('--num_filters', type=int, default=64,
                        help='Filter dimensions for DenseNet (all layers same). Default=64')
    parser.add_argument('--num_classes_ax_cor', type=int, default=79,
                        help='Number of classes to predict in axial and coronal net, including background. Default=79')
    parser.add_argument('--num_classes_sag', type=int, default=51,
                        help='Number of classes to predict in sagittal net, including background. Default=51')
    parser.add_argument('--num_channels', type=int, default=7,
                        help='Number of input channels. Default=7 (thick slices)')
    parser.add_argument('--kernel_height', type=int, default=5, help='Height of Kernel (Default 5)')
    parser.add_argument('--kernel_width', type=int, default=5, help='Width of Kernel (Default 5)')
    parser.add_argument('--stride', type=int, default=1, help="Stride during convolution (Default 1)")
    parser.add_argument('--stride_pool', type=int, default=2, help="Stride during pooling (Default 2)")
    parser.add_argument('--pool', type=int, default=2, help='Size of pooling filter (Default 2)')

    sel_option = parser.parse_args()

    return sel_option

    
### Compress model (prune + speed up from autoML)
def compress_model(model, dummy_data, prune_type, prune_percent, device, prune_save_dir='./pruned_models/'):
    print(f'compressing model with prune type: {prune_type}, sparsity: {prune_percent} and saving it here:\n{prune_save_dir}')
    config_list = [{
        'sparsity': prune_percent,
        'op_types': ['Conv2d'] #[]
    }]

    print(f'shape of dummy data: {dummy_data.shape}')
    if prune_type == 'L1':
        pruner =    (model, config_list, dependency_aware=True, dummy_input=dummy_data) 
    elif prune_type == 'L2':
        pruner = L2FilterPruner(model,config_list)
    elif prune_type == 'level':
        pruner = LevelPruner(model, config_list)
    elif prune_type == 'FPGM':
        pruner = FPGMPruner(model,config_list, dependency_aware=True, dummy_input=dummy_data)
    else:
        print(f'Unknown pruner type: {prune_type}')

    model = pruner.compress()
    pruner.get_pruned_weights()

    # export the pruned model masks for model speedup
    model_path = os.path.join(prune_save_dir, f'pruned_{prune_type}_{prune_percent}.pth')
    mask_path = os.path.join(prune_save_dir, f'mask_{prune_type}_{prune_percent}.pth')
    pruner.export_model(model_path=model_path, mask_path=mask_path)

    print('speeding up the model by truly pruning')
    # Unwrap all modules to normal state
    pruner._unwrap_model()
    m_speedup = ModelSpeedup(model, dummy_data, mask_path, device)
    m_speedup.speedup_model()

    return model


def load_pretrained(pretrained_ckpt, params_model, model):
    model_state = torch.load(pretrained_ckpt, map_location=params_model["device"])
    new_state_dict = OrderedDict()

    # FastSurfer model specific configs
    for k, v in model_state["model_state_dict"].items():

        if k[:7] == "module." and not params_model["model_parallel"]:
            new_state_dict[k[7:]] = v

        elif k[:7] != "module." and params_model["model_parallel"]:
            new_state_dict["module." + k] = v

        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)

    model.eval()
    
    return model


if __name__ == "__main__":

    args = options_parse() 

    plane = "Axial"
    pretrained_ckpt = f'../checkpoints/{plane}_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl'

    # Put it onto the GPU or CPU
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Set up model for axial and coronal networks
    params_model = {'num_channels': args.num_channels, 'num_filters': args.num_filters,
                      'kernel_h': args.kernel_height, 'kernel_w': args.kernel_width,
                      'stride_conv': args.stride, 'pool': args.pool,
                      'stride_pool': args.stride_pool, 'num_classes': args.num_classes_ax_cor,
                      'kernel_c': 1, 'kernel_d': 1,
                      'model_parallel': False,
                      'device': device
                      }

    # Select the model
    model = FastSurferCNN(params_model)
    model.to(device)
 
    # Load pretrained weights
    model = load_pretrained(pretrained_ckpt, params_model, model)

    # Prune model
    dummy_data = torch.ones(1, 7, 256, 256)
    prune_type = 'L1'
    prune_percent = 0.5

    prune_save_dir='./pruned_models/'

    model = compress_model(model, dummy_data, prune_type, prune_percent, device, prune_save_dir)

Log message:

compressing model with prune type: L1, sparsity: 0.5 and saving it here:
./pruned_models/
shape of dummy data: torch.Size([1, 7, 256, 256])
[2021-07-15 17:38:58] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: encode1.conv1,encode1.conv0
../../nni/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py:234: UserWarning: This overload of nonzero is deprecated:
        nonzero()
Consider using one of the following signatures instead:
        nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729006826/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  channel_masks == False).nonzero().squeeze(1).tolist()
[2021-07-15 17:38:58] INFO (torch filter pruners/MainThread) Prune the 4,5,8,9,13,14,16,17,19,23,24,26,28,30,31,32,35,36,37,38,39,40,42,43,46,48,49,50,51,53,56,58 channels for all dependent
[2021-07-15 17:38:58] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: encode1.conv2,decode1.conv1,decode1.conv0,encode2.conv1,decode2.conv2,encode2.conv0
[2021-07-15 17:38:58] INFO (torch filter pruners/MainThread) Prune the 0,2,6,10,14,17,19,20,21,22,27,28,30,32,35,37,39,40,43,44,45,46,47,48,50,52,55,56,58,59,60,61 channels for all dependent
[2021-07-15 17:38:58] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: decode3.conv2,encode3.conv0,decode2.conv0,encode3.conv1,encode2.conv2,decode2.conv1
[2021-07-15 17:38:58] INFO (torch filter pruners/MainThread) Prune the 0,4,5,7,10,12,17,18,19,22,23,24,25,26,27,29,34,38,39,40,42,46,48,49,50,52,53,55,56,58,61,62 channels for all dependent
[2021-07-15 17:38:58] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: decode3.conv1,encode3.conv2,encode4.conv1,encode4.conv0,decode3.conv0,decode4.conv2
[2021-07-15 17:38:58] INFO (torch filter pruners/MainThread) Prune the 0,1,6,8,10,14,15,16,17,18,21,22,24,28,35,37,38,39,41,42,43,45,46,50,54,55,56,57,60,61,62,63 channels for all dependent
[2021-07-15 17:38:58] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: bottleneck.conv2,decode4.conv1,decode4.conv0,encode4.conv2,bottleneck.conv1,bottleneck.conv0
[2021-07-15 17:38:58] INFO (torch filter pruners/MainThread) Prune the 0,1,3,4,5,7,11,12,13,14,15,18,19,22,25,26,27,29,34,35,40,41,45,46,47,50,52,53,55,58,61,62 channels for all dependent
[2021-07-15 17:38:58] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: decode1.conv2
[2021-07-15 17:38:58] INFO (torch filter pruners/MainThread) Prune the 1,2,4,6,7,8,10,12,13,15,18,20,21,22,23,30,32,36,39,40,42,44,45,48,49,50,52,55,57,60,62,63 channels for all dependent
[2021-07-15 17:38:58] INFO (nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner/MainThread) Pruning the dependent layers: classifier.conv
[2021-07-15 17:38:58] INFO (torch filter pruners/MainThread) Prune the 6,8,9,11,13,14,25,26,27,28,29,34,35,37,41,43,47,48,51,53,54,55,56,57,58,59,60,65,67,68,69,70,71,72,74,75,76,77,78 channels for all dependent
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode1.conv0 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode1.conv1 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode1.conv2 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode2.conv0 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode2.conv1 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode2.conv2 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode3.conv0 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode3.conv1 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode3.conv2 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode4.conv0 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode4.conv1 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune encode4.conv2 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune bottleneck.conv0 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune bottleneck.conv1 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune bottleneck.conv2 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode4.conv0 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode4.conv1 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode4.conv2 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode3.conv0 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode3.conv1 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode3.conv2 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode2.conv0 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode2.conv1 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode2.conv2 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode1.conv0 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode1.conv1 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune decode1.conv2 remain/total: 32/64
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) simulated prune classifier.conv remain/total: 40/79
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to ./pruned_models/pruned_L1_0.5.pth
[2021-07-15 17:38:58] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to ./pruned_models/mask_L1_0.5.pth
speeding up the model by truly pruning
[2021-07-15 17:39:11] INFO (nni.compression.pytorch.speedup.compressor/MainThread) start to speed up the model
[2021-07-15 17:39:11] INFO (nni.compression.pytorch.speedup.compressor/MainThread) fix the mask conflict of the interdependent layers
[2021-07-15 17:39:24] INFO (nni.compression.pytorch.utils.mask_conflict/MainThread) {'encode1.conv0': 1, 'encode1.conv1': 1, 'encode1.conv2': 1, 'encode2.conv0': 1, 'encode2.conv1': 1, 'encode2.conv2': 1, 'encode3.conv0': 1, 'encode3.conv1': 1, 'encode3.conv2': 1, 'encode4.conv0': 1, 'encode4.conv1': 1, 'encode4.conv2': 1, 'bottleneck.conv0': 1, 'bottleneck.conv1': 1, 'bottleneck.conv2': 1, 'decode4.conv0': 1, 'decode4.conv1': 1, 'decode4.conv2': 1, 'decode3.conv0': 1, 'decode3.conv1': 1, 'decode3.conv2': 1, 'decode2.conv0': 1, 'decode2.conv1': 1, 'decode2.conv2': 1, 'decode1.conv0': 1, 'decode1.conv1': 1, 'decode1.conv2': 1, 'classifier.conv': 1}
[2021-07-15 17:39:24] INFO (nni.compression.pytorch.utils.mask_conflict/MainThread) dim0 sparsity: 0.499723
[2021-07-15 17:39:24] INFO (nni.compression.pytorch.utils.mask_conflict/MainThread) dim1 sparsity: 0.000000
[2021-07-15 17:39:24] INFO (nni.compression.pytorch.utils.mask_conflict/MainThread) detected conv prune dim: 0
[2021-07-15 17:39:24] INFO (nni.compression.pytorch.speedup.compressor/MainThread) infer module masks...
Traceback (most recent call last):
  File "/home/nikhil/projects/green_comp_neuro/FastSurfer/FastSurferCNN/prune_model.py", line 161, in <module>
    model = compress_model(model, dummy_data, prune_type, prune_percent, device, prune_save_dir)
  File "/home/nikhil/projects/green_comp_neuro/FastSurfer/FastSurferCNN/prune_model.py", line 97, in compress_model
    m_speedup.speedup_model()
  File "../../nni/nni/compression/pytorch/speedup/compressor.py", line 183, in speedup_model
    self.infer_modules_masks()
  File "../../nni/nni/compression/pytorch/speedup/compressor.py", line 140, in infer_modules_masks
    self.infer_module_mask(module_name, None, mask=mask)
  File "../../nni/nni/compression/pytorch/speedup/compressor.py", line 124, in infer_module_mask
    self.infer_module_mask(_module_name, module_name, in_shape=output_cmask)
  File "../../nni/nni/compression/pytorch/speedup/compressor.py", line 124, in infer_module_mask
    self.infer_module_mask(_module_name, module_name, in_shape=output_cmask)
  File "../../nni/nni/compression/pytorch/speedup/compressor.py", line 92, in infer_module_mask
    .format(m_type, module_name))
RuntimeError: Has not supported infering output shape from input shape for module/function: `aten::unsqueeze`, .aten::unsqueeze.90

I will appreciate any help or suggestions! Thanks!

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions