This repository was archived by the owner on Sep 18, 2024. It is now read-only.
This repository was archived by the owner on Sep 18, 2024. It is now read-only.
Has not supported replacing the module: InstanceNorm2d
#4387
Open
Description
Describe the issue:
下面定义一个简单使用了nn.InstanceNorm2d
的SimpleModel模型,
use_inorm=False可以正常pruning,
但use_inorm=True,出现:Has not supported replacing the module: InstanceNorm2d
的错误
Example:
# -*-coding: utf-8 -*-
import os
import copy
import torch
import torch.nn as nn
import torch.onnx
import torch.nn.functional as F
from nni.compression.pytorch.utils.counter import count_flops_params
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.algorithms.compression.pytorch import pruning
from nni.compression.pytorch import apply_compression_results
def model_pruning(model: nn.Module,
input_size=[1, 3, 128, 128],
sparsity=0.2,
prune_mod="FPGM",
output_prune="pruning_output",
mask_file="",
dependency_aware=True,
device="cpu",
verbose=False,
**kwargs):
info = ""
model = model.to(device)
if not os.path.exists(output_prune): os.makedirs(output_prune)
prune_file = os.path.join(output_prune, 'pruned_naive_{}filter.pth'.format(prune_mod))
onnx_file = os.path.join(output_prune, 'pruned_naive_{}filter.onnx'.format(prune_mod))
mask_file = os.path.join(output_prune, 'mask_naive_{}filter.pth'.format(prune_mod)) if not mask_file else mask_file
dummy_input = torch.randn(input_size).to(device)
# 原始模型的计算量和参数量
flops, params, _ = count_flops_params(model, dummy_input, verbose=verbose)
info += f"origin-Model FLOPs {flops / 1e6:.2f}M, Params {params / 1e6:.2f}M\n"
# 模型剪枝,会生成mask文件(mask_naive_l1filter.pth)
if prune_mod.lower() == "Level".lower():
config = [{'sparsity': sparsity, 'op_types': ['Conv2d']}]
pruner = pruning.LevelPruner(model, config)
elif prune_mod.lower() == "L1".lower():
# op_types : Only Conv2d is supported in L1FilterPruner.
# config = [{'sparsity': sparsity, 'op_types': ['Conv2d'], "exclude": False}]
config = [{'sparsity': sparsity, 'op_types': ['Conv2d']}]
pruner = pruning.L1FilterPruner(model, config, dependency_aware, dummy_input=dummy_input)
elif prune_mod.lower() == "L2".lower():
# op_types : Only Conv2d is supported in L2FilterPruner.
config = [{'sparsity': sparsity, 'op_types': ['Conv2d']}]
pruner = pruning.L2FilterPruner(model, config, dependency_aware, dummy_input=dummy_input)
elif prune_mod.lower() == "FPGM".lower():
# op_types : Only Conv2d is supported in FPGM Pruner
config = [{'sparsity': sparsity, 'op_types': ['Conv2d']}]
pruner = pruning.FPGMPruner(model, config, dependency_aware, dummy_input=dummy_input)
elif prune_mod.lower() == "Slim".lower():
config = [{'sparsity': sparsity, 'op_types': ['BatchNorm2d']}]
pruner = pruning.ActivationMeanRankFilterPruner()
else:
raise Exception("Error prune_mod:{}".format(prune_mod))
# compress the model, the mask will be updated.
pruner.compress()
# pruner.get_pruned_weights()
# use a dummy input to apply the sparsify.
out = model(dummy_input)
# 剪枝后模型的计算量和参数量
flops, params, _ = count_flops_params(model, dummy_input, verbose=verbose)
info += f"pruner-Model FLOPs {flops / 1e6:.2f}M, Params {params / 1e6:.2f}M\n"
# export the sparsified and mask model
pruner.export_model(model_path=prune_file, mask_path=mask_file,
onnx_path=onnx_file, input_shape=dummy_input.shape,
device=device,
opset_version=11)
# speedup the model with provided weight mask.If you use a wrapped model, don't forget to unwrap it.
pruner._unwrap_model()
# 将掩码应用到模型,模型会变得更小,推理延迟也会减小
# apply_compression_results(model, mask_file, device)
if not os.path.exists(mask_file): raise Exception("not found mask file:{}".format(mask_file))
print("load mask file to speed up:{}".format(mask_file))
speed_up = ModelSpeedup(model, dummy_input=dummy_input, masks_file=mask_file)
speed_up.speedup_model()
out = model(dummy_input)
# speedup后模型的计算量和参数量
flops, params, _ = count_flops_params(model, dummy_input, verbose=verbose)
info += f"speedup-Model FLOPs {flops / 1e6:.2f}M, Params {params / 1e6:.2f}M\n"
print(info)
# finetune the model to recover the accuracy.
return model
class SimpleModel(nn.Module):
def __init__(self, num_classes, use_inorm=True):
super(SimpleModel, self).__init__()
self.use_inorm = use_inorm
self.conv1 = nn.Conv2d(3, 32, 3)
if self.use_inorm:
self.inorm1 = nn.InstanceNorm2d(32, affine=False)
self.conv2 = nn.Conv2d(32, 64, 3)
if self.use_inorm:
self.inorm2 = nn.InstanceNorm2d(64, affine=False)
self.conv3 = nn.Conv2d(64, 128, 3)
self.fc = nn.Linear(128, 256)
self.classifier = nn.Linear(256, num_classes)
def forward(self, x):
x = F.relu(self.conv1(x))
if self.use_inorm:
x = self.inorm1(x)
x = F.relu(self.conv2(x))
if self.use_inorm:
x = self.inorm2(x)
x = F.relu(self.conv3(x))
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
x = self.fc(x)
x = self.classifier(x)
return x
if __name__ == "__main__":
device = "cuda:0"
num_classes = 20
input_size = [1, 3, 128, 128]
# use_inorm=False可以正常pruning,
# 但use_inorm=True,Has not supported replacing the module: `InstanceNorm2d`
# model = SimpleModel(num_classes=num_classes, use_inorm=False)
model = SimpleModel(num_classes=num_classes, use_inorm=True)
model.eval()
inputs = torch.randn(input_size)
model = model.to(device)
inputs = inputs.to(device)
output = model(inputs)
prune_model = copy.deepcopy(model)
prune_model = model_pruning(prune_model, input_size=input_size, sparsity=0.2, dependency_aware=True, device=device)
print("inputs:", inputs.shape)
print("output:", output.shape)