🐛 [Bug] torch_tensorrt.compile
does not work with nn.ConvTranspose2d
and output_padding #3352
Open
Description
Bug Description
Trying to use torch_tensorrt.compile
to compile a model using nn.ConvTranspose2d
with output_padding = 1
raises the following error:
RuntimeError: Target aten.convolution.default does not support `transposed=True`
While executing %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %deconv_weight, %deconv_bias, [2, 2], [0, 0], [1, 1], True, [1, 1], 1), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x72fa0452d6b0>: ((4, 256, 296, 296), torch.float32, False, (22429696, 87616, 296, 1), torch.contiguous_format, False, {})}})
Note: using default value
output_padding = 0
works fine. I have not tried with other values.
To Reproduce
The followig code allows to reproduce the error
import torch
import torch.nn as nn
import torch_tensorrt
class ToyModel(nn.Module):
def __init__(self) -> None:
super(ToyModel, self).__init__()
self.deconv = nn.ConvTranspose2d(
256,
128,
kernel_size=3,
output_padding=1,
stride=2,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.deconv(x)
return x
def main():
# Create a toy model instance
model = ToyModel().eval().cuda()
# Create dummy input
input_tensor = torch.randn(4, 256, 296, 296).cuda()
# Compile the model with torch_tensorrt
trt_model = torch_tensorrt.compile(
model,
inputs=[input_tensor],
enabled_precisions={torch.float32},
min_block_size=1,
)
if __name__ == "__main__":
main()
Environment
- Torch-TensorRT Version (e.g. 1.0.0): 2.5.0
- PyTorch Version (e.g. 1.0): 2.5.1+cu124
- CPU Architecture: x86_64
- OS (e.g., Linux): Ubuntu 22.04
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives: No
- Python version: 3.11.11
- CUDA version: 12.6
- GPU models and configuration: NVIDIA GeForce RTX 4090
- Any other relevant information: