Skip to content

✨[Feature] Support QAT export in dynamo  #2483

Closed
@peri044

Description

Is your feature request related to a problem? Please describe.
Tried nvidia-ammo 0.5 release.
Tried to torch.compile/torch.export the model with QAT nodes but faced the following error

assert isinstance(
AssertionError: expected FunctionType found method <bound method QuantLinearConvBase.quantized_forward of <class 'ammo.torch.quantization.nn.modules.quant_conv.QuantConv2d'>>

** To reproduce **

import torch
import ammo.torch.quantization as atq
from vgg16 import vgg16
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch_tensorrt

# Select quantization config
config = atq.INT8_DEFAULT_CFG

model = vgg16(num_classes=10, init_weights=False).eval().cuda()
calib_set = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    ),
)

# Define forward loop for calibration
def forward_loop():
    for data in calib_set:
        model(data[0].unsqueeze(0).cuda())

# QAT with in-place replacement to quantized modules
atq.quantize(model, config, forward_loop)

x = torch.randn((1, 3, 32, 32), dtype=torch.float32).cuda()
# exp_program = torch.export.export(model, (x,))

compile_spec = {"inputs": [x,],
                "debug": True,
                "min_block_size": 1}
# Torch compile path
trt_model = torch.compile(model, backend="tensorrt", dynamic=None, options=compile_spec)

Describe alternatives you've considered

Additional context

Metadata

Assignees

Labels

bug: tracerBugs that arise from torch.export or torch.compile tracerfeature requestNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions