Skip to content

DeformConv2d Cannot be Quantized #2794

Open
@anzr299

Description

@anzr299

🐛 Describe the bug

Operator Metatype DeformConv2dOp is mentioned in nncf/nncf/torch/graph/operator_metatypes.py which uses the namespace torch.nn.functional whereas the function deform_conv2d belongs to torchvision.ops.deform_conv2d. As seen in the code output attached below, this deformable convolution was not quantized.

Environment

about-time==4.2.1
absl-py==2.1.0
accelerate==0.28.0
accuracy_checker @ git+https://github.com/openvinotoolkit/open_model_zoo.git@37f60eb7fe1dcdedc552b2fb184d646723ed5e80#subdirectory=tools/accuracy_checker
addict==2.4.0
aiohttp==3.9.5
aiosignal==1.3.1
alive-progress==3.1.5
async-timeout==4.0.3
attrs==23.2.0
autograd==1.6.2
certifi==2024.7.4
cfgv==3.4.0
charset-normalizer==3.3.2
cma==3.2.2
coloredlogs==15.0.1
contourpy==1.2.1
coverage==7.5.4
cycler==0.12.1
datasets==2.14.7
defusedxml==0.7.1
Deprecated==1.2.14
dill==0.3.7
distlib==0.3.8
efficientnet-pytorch==0.7.1
evaluate==0.3.0
exceptiongroup==1.2.1
execnet==2.1.1
fastcore==1.5.48
fastdownload==0.0.7
fastprogress==1.0.3
filelock==3.15.4
flatbuffers==24.3.25
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2023.10.0
future==1.0.0
grapheme==0.6.0
grpcio==1.64.1
huggingface-hub==0.23.4
humanfriendly==10.0
identify==2.6.0
idna==3.7
iniconfig==2.0.0
Jinja2==3.1.4
joblib==1.4.2
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jstyleson==0.0.2
kiwisolver==1.4.5
lightning-utilities==0.11.3.post0
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.1
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.15
natsort==8.4.0
networkx==3.3
ninja==1.11.1.1
-e git+https://github.com/anzr299/nncf.git@bfc94b9d1078024b04246f8fc41106582c227f7b#egg=nncf
nodeenv==1.9.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.82
nvidia-nvtx-cu12==12.1.105
onnx==1.16.0
onnxruntime==1.17.1
opencv-python==4.10.0.84
openvino==2024.2.0
openvino-telemetry==2024.1.0
packaging==24.1
pandas==2.2.2
pillow==10.4.0
platformdirs==4.2.2
pluggy==1.5.0
pre-commit==3.2.2
protobuf==4.25.3
psutil==6.0.0
pyarrow==16.1.0
pyarrow-hotfix==0.6
pycocotools==2.0.7
pydot==2.0.0
Pygments==2.18.0
pymoo==0.6.1.1
pyparsing==3.1.2
pytest==8.0.2
pytest-cov==4.1.0
pytest-dependency==0.6.0
pytest-mock==3.12.0
pytest-xdist==3.5.0
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
referencing==0.35.1
regex==2024.5.15
requests==2.32.3
responses==0.18.0
rich==13.7.1
rpds-py==0.18.1
safetensors==0.4.3
scikit-learn==1.5.1
scipy==1.14.0
six==1.16.0
sympy==1.12.1
tabulate==0.9.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
threadpoolctl==3.5.0
timm==0.9.2
tokenizers==0.15.2
tomli==2.0.1
torch==2.3.0
torchmetrics==1.0.1
torchvision==0.18.0
tqdm==4.66.4
transformers==4.38.2
triton==2.3.0
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
virtualenv==20.26.3
Werkzeug==3.0.3
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4

Minimal Reproducible Example


#Sample Deformable Convolution Network Model Definition which utilizes torchvision.ops.DeformConv2d
class DCNV2(nn.Module):
    def __init__(self):
        super(DCNV2, self).__init__()
        self.deform_conv1 = DeformConv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.offset_conv1 = nn.Conv2d(3, 18, kernel_size=3, stride=1, padding=1)
        
        self.deform_conv2 = DeformConv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.offset_conv2 = nn.Conv2d(32, 18, kernel_size=3, stride=1, padding=1)
        
        self.fc = nn.Linear(64 * 8 * 8, 10)
        
    def forward(self, x):
        offset1 = self.offset_conv1(x)
        x = self.deform_conv1(x, offset1)
        x = nn.ReLU()(x)
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        
        offset2 = self.offset_conv2(x)
        x = self.deform_conv2(x, offset2)
        x = nn.ReLU()(x)
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = DCNV2()

#Dummy data for our model
class RandomDataset(torch.utils.data.Dataset):
    def __getitem__(self, index):
        return torch.randn(3, 32, 32), torch.tensor(0)
    
    def __len__(self):
        return 1000

data_loader = torch.utils.data.DataLoader(RandomDataset(), batch_size=32)

#transform function for the calibration dataset
def transform_fn(data_item):
    images, _ = data_item
    return images

calibration_dataset = nncf.Dataset(data_loader, transform_fn)
model.eval()
quantized_model = nncf.quantize(model, calibration_dataset)
print(quantized_model)
'''
OUTPUT: 

DCNV2(
  (deform_conv1): DeformConv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (offset_conv1): Conv2d(3, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (deform_conv2): DeformConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (offset_conv2): Conv2d(32, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc): Linear(in_features=4096, out_features=10, bias=True)
  (_nncf): NNCFNetworkInterface(
    (external_quantizers): ModuleDict(
      (/nncf_model_input_0|OUTPUT): SymmetricQuantizer(bit=8, ch=False)
      (DCNV2/Conv2d[offset_conv1]/conv2d_0|INPUT1): SymmetricQuantizer(bit=8, ch=True)
      (DCNV2/Conv2d[offset_conv2]/conv2d_0|INPUT1): SymmetricQuantizer(bit=8, ch=True)
      (DCNV2/Linear[fc]/linear_0|INPUT1): SymmetricQuantizer(bit=8, ch=True)
    )
  )
)
'''


Are you going to submit a PR?

  • Yes I'd like to help by submitting a PR!

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions