diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 68bbcc31d0..a5b3ee845c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import numpy as np import tensorrt as trt @@ -11,11 +11,102 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, get_positive_dim, - get_trt_tensor, set_layer_name, ) +def unify_and_concat_trt_tensors( + ctx: ConversionContext, + target: Target, + name: str, + inputs: Sequence[Union[int, np.ndarray, torch.Tensor, TRTTensor]], + concat_axis: int, + cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None, + force_trt_output: bool = False, +) -> Union[TRTTensor, List[int]]: + """ + Normalize all inputs to TRT tensors if needed, optionally cast, and concat if any dynamic. + + Args: + ctx: TensorRT conversion context. + target: Operation Target. + name: Operation Name. + inputs: Sequence of ints / numpy arrays / torch tensors / TRT tensors. + concat_axis: Axis along which to concatenate tensors if dynamic. + cast_dtype: Optional target dtype for casting TRT tensors. + force_trt_output: If True, return TRT tensor even if all inputs are static ints. (True for concat operations) + """ + has_dynamic = any(not isinstance(x, int) for x in inputs) + trt_tensors = [] + + for i, x in enumerate(inputs): + # convert to TRTTensor + if isinstance(x, TRTTensor): + t = x + elif isinstance(x, int) and not has_dynamic and not force_trt_output: + t = x # pure static path + else: + const_arr = np.array([x], dtype=np.int32) + shape = (1,) + if not isinstance(x, int): + const_arr = np.array(x, dtype=np.int32) + shape = (x.numel(),) + + layer = ctx.net.add_constant(shape, const_arr) + set_layer_name(layer, target, f"{name}_dim{i}_const") + t = layer.get_output(0) + trt_tensors.append(t) + + if not has_dynamic and not force_trt_output: + return trt_tensors # all ints + + final_dtype = None + if cast_dtype: + # Explicit cast requested + if isinstance(cast_dtype, _enums.dtype): + final_dtype = cast_dtype.to(trt.DataType) + elif isinstance(cast_dtype, (np.dtype, trt.dtype)): + final_dtype = _enums.dtype._from(cast_dtype).to(trt.DataType) + else: + final_dtype = cast_dtype # already trt.DataType + else: + # Automatic promotion + promoted_type = None + for t in trt_tensors: + if isinstance(t, TRTTensor): + if promoted_type is None: + promoted_type = t.dtype + else: + promoted_type = _enums.dtype._from( + torch.promote_types( + _enums.dtype._from(promoted_type).to(torch.dtype), + _enums.dtype._from(t.dtype).to(torch.dtype), + ) + ).to(trt.DataType) + final_dtype = promoted_type + + # promote remaining ints to TRT consts before concat + for i, t in enumerate(trt_tensors): + if isinstance(t, int): + const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32)) + set_layer_name(const, target, f"{name}_static_{i}_const") + trt_tensors[i] = const.get_output(0) + + # final cast + if final_dtype is not None: + casted = [] + for i, t in enumerate(trt_tensors): + if isinstance(t, TRTTensor): + t = cast_trt_tensor(ctx, t, final_dtype, f"{name}_cast_{i}") + casted.append(t) + trt_tensors = casted + + concat = ctx.net.add_concatenation(trt_tensors) + concat.axis = concat_axis + set_layer_name(concat, target, f"{name}_concat") + return concat.get_output(0) + + def cat( ctx: ConversionContext, target: Target, @@ -25,38 +116,17 @@ def cat( dim: int, cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - trt_inputs = [] - for i, each_input in enumerate(input): - if not isinstance(each_input, TRTTensor): - each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}") - if cast_dtype: - each_input = cast_trt_tensor( - ctx, each_input, cast_dtype, f"{name}_tensor_int32_cast_{i}" - ) - trt_inputs.append(each_input) - - if len(trt_inputs) > 1: - # Cast to promoted type for all inputs - promoted_type = trt_inputs[0].dtype - for each_input in trt_inputs[1:]: - promoted_type = _enums.dtype._from( - torch.promote_types( - _enums.dtype._from(promoted_type).to(torch.dtype), - _enums.dtype._from(each_input.dtype).to(torch.dtype), - ) - ) - trt_promoted_type = promoted_type.to(trt.DataType) - - trt_casted_inputs = [] - for i, each_input in enumerate(trt_inputs): - casted_input = cast_trt_tensor( - ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}" - ) - trt_casted_inputs.append(casted_input) - trt_inputs = trt_casted_inputs - - concat_layer = ctx.net.add_concatenation(trt_inputs) - dim = get_positive_dim(dim, len(trt_inputs[0].shape)) - concat_layer.axis = dim - set_layer_name(concat_layer, target, f"{name}_gather", source_ir) - return concat_layer.get_output(0) + # int is only when cat called in other ops like pad + if not isinstance(input[0], int): + dim = get_positive_dim(dim, len(input[0].shape)) + else: + dim = 0 + return unify_and_concat_trt_tensors( + ctx, + target, + name, + input, + concat_axis=dim, + cast_dtype=cast_dtype, + force_trt_output=True, + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 4b47ca5dec..ac54e18f3a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -9,7 +9,12 @@ has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.dynamo.conversion.impl.cat import ( + unify_and_concat_trt_tensors as unify_trt_shape_tensors, +) +from torch_tensorrt.dynamo.conversion.impl.shape import ( + get_shape_with_dynamic_shape, +) def upsample( @@ -28,14 +33,22 @@ def upsample( if scale_factor is not None: layer.scales = [1.0, 1.0] + list(scale_factor) else: - shape = list(input.shape)[:2] + list(size) + shape = list(input.shape)[:2] + if size is not None: + shape += list(size) if has_dynamic_shape(shape): shape = get_shape_with_dynamic_shape( ctx, target, source_ir, name, shape, input ) layer.set_input(1, shape) else: - layer.shape = shape + trt_shape = unify_trt_shape_tensors( + ctx, target, name, shape, concat_axis=0, force_trt_output=False + ) + if isinstance(trt_shape, list): + layer.shape = trt_shape + else: + layer.set_input(1, trt_shape) if mode == "nearest": layer.resize_mode = trt.InterpolationMode.NEAREST diff --git a/tests/py/dynamo/conversion/test_upsample_aten.py b/tests/py/dynamo/conversion/test_upsample_aten.py index 44c4af2a92..6646cfa63e 100644 --- a/tests/py/dynamo/conversion/test_upsample_aten.py +++ b/tests/py/dynamo/conversion/test_upsample_aten.py @@ -296,6 +296,50 @@ def forward(self, x): ] self.run_test_with_dynamic_shape(TestModule(), input_specs) + @parameterized.expand( + [ + ([torch.tensor(3), 3], None), + (None, [torch.tensor(0.5), 1.5]), + ] + ) + def test_nearest2d_mixed_dynamic_shape(self, output_size, scale_factors): + class TestModule(torch.nn.Module): + def forward(self, x): + out_size = output_size + scale = scale_factors + + return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale) + + input_specs = [ + Input( + min_shape=(1, 1, 1, 1), + opt_shape=(5, 5, 5, 5), + max_shape=(9, 9, 9, 9), + dtype=torch.float32, + ) + ] + self.run_test_with_dynamic_shape(TestModule(), input_specs) + + @parameterized.expand( + [ + # Mix of Tensor and int in output_size + ([torch.tensor(3), 3], None), + # Mix of Tensor and float in scale_factors + (None, [torch.tensor(0.5), 1.5]), + ] + ) + def test_nearest2d_mixed_static_input(self, output_size, scale_factors): + class TestModule(torch.nn.Module): + def forward(self, x): + out_size = output_size + scale = scale_factors + return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale) + + input_size = [7, 7] # H, W + inputs = [torch.randn([1, 1] + input_size)] # shape [1, 1, 7, 7] + + self.run_test(TestModule(), inputs) + if __name__ == "__main__": run_tests()