From 933e3aca016d901a85c6f1bd0fc6692e2026b038 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 20 Jan 2025 20:05:58 +0900 Subject: [PATCH 1/4] fix: remove legacy conv converter --- .../dynamo/conversion/aten_ops_converters.py | 3 + .../fx/converters/aten_ops_converters.py | 108 +++++++++--------- run_aot_moria_2dunet.py | 43 +++++++ 3 files changed, 100 insertions(+), 54 deletions(-) create mode 100644 run_aot_moria_2dunet.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5254c6a0ac..42f1184fc1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2449,6 +2449,8 @@ def aten_ops_le( def conv_param_validator( conv_node: Node, settings: Optional[CompilationSettings] = None ) -> bool: + + # return True return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) @@ -2500,6 +2502,7 @@ def aten_ops_convolution( stride=args[3], padding=args[4], dilation=args[5], + # output_padding=args[7], groups=args[8], ) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index f11e40a6db..b46d4cfa40 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -103,60 +103,60 @@ def aten_ops_batch_norm( ) -@tensorrt_converter(torch.ops.aten.convolution.default) -def aten_ops_convolution( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - "weight": args[1], - "bias": args[2], - "stride": args[3], - "padding": args[4], - "dilation": args[5], - "groups": args[8], - } - # we do not handle transposed. - if args[6] is True: - raise RuntimeError(f"Target {target} does not support `transposed=True` ") - # we do not handle output_padding. - if args[7] not in ([0], [0, 0], [0, 0, 0]): - raise RuntimeError(f"Target {target} has non-0 output_padding") - - if len(kwargs_new["stride"]) == 1: - return convolution.convNd( - network, - target, - source_ir=SourceIR.ATEN, - name=name, - is_conv1d=True, - input_val=kwargs_new["input"], - weight=kwargs_new["weight"], - bias=kwargs_new["bias"], - stride=kwargs_new["stride"], - padding=kwargs_new["padding"], - dilation=kwargs_new["dilation"], - groups=kwargs_new["groups"], - ) - else: - return convolution.convNd( - network, - target, - source_ir=SourceIR.ATEN, - name=name, - is_conv1d=False, - input_val=kwargs_new["input"], - weight=kwargs_new["weight"], - bias=kwargs_new["bias"], - stride=kwargs_new["stride"], - padding=kwargs_new["padding"], - dilation=kwargs_new["dilation"], - groups=kwargs_new["groups"], - ) +# @tensorrt_converter(torch.ops.aten.convolution.default) +# def aten_ops_convolution( +# network: TRTNetwork, +# target: Target, +# args: Tuple[Argument, ...], +# kwargs: Dict[str, Argument], +# name: str, +# ) -> Union[TRTTensor, Sequence[TRTTensor]]: +# kwargs_new = { +# "input": args[0], +# "weight": args[1], +# "bias": args[2], +# "stride": args[3], +# "padding": args[4], +# "dilation": args[5], +# "groups": args[8], +# } +# # we do not handle transposed. +# if args[6] is True: +# raise RuntimeError(f"Target {target} does not support `transposed=True` ") +# # we do not handle output_padding. +# if args[7] not in ([0], [0, 0], [0, 0, 0]): +# raise RuntimeError(f"Target {target} has non-0 output_padding") + +# if len(kwargs_new["stride"]) == 1: +# return convolution.convNd( +# network, +# target, +# source_ir=SourceIR.ATEN, +# name=name, +# is_conv1d=True, +# input_val=kwargs_new["input"], +# weight=kwargs_new["weight"], +# bias=kwargs_new["bias"], +# stride=kwargs_new["stride"], +# padding=kwargs_new["padding"], +# dilation=kwargs_new["dilation"], +# groups=kwargs_new["groups"], +# ) +# else: +# return convolution.convNd( +# network, +# target, +# source_ir=SourceIR.ATEN, +# name=name, +# is_conv1d=False, +# input_val=kwargs_new["input"], +# weight=kwargs_new["weight"], +# bias=kwargs_new["bias"], +# stride=kwargs_new["stride"], +# padding=kwargs_new["padding"], +# dilation=kwargs_new["dilation"], +# groups=kwargs_new["groups"], +# ) @tensorrt_converter(torch.ops.aten.div.default) diff --git a/run_aot_moria_2dunet.py b/run_aot_moria_2dunet.py new file mode 100644 index 0000000000..b0b6ea8353 --- /dev/null +++ b/run_aot_moria_2dunet.py @@ -0,0 +1,43 @@ +from monai.networks.nets import UNet +import torch +import torch_tensorrt + +device = "cuda:0" + +# Define the 2D U-Net model +model = UNet( + spatial_dims=2, + in_channels=3, + out_channels=2, + channels=(16, 32, 64, 128), + strides=(2, 2, 2), + num_res_units=2, + act="relu", + norm="batch", + dropout=0.1, +).to(device).half().eval() + +# (batch size, channels, height, width) +input_tensor = torch.randn(1, 3, 256, 256, device=device).half() + +backend = "torch_tensorrt" + +# Compile the model with Torch-TensorRT backend +model = torch.compile( + model, + backend=backend, + options={ + "use_python_runtime": False, + "enabled_precisions": {torch.float16}, + "truncate_double": True, + "debug": True, + "min_block_size": 1, + }, + dynamic=False, +) + +# Perform inference with the compiled model +with torch.no_grad(): + output = model(input_tensor) + +print(output) From f48f0403f67c58a66cb8f066fa98d817a146231d Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 20 Jan 2025 20:05:58 +0900 Subject: [PATCH 2/4] chore: remove comments and linting --- .../dynamo/conversion/aten_ops_converters.py | 4 +- .../fx/converters/aten_ops_converters.py | 59 +------------------ run_aot_moria_2dunet.py | 43 -------------- 3 files changed, 4 insertions(+), 102 deletions(-) delete mode 100644 run_aot_moria_2dunet.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 42f1184fc1..5fd2f3ff37 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -7,6 +7,7 @@ import numpy as np import torch from torch.fx.node import Argument, Node, Target + from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl @@ -2449,8 +2450,7 @@ def aten_ops_le( def conv_param_validator( conv_node: Node, settings: Optional[CompilationSettings] = None ) -> bool: - - # return True + return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index b46d4cfa40..795ae7c4d9 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -10,9 +10,10 @@ # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch.fx.immutable_collections import immutable_list from torch.fx.node import Argument, Target + +import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.converters.impl import activation, convolution @@ -103,62 +104,6 @@ def aten_ops_batch_norm( ) -# @tensorrt_converter(torch.ops.aten.convolution.default) -# def aten_ops_convolution( -# network: TRTNetwork, -# target: Target, -# args: Tuple[Argument, ...], -# kwargs: Dict[str, Argument], -# name: str, -# ) -> Union[TRTTensor, Sequence[TRTTensor]]: -# kwargs_new = { -# "input": args[0], -# "weight": args[1], -# "bias": args[2], -# "stride": args[3], -# "padding": args[4], -# "dilation": args[5], -# "groups": args[8], -# } -# # we do not handle transposed. -# if args[6] is True: -# raise RuntimeError(f"Target {target} does not support `transposed=True` ") -# # we do not handle output_padding. -# if args[7] not in ([0], [0, 0], [0, 0, 0]): -# raise RuntimeError(f"Target {target} has non-0 output_padding") - -# if len(kwargs_new["stride"]) == 1: -# return convolution.convNd( -# network, -# target, -# source_ir=SourceIR.ATEN, -# name=name, -# is_conv1d=True, -# input_val=kwargs_new["input"], -# weight=kwargs_new["weight"], -# bias=kwargs_new["bias"], -# stride=kwargs_new["stride"], -# padding=kwargs_new["padding"], -# dilation=kwargs_new["dilation"], -# groups=kwargs_new["groups"], -# ) -# else: -# return convolution.convNd( -# network, -# target, -# source_ir=SourceIR.ATEN, -# name=name, -# is_conv1d=False, -# input_val=kwargs_new["input"], -# weight=kwargs_new["weight"], -# bias=kwargs_new["bias"], -# stride=kwargs_new["stride"], -# padding=kwargs_new["padding"], -# dilation=kwargs_new["dilation"], -# groups=kwargs_new["groups"], -# ) - - @tensorrt_converter(torch.ops.aten.div.default) @tensorrt_converter(torch.ops.aten.div.Tensor_mode) @tensorrt_converter(torch.ops.aten.div.Tensor) diff --git a/run_aot_moria_2dunet.py b/run_aot_moria_2dunet.py deleted file mode 100644 index b0b6ea8353..0000000000 --- a/run_aot_moria_2dunet.py +++ /dev/null @@ -1,43 +0,0 @@ -from monai.networks.nets import UNet -import torch -import torch_tensorrt - -device = "cuda:0" - -# Define the 2D U-Net model -model = UNet( - spatial_dims=2, - in_channels=3, - out_channels=2, - channels=(16, 32, 64, 128), - strides=(2, 2, 2), - num_res_units=2, - act="relu", - norm="batch", - dropout=0.1, -).to(device).half().eval() - -# (batch size, channels, height, width) -input_tensor = torch.randn(1, 3, 256, 256, device=device).half() - -backend = "torch_tensorrt" - -# Compile the model with Torch-TensorRT backend -model = torch.compile( - model, - backend=backend, - options={ - "use_python_runtime": False, - "enabled_precisions": {torch.float16}, - "truncate_double": True, - "debug": True, - "min_block_size": 1, - }, - dynamic=False, -) - -# Perform inference with the compiled model -with torch.no_grad(): - output = model(input_tensor) - -print(output) From 151340bb2fea04ccb73bb9fd35e9393ea098a1ee Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 20 Jan 2025 20:05:58 +0900 Subject: [PATCH 3/4] feat: support output_padding argument in deconv converter --- .../dynamo/conversion/aten_ops_converters.py | 10 +-- .../dynamo/conversion/impl/deconv.py | 23 +++++++ .../fx/converters/aten_ops_converters.py | 56 +++++++++++++++++ .../conversion/test_deconvolution_aten.py | 61 +++++++++++++++++-- 4 files changed, 136 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5fd2f3ff37..eac8a4b70c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2447,16 +2447,8 @@ def aten_ops_le( ) -def conv_param_validator( - conv_node: Node, settings: Optional[CompilationSettings] = None -) -> bool: - - return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) - - @dynamo_tensorrt_converter( torch.ops.aten.convolution.default, - capability_validator=conv_param_validator, supports_dynamic_shapes=True, ) @enforce_tensor_types( @@ -2502,7 +2494,7 @@ def aten_ops_convolution( stride=args[3], padding=args[4], dilation=args[5], - # output_padding=args[7], + output_padding=args[7], groups=args[8], ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index 03a209e2a5..d19a92e646 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -6,6 +6,7 @@ import tensorrt as trt import torch from torch.fx.node import Target + from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( @@ -105,6 +106,9 @@ def deconvNd( padding = (padding,) if isinstance(padding, int) else padding stride = (stride,) if isinstance(stride, int) else stride dilation = (dilation,) if isinstance(dilation, int) else dilation + output_padding = ( + (output_padding,) if isinstance(output_padding, int) else output_padding + ) # Expand parameters manually for Conv1D computations if is_deconv1d: @@ -113,6 +117,11 @@ def deconvNd( dilation = ( extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation ) + output_padding = ( + (tuple(output_padding) + (0,)) + if output_padding is not None + else output_padding + ) set_layer_name(deconv_layer, target, name, source_ir) @@ -126,6 +135,20 @@ def deconvNd( if groups is not None: deconv_layer.num_groups = groups + ndims = len(padding) + pre_padding_values = [] + post_padding_values = [] + + for dim in range(ndims): + pre_padding = padding[dim] + post_padding = padding[dim] - output_padding[dim] + + pre_padding_values.append(pre_padding) + post_padding_values.append(post_padding) + + deconv_layer.pre_padding = tuple(pre_padding_values) + deconv_layer.post_padding = tuple(post_padding_values) + # Handle quantization cases if scale is not None and zero_point is not None: # Assume the dtype of activation is torch.quint8 diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 795ae7c4d9..a725ce8aa3 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -104,6 +104,62 @@ def aten_ops_batch_norm( ) +@tensorrt_converter(torch.ops.aten.convolution.default) +def aten_ops_convolution( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "weight": args[1], + "bias": args[2], + "stride": args[3], + "padding": args[4], + "dilation": args[5], + "groups": args[8], + } + # we do not handle transposed. + if args[6] is True: + raise RuntimeError(f"Target {target} does not support `transposed=True` ") + # we do not handle output_padding. + if args[7] not in ([0], [0, 0], [0, 0, 0]): + raise RuntimeError(f"Target {target} has non-0 output_padding") + + if len(kwargs_new["stride"]) == 1: + return convolution.convNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + is_conv1d=True, + input_val=kwargs_new["input"], + weight=kwargs_new["weight"], + bias=kwargs_new["bias"], + stride=kwargs_new["stride"], + padding=kwargs_new["padding"], + dilation=kwargs_new["dilation"], + groups=kwargs_new["groups"], + ) + else: + return convolution.convNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + is_conv1d=False, + input_val=kwargs_new["input"], + weight=kwargs_new["weight"], + bias=kwargs_new["bias"], + stride=kwargs_new["stride"], + padding=kwargs_new["padding"], + dilation=kwargs_new["dilation"], + groups=kwargs_new["groups"], + ) + + @tensorrt_converter(torch.ops.aten.div.default) @tensorrt_converter(torch.ops.aten.div.Tensor_mode) @tensorrt_converter(torch.ops.aten.div.Tensor) diff --git a/tests/py/dynamo/conversion/test_deconvolution_aten.py b/tests/py/dynamo/conversion/test_deconvolution_aten.py index d6cbc0579f..046c646871 100644 --- a/tests/py/dynamo/conversion/test_deconvolution_aten.py +++ b/tests/py/dynamo/conversion/test_deconvolution_aten.py @@ -1,6 +1,7 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests + from torch_tensorrt import Input from .harness import DispatchTestCase @@ -15,6 +16,21 @@ class TestDeconvolutionConverter(DispatchTestCase): param("non_zero_padding", 1, padding=1), param("dilation", 1, dilation=2), param("groups", 1, groups=3), + param("output_padding_1", 3, stride=2, padding=1, output_padding=1), + param("output_padding_2", 3, stride=2, padding=2, output_padding=1), + param("output_padding_3", 3, stride=2, padding=3, output_padding=1), + param("output_padding_4", 3, stride=3, padding=2, output_padding=1), + param("output_padding_5", 3, stride=3, padding=3, output_padding=1), + param("output_padding_6", 3, stride=3, padding=3, output_padding=2), + param( + "combined_params", + 3, + stride=3, + padding=3, + dilation=2, + groups=3, + output_padding=2, + ), ] ) def test_deconv1d( @@ -26,6 +42,7 @@ def test_deconv1d( dilation=1, groups=1, bias=True, + output_padding=0, ): class TestModule(torch.nn.Module): def __init__(self): @@ -36,9 +53,10 @@ def __init__(self): kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, + output_padding=output_padding, groups=groups, bias=bias, + dilation=dilation, ) def forward(self, x): @@ -101,6 +119,22 @@ def forward(self, x): param("non_zero_padding", 1, padding=1), param("dilation", 1, dilation=2), param("groups", 1, groups=3), + param("output_padding_1", 3, stride=2, padding=1, output_padding=1), + param("output_padding_2", 3, stride=2, padding=1, output_padding=1), + param("output_padding_3", 3, stride=2, padding=2, output_padding=1), + param("output_padding_4", 3, stride=2, padding=3, output_padding=1), + param("output_padding_5", 3, stride=3, padding=2, output_padding=1), + param("output_padding_6", 3, stride=3, padding=3, output_padding=1), + param("output_padding_7", 3, stride=3, padding=3, output_padding=2), + param( + "combined_params", + 3, + stride=3, + padding=3, + dilation=2, + groups=3, + output_padding=2, + ), ] ) def test_deconv2d( @@ -112,6 +146,7 @@ def test_deconv2d( dilation=1, groups=1, bias=True, + output_padding=0, ): class TestModule(torch.nn.Module): def __init__(self): @@ -122,9 +157,10 @@ def __init__(self): kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, + output_padding=output_padding, groups=groups, bias=bias, + dilation=dilation, ) def forward(self, x): @@ -172,6 +208,19 @@ def forward(self, x): param("non_zero_padding", 1, padding=1), param("dilation", 1, dilation=2), param("groups", 1, groups=3), + param("output_padding_1", 3, stride=2, padding=1, output_padding=1), + param("output_padding_2", 3, stride=2, padding=2, output_padding=1), + param("output_padding_3", 3, stride=3, padding=3, output_padding=1), + param("output_padding_4", 3, stride=3, padding=3, output_padding=2), + param( + "combined_params", + 3, + stride=3, + padding=3, + dilation=2, + groups=3, + output_padding=2, + ), ] ) def test_deconv3d( @@ -183,6 +232,7 @@ def test_deconv3d( dilation=1, groups=1, bias=True, + output_padding=0, ): class TestModule(torch.nn.Module): def __init__(self): @@ -193,9 +243,10 @@ def __init__(self): kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, + output_padding=output_padding, groups=groups, bias=bias, + dilation=dilation, ) def forward(self, x): @@ -209,8 +260,8 @@ def forward(self, x): enable_passes=True, ) - # Testing with (-1, -1, -1, -1, -1) results into Error: - # AssertionError: Channel dim can't be dynamic for deconvolution. + # # Testing with (-1, -1, -1, -1, -1) results into Error: + # # AssertionError: Channel dim can't be dynamic for deconvolution. def test_deconv3d_with_dynamic_shape(self): class TestModule(torch.nn.Module): From c399fdde65ad09548f7afd3dfdf00dc523fd4ed4 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 20 Jan 2025 20:05:58 +0900 Subject: [PATCH 4/4] chore: minor lint issue --- tests/py/dynamo/conversion/test_deconvolution_aten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/conversion/test_deconvolution_aten.py b/tests/py/dynamo/conversion/test_deconvolution_aten.py index 046c646871..1909cb8fbb 100644 --- a/tests/py/dynamo/conversion/test_deconvolution_aten.py +++ b/tests/py/dynamo/conversion/test_deconvolution_aten.py @@ -260,8 +260,8 @@ def forward(self, x): enable_passes=True, ) - # # Testing with (-1, -1, -1, -1, -1) results into Error: - # # AssertionError: Channel dim can't be dynamic for deconvolution. + # Testing with (-1, -1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for deconvolution. def test_deconv3d_with_dynamic_shape(self): class TestModule(torch.nn.Module):