From 840de6bab81a5917a2aa7efe7a25585c3237bf4a Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 10 Mar 2023 10:39:14 -0800 Subject: [PATCH 01/39] fx2trt converters - reorg of the existing converters, addition of new converters to aten and nn --- .../fx/converters/acc_ops_converters.py | 805 +--------------- py/torch_tensorrt/fx/converters/activation.py | 168 +++- .../fx/converters/aten_ops_converters.py | 200 +++- .../fx/converters/converter_utils.py | 198 ---- .../fx/converters/nn_ops_converters.py | 31 + py/torch_tensorrt/fx/converters/operator.py | 902 ++++++++++++++++++ .../aten_op/test_leaky_relu_aten.py | 53 + 7 files changed, 1350 insertions(+), 1007 deletions(-) create mode 100644 py/torch_tensorrt/fx/converters/nn_ops_converters.py create mode 100644 py/torch_tensorrt/fx/converters/operator.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 77a9b92dfe..689a52743e 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -26,6 +26,8 @@ trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous +import activation +import operator _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -77,7 +79,7 @@ def trt_transposed_linear_converter(network, target, args, kwargs, name): trt.MatrixOperation.NONE, ) set_layer_name(layer, target, f"{name}_mm") - return add_binary_elementwise_layer( + return operator.add_binary_elementwise_layer( network, layer.get_output(0), bias, @@ -676,160 +678,8 @@ def acc_ops_batch_norm( @tensorrt_converter(acc_ops.layer_norm) def acc_ops_layer_norm(network, target, args, kwargs, name): - input_val = kwargs["input"] - - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"LayerNorm received input {input_val} that is not part " - "of the TensorRT region!" - ) - - gamma = kwargs["weight"].detach().cpu().float().numpy() - gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) - beta = kwargs["bias"].detach().cpu().float().numpy() - beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) - eps_field = trt.PluginField( - "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32 - ) - try: - normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32) - except TypeError: - _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") - normalized_shape = np.array([], dtype=np.int32) - - normalized_shape_filed = trt.PluginField( - "normalized_shape", normalized_shape, trt.PluginFieldType.INT32 - ) - field_collection = trt.PluginFieldCollection( - [gamma_field, beta_field, eps_field, normalized_shape_filed] - ) - - try: - if network.has_implicit_batch_dimension: - plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") - else: - plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") - except AssertionError: - _LOGGER.error( - "Unable to find layer norm plugin, fall back to TensorRT implementation." - ) - return layer_norm(network, target, args, kwargs, name) - layer = network.add_plugin_v2([input_val], plugin) - layer.name = name - return layer.get_output(0) - - -def layer_norm( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"LayerNorm received input {input_val} that is not part " - "of the TensorRT region!" - ) - - shape = kwargs["weight"].shape # type: ignore[union-attr] - broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape - gamma = to_numpy(kwargs["weight"].reshape(*shape)) # type: ignore[union-attr] - beta = to_numpy(kwargs["bias"].reshape(*shape)) # type: ignore[union-attr] - eps = kwargs["eps"] - - axes = 0 - for d in range(len(shape)): - axes |= 1 << (len(input_val.shape) - d - 1) - - # E[x] - mean_expected_layer = network.add_reduce( - input_val, trt.ReduceOperation.AVG, axes, keep_dims=True - ) - set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") - - # X-E[x] - sub_trt = add_binary_elementwise_layer( - network, - input_val, - mean_expected_layer.get_output(0), - trt.ElementWiseOperation.SUB, - target, - f"{name}_sub", - ) - # Variance = mean(pow(x_sub_mean,2)) - pow_tensor = network.add_constant( - (1,) * len(input_val.shape), - trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), - ) - pow_tensor.name = f"{name}_power" - pow_var = add_binary_elementwise_layer( - network, - sub_trt, - pow_tensor.get_output(0), - trt.ElementWiseOperation.POW, - target, - f"{name}_pow_var", - ) - mean_trt_layer = network.add_reduce( - pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True - ) - set_layer_name(mean_trt_layer, target, f"{name}_mean") - # Variance + eps - eps_tensor = network.add_constant( - (1,) * len(input_val.shape), - trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), - ) - eps_tensor.name = f"{name}_eps" - add_trt = add_binary_elementwise_layer( - network, - mean_trt_layer.get_output(0), - eps_tensor.get_output(0), - trt.ElementWiseOperation.SUM, - target, - f"{name}_add", - ) - # SQRT((Var + eps)) - sqrt_trt = add_unary_layer( - network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt" - ) - # (x - E[x]) / sqrt((var + eps)) - div_trt = add_binary_elementwise_layer( - network, - sub_trt, - sqrt_trt, - trt.ElementWiseOperation.DIV, - target, - f"{name}_div_trt", - ) - - assert gamma is not None - gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) # type: ignore[attr-defined] - gamma_tensor.name = f"{name}_gamma" - assert beta is not None - beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] - beta_tensor.name = f"{name}_beta" - # y * gamma + beta - scale_layer = add_binary_elementwise_layer( - network, - div_trt, - gamma_tensor.get_output(0), - trt.ElementWiseOperation.PROD, - target, - f"{name}_scale", - ) - return add_binary_elementwise_layer( - network, - scale_layer, - beta_tensor.get_output(0), - trt.ElementWiseOperation.SUM, - target, - name, - ) - - + return operator.add_layer_norm(network, target, kwargs, name) + @tensorrt_converter(acc_ops.softmax) def acc_ops_softmax( network: TRTNetwork, @@ -879,105 +729,8 @@ def acc_ops_tile( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_t = kwargs["input"] - input_val = get_trt_tensor(network, input_t, f"{name}_input") - - dims = tuple(cast(Sequence[int], kwargs["dims"])) - n_input_dims = len(input_val.shape) + ( - 1 if network.has_implicit_batch_dimension else 0 - ) - - if len(dims) > n_input_dims: - assert not network.has_implicit_batch_dimension - layer = network.add_shuffle(input_val) - layer.name = f"{name}_reshape" - num_preceding_ones = len(dims) - n_input_dims - - if len(get_dynamic_dims(input_val.shape)) > 1: - input_shape_layer = network.add_shape(input_val) - input_shape_layer.name = f"{name}_input_shape" - preceding_ones = network.add_constant( - (num_preceding_ones,), - np.ascontiguousarray([1] * num_preceding_ones, np.int32), - ).get_output(0) - reshape_layer = network.add_concatenation( - [preceding_ones, input_shape_layer.get_output(0)] - ) - reshape_layer.axis = 0 - reshape_layer.name = f"{name}_reshape_dims" - layer.set_input(1, reshape_layer.get_output(0)) - else: - layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple( - input_val.shape - ) - input_val = layer.get_output(0) - else: - dims = (1,) * (n_input_dims - len(dims)) + dims - - if network.has_implicit_batch_dimension: - assert dims[0] == 1, "Can't tile the batch dim when it's implicit." - dims = dims[1:] - starts = [0] * len(dims) - shapes = [] - if all(isinstance(d, int) for d in dims): - shapes = [i * j for i, j in zip(input_val.shape, dims)] # type: ignore[union-attr] - else: - shape = [] - for i, (s, d) in enumerate(zip(input_val.shape, dims)): - if isinstance(d, TRTTensor) and len(d.shape) == 0: - d = prepend_ones(network, d, f"{name}_{i}", 1) - else: - d = get_trt_tensor(network, d, f"{name}_{i}") - shape.append(d) - mul = add_binary_elementwise_layer( - network, - s, - d, - trt.ElementWiseOperation.PROD, - target, - f"{name}_mul_{i}", - ) - shapes.append(mul) - dims = shape - # If there's dynmaic dim then there would be negative dims in shapes which is not allowed. - # Here we build a dummy shapes array. - if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] - shapes = [1] * len(dims) - strides = [1] * len(dims) - layer = network.add_slice(input_val, starts, shapes, strides) - layer.mode = trt.SliceMode.WRAP - set_layer_name(layer, target, name) - - if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] - starts_tensor = network.add_constant( - (len(dims),), np.ascontiguousarray([0] * len(dims), np.int32) - ).get_output(0) - if all(isinstance(d, int) for d in dims): - dims_tensor = network.add_constant( - (len(dims),), np.ascontiguousarray(dims, np.int32) - ).get_output(0) - else: - assert all(isinstance(d, TRTTensor) for d in dims) - concat_dims_layer = network.add_concatenation(inputs=dims) - concat_dims_layer.axis = 0 - concat_dims_layer.name = f"{name}_tile_dim" - dims_tensor = concat_dims_layer.get_output(0) - input_shape_layer = network.add_shape(input_val) - input_shape_layer.name = f"{name}_slice_input_shape" - slice_shapes_tensor = add_binary_elementwise_layer( - network, - input_shape_layer.get_output(0), - dims_tensor, - trt.ElementWiseOperation.PROD, - target, - f"{name}_slice_shapes", - ) - layer.set_input(1, starts_tensor) - layer.set_input(2, slice_shapes_tensor) - - return layer.get_output(0) - - + return operator.add_tile(network, target, kwargs, name) + @tensorrt_converter(acc_ops.sign) def acc_ops_sign( network: TRTNetwork, @@ -1004,10 +757,7 @@ def acc_ops_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.RELU - return add_activation_layer(network, input_val, operation_type, target, name) - + return activation.add_relu(network, target, kwargs, name) @tensorrt_converter(acc_ops.leaky_relu) def acc_ops_leaky_relu( @@ -1017,14 +767,8 @@ def acc_ops_leaky_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - negative_slope = kwargs["negative_slope"] - operation_type = trt.ActivationType.LEAKY_RELU - return add_activation_layer( - network, input_val, operation_type, target, name, negative_slope - ) - - + return activation.add_leaky_relu(network, target, kwargs, name) + @tensorrt_converter(acc_ops.elu) def acc_ops_elu( network: TRTNetwork, @@ -1033,12 +777,8 @@ def acc_ops_elu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - alpha = kwargs["alpha"] - operation_type = trt.ActivationType.ELU - return add_activation_layer(network, input_val, operation_type, target, name, alpha) - - + return activation.add_elu(network, target, kwargs, name) + @tensorrt_converter(acc_ops.selu) def acc_ops_selu( network: TRTNetwork, @@ -1047,10 +787,7 @@ def acc_ops_selu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.SELU - return add_activation_layer(network, input_val, operation_type, target, name) - + return activation.add_selu(network, target, kwargs, name) @tensorrt_converter(acc_ops.softsign) def acc_ops_softsign( @@ -1060,11 +797,8 @@ def acc_ops_softsign( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.SOFTSIGN - return add_activation_layer(network, input_val, operation_type, target, name) - - + return activation.add_softsign(network, target, kwargs, name) + @tensorrt_converter(acc_ops.sin) def acc_ops_sin( network: TRTNetwork, @@ -1138,10 +872,7 @@ def acc_ops_tanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.TANH - return add_activation_layer(network, input_val, operation_type, target, name) - + return activation.add_tanh(network, target, kwargs, name) @tensorrt_converter(acc_ops.asin) def acc_ops_asin( @@ -1458,16 +1189,8 @@ def acc_ops_maximum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.MAX, - target, - name, - ) - - + return operator.add_maximum(network, target, kwargs, name) + @tensorrt_converter(acc_ops.minimum) def acc_ops_minimum( network: TRTNetwork, @@ -1476,16 +1199,8 @@ def acc_ops_minimum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.MIN, - target, - name, - ) - - + return operator.add_minimum(network, target, kwargs, name) + @tensorrt_converter(acc_ops.dtype) def acc_ops_dtype( network: TRTNetwork, @@ -1553,44 +1268,7 @@ def acc_ops_logical_and( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `logical_and` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - # we only support both inputs are bool type - if target == acc_ops.bitwise_and: - - def check_is_bool(input_t): - if isinstance(input_t, TRTTensor): - assert ( - input_t.dtype == trt.bool - ), "We currently do not support input is non-bool" - elif isinstance(input_t, torch.Tensor): - assert ( - input_t.dtype == torch.bool - ), "We currently do not support input is non-bool" - else: - assert isinstance( - input_t.bool - ), "We currently do not support input is non-bool" - - check_is_bool(input_t) - check_is_bool(other_t) - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - if input_t.dtype != trt.bool: - input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool) - if other_t.dtype != trt.bool: - other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.AND, target, name - ) - + return operator.add_logical_and(network, target, kwargs, name) @tensorrt_converter(acc_ops.ne, no_implicit_batch_dim=True) def acc_ops_ne( @@ -1600,25 +1278,8 @@ def acc_ops_ne( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `ne` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - eq_t = add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name - ) - - return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) - - + return operator.add_ne(network, target, kwargs, name) + @tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True) def acc_ops_eq( network: TRTNetwork, @@ -1627,23 +1288,8 @@ def acc_ops_eq( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `eq` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name - ) - - + return operator.add_eq(network, target, kwargs, name) + @tensorrt_converter(acc_ops.gt, no_implicit_batch_dim=True) def acc_ops_gt( network: TRTNetwork, @@ -1652,23 +1298,8 @@ def acc_ops_gt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `gt` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name - ) - - + return operator.add_gt(network, target, kwargs, name) + @tensorrt_converter(acc_ops.lt, no_implicit_batch_dim=True) def acc_ops_lt( network: TRTNetwork, @@ -1677,22 +1308,8 @@ def acc_ops_lt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `le` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name - ) - + return operator.add_lt(network, target, kwargs, name) + @tensorrt_converter(acc_ops.logical_or, no_implicit_batch_dim=True) def acc_ops_logical_or( @@ -1702,35 +1319,8 @@ def acc_ops_logical_or( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `logical_or` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - if isinstance(other_t, (torch.Tensor, bool)): - if isinstance(other_t, bool): - other_t = int(other_t) - elif other_t.dtype == torch.bool: - other_t = other_t.to(torch.int32) - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - if input_t.dtype != trt.bool: - layer_i = network.add_identity(input_t) - layer_i.set_output_type(0, trt.bool) - set_layer_name(layer_i, target, f"{name}_input_dtype_change") - input_t = layer_i.get_output(0) - if other_t.dtype != trt.bool: - layer_o = network.add_identity(other_t) - layer_o.set_output_type(0, trt.bool) - set_layer_name(layer_o, target, f"{name}_other_dtype_change") - other_t = layer_o.get_output(0) - - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.OR, target, name - ) - - + return operator.add_logical_or(network, target, kwargs, name) + @tensorrt_converter(acc_ops.logical_xor, no_implicit_batch_dim=True) def acc_ops_logical_xor( network: TRTNetwork, @@ -1739,35 +1329,8 @@ def acc_ops_logical_xor( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `logical_xor` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - if isinstance(other_t, (torch.Tensor, bool)): - if isinstance(other_t, bool): - other_t = int(other_t) - elif other_t.dtype == torch.bool: - other_t = other_t.to(torch.int32) - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - if input_t.dtype != trt.bool: - layer_i = network.add_identity(input_t) - layer_i.set_output_type(0, trt.bool) - set_layer_name(layer_i, target, f"{name}_input_dtype_change") - input_t = layer_i.get_output(0) - if other_t.dtype != trt.bool: - layer_o = network.add_identity(other_t) - layer_o.set_output_type(0, trt.bool) - set_layer_name(layer_o, target, f"{name}_other_dtype_change") - other_t = layer_o.get_output(0) - - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name - ) - - + return operator.add_logical_xor(network, target, kwargs, name) + # T113156424 Have some accuracy problems in hf_T5. # [TRT] [W] Weights [name=isinf_1_inf_t]: Converted FP32 value in weights (either FP32 infinity or FP32 value outside FP16 range) to corresponding FP16 infinity. If this is not the desired behavior, please modify the weights or retrain with regularization to reduce the magnitude of the weights. # @tensorrt_converter(acc_ops.isinf) @@ -1859,28 +1422,7 @@ def acc_ops_fmod( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it - trunc_div_value = trunc_div( - kwargs["input"], kwargs["other"], network, target, name + "_trunc_div" - ) - prod_value = add_binary_elementwise_layer( - network, - trunc_div_value, - kwargs["other"], - trt.ElementWiseOperation.PROD, - target, - name + "_prod", - ) - sub_value = add_binary_elementwise_layer( - network, - kwargs["input"], - prod_value, - trt.ElementWiseOperation.SUB, - target, - name + "_sub", - ) - return sub_value - + return operator.add_fmod(network, target, kwargs, name) # T113156424 embedding implemenatation is very limited and shows no usage in hf models due to the indices are int64. # if we cast to int32, it will create accuracy issues. We'd better leave it to future implementation. @@ -2108,15 +1650,7 @@ def acc_ops_add( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.SUM, - target, - name, - ) - + return operator.add_add(network, target, kwargs, name) @tensorrt_converter(acc_ops.sub) def acc_ops_sub( @@ -2126,15 +1660,7 @@ def acc_ops_sub( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.SUB, - target, - name, - ) - + return operator.add_sub(network, target, kwargs, name) @tensorrt_converter(acc_ops.div) def acc_ops_div( @@ -2144,15 +1670,7 @@ def acc_ops_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.DIV, - target, - name, - ) - + return operator.add_div(network, target, kwargs, name) @tensorrt_converter(acc_ops.floor_div) def acc_ops_floor_div( @@ -2162,16 +1680,8 @@ def acc_ops_floor_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.FLOOR_DIV, - target, - name, - ) - - + return operator.add_floor_div(network, target, kwargs, name) + @tensorrt_converter(acc_ops.trunc_div) def acc_ops_trunc_div( network: TRTNetwork, @@ -2180,9 +1690,8 @@ def acc_ops_trunc_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return trunc_div(kwargs["input"], kwargs["other"], network, target, name) - - + return operator.add_trunc_div(network, target, kwargs, name) + @tensorrt_converter(acc_ops.mul) def acc_ops_mul( network: TRTNetwork, @@ -2191,16 +1700,8 @@ def acc_ops_mul( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.PROD, - target, - name, - ) - - + return operator.add_mul(network, target, kwargs, name) + @tensorrt_converter(acc_ops.pow) def acc_ops_pow( network: TRTNetwork, @@ -2209,15 +1710,7 @@ def acc_ops_pow( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["exponent"], - trt.ElementWiseOperation.POW, - target, - name, - ) - + return operator.add_pow(network, target, kwargs, name) @tensorrt_converter(acc_ops.unsqueeze) def acc_ops_unsqueeze( @@ -2786,60 +2279,8 @@ def acc_ops_linear( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Linear received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dynamic_dims = get_dynamic_dims(input_val.shape) - assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, ( - "Currently we only support one dynmaic " - "dim for linear and it can't be the last dim." - ) - - if isinstance(kwargs["weight"], torch.Tensor): - weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight") - if target not in (acc_ops.linear, torch.ops.aten.linear): - weight_op = trt.MatrixOperation.TRANSPOSE - else: - weight_op = trt.MatrixOperation.NONE - else: - assert isinstance( - kwargs["weight"], TRTTensor - ), f"Expect weight to be trt tensor but got {type(kwargs['weight'])}" - weight = kwargs["weight"] - weight_op = trt.MatrixOperation.TRANSPOSE - - preset_diff = 0 - if len(input_val.shape) == 1: - preset_diff -= 1 - input_op = trt.MatrixOperation.VECTOR - else: - input_op = trt.MatrixOperation.NONE - - input_val, weight = broadcast( - network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff - ) - matmul_layer = network.add_matrix_multiply(input_val, input_op, weight, weight_op) - set_layer_name(matmul_layer, target, f"{name}_matmul") - res = matmul_layer.get_output(0) - - if kwargs["bias"] is not None: - bias = get_trt_tensor(network, kwargs["bias"], f"{name}_bias") # type: ignore[arg-type] - res = add_binary_elementwise_layer( - network, - matmul_layer.get_output(0), - bias, - trt.ElementWiseOperation.SUM, - target, - f"{name}_add", - ) - return res - - + return operator.add_linear(network, target, kwargs, name) + def add_clamp(network, input, val, op, name): if not len(input.shape): # clamping scalar @@ -3091,36 +2532,8 @@ def acc_ops_matmul( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input") - other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other") - - for i in [input_val, other_val]: - if not isinstance(i, TRTTensor): - raise RuntimeError( - f"matmul received input {i} that is not part of the TensorRT region!" - ) - - input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE - preset_diff = 0 - - if len(input_val.shape) == 1: - preset_diff -= 1 - input_matrix_op = trt.MatrixOperation.VECTOR - - if len(other_val.shape) == 1: - preset_diff += 1 - other_matrix_op = trt.MatrixOperation.VECTOR - - input_val, other_val = broadcast( - network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff - ) - layer = network.add_matrix_multiply( - input_val, input_matrix_op, other_val, other_matrix_op - ) - set_layer_name(layer, target, name) - return layer.get_output(0) - - + return operator.add_matmul(network, target, kwargs, name) + @tensorrt_converter(acc_ops.hardsigmoid) def acc_ops_hard_sigmoid( network: TRTNetwork, @@ -3129,23 +2542,7 @@ def acc_ops_hard_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Hard sigmoid received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_activation_layer( - network, - input_val, - trt.ActivationType.HARD_SIGMOID, - target, - name, - alpha=1 / 6, - beta=0.5, - ) + return activation.add_hard_sigmoid(network, target, kwargs, name) @tensorrt_converter(acc_ops.sigmoid) @@ -3156,17 +2553,7 @@ def acc_ops_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Sigmoid received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_activation_layer( - network, input_val, trt.ActivationType.SIGMOID, target, name - ) + return activation.add_sigmoid(network, target, kwargs, name) @tensorrt_converter(acc_ops.permute) @@ -3470,77 +2857,8 @@ def acc_ops_cumsum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - dim = cast(int, kwargs["dim"]) - input_shape = input_val.shape # type: ignore[union-attr] - input_dim_size = len(input_val.shape) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"cumsum received input {input_val} that is not part " - "of the TensorRT region!" - ) - if network.has_implicit_batch_dimension: - raise RuntimeError( - "cumsum converter currently doesn't support implicit batch dimension" - ) - dim = get_positive_dim(dim, input_dim_size) - loop = network.add_loop() - trip_limit = None - if input_shape[dim] > 0: - axis = torch.tensor(input_shape[dim], dtype=torch.int32) - trip_limit_layer = network.add_constant(axis.shape, to_numpy(axis)) - else: - input_shape = network.add_shape(input_val).get_output(0) - dim_value = torch.tensor(dim, dtype=torch.int32) - axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0) - trip_limit_layer = network.add_gather(input_shape, axis, 0) - set_layer_name(trip_limit_layer, target, f"{name}_trip_limit") - trip_limit = trip_limit_layer.get_output(0) - - loop.add_trip_limit(trip_limit, trt.TripLimit(0)) - iterator = loop.add_iterator(input_val, dim, False) - data = iterator.get_output(0) - new_dims = tuple(data.shape) - zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype)) - zero_tensor = network.add_constant( - zero_tensor.shape, to_numpy(zero_tensor) - ).get_output(0) - - running_sum = loop.add_recurrence(zero_tensor) - set_layer_name(running_sum, target, f"{name}_running_sum_1") - running_sum_tensor = running_sum.get_output(0) - - current_sum = add_binary_elementwise_layer( - network, - data, - running_sum_tensor, - trt.ElementWiseOperation.SUM, - target, - f"{name}_sum_1", - ) - running_sum.set_input(1, current_sum) - - running_sum = loop.add_recurrence(zero_tensor) - set_layer_name(running_sum, target, f"{name}_running_sum_2") - running_sum_tensor = running_sum.get_output(0) - - current_sum = add_binary_elementwise_layer( - network, - data, - running_sum_tensor, - trt.ElementWiseOperation.SUM, - target, - f"{name}_sum_2", - ) - running_sum.set_input(1, current_sum) - - loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim) - set_layer_name(loop_output, target, f"{name}_loop_output") - loop_output.set_input(1, trip_limit) - return loop_output.get_output(0) - - + return operator.add_cumsum(network, target, kwargs, name) + @tensorrt_converter(acc_ops.hardtanh) def acc_ops_hardtanh( network: TRTNetwork, @@ -3549,24 +2867,7 @@ def acc_ops_hardtanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"hardtanh received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_activation_layer( - network, - input_val, - trt.ActivationType.CLIP, - target, - name, - alpha=kwargs["min_val"], - beta=kwargs["max_val"], - ) - + return activation.add_hardtanh(network, target, kwargs, name) @tensorrt_converter(acc_ops.interpolate) def acc_ops_interpolate( diff --git a/py/torch_tensorrt/fx/converters/activation.py b/py/torch_tensorrt/fx/converters/activation.py index a7ab25152c..9d38dcbbc2 100644 --- a/py/torch_tensorrt/fx/converters/activation.py +++ b/py/torch_tensorrt/fx/converters/activation.py @@ -1,76 +1,164 @@ import numpy as np +import operator +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch +from torch.fx.node import Argument, Target -from ..converter_registry import tensorrt_converter +from ..utils import torch_dtype_from_trt from .converter_utils import mark_as_int8_layer +from .converter_utils import set_layer_name + +from ..types import ( + Shape, + TRTDataType, + TRTElementWiseOp, + TRTLayer, + TRTNetwork, + TRTPlugin, + TRTPluginFieldCollection, + TRTTensor, +) + +def add_activation_layer( + network: TRTNetwork, + input_val: TRTTensor, + operation_type: trt.ActivationType, + target: Target, + name: str, + alpha: Optional[Any] = None, + beta: Optional[Any] = None, + dyn_range_fn: Optional[Callable[Tuple[float, float]]] = None +) -> TRTTensor: + """ + Add a TensorRT Activation layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): Input to the activation op. + Must be a TensorRT tensor. + op_type (trt.ElementWiseOperation): Type of the TensorRT activation + operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + alpha (Optional[Any]): If not None, we will use it to set the alpha + attribute of the created TensorRT activation layer. + beta (Optional[Any]): If not None, we will use it to set the beta + attribute of the created TensorRT activation layer. + dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range + + + Returns: + The output of TensorRT Activation layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_activation(input_val, operation_type) + if alpha is not None: + layer.alpha = alpha + if beta is not None: + layer.beta = beta + set_layer_name(layer, target, name) + + if input_val.dynamic_range is not None: + dyn_range = dyn_range_fn(input_val.dynamic_range) + mark_as_int8_layer(layer, dyn_range) + return layer.get_output(0) +def add_relu(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.RELU + return add_activation_layer(network, input_val, operation_type, target, name) -def common_activation( - network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name -): - layer = network.add_activation(input=input_val, type=activation_type) - layer.name = layer_name +def add_leaky_relu(network, target, kwargs, name): + input_val = kwargs["input"] + negative_slope = kwargs["negative_slope"] + operation_type = trt.ActivationType.LEAKY_RELU + return add_activation_layer( + network, input_val, operation_type, target, name, negative_slope + ) - if input_val.dynamic_range: - dyn_range = activation_dyn_range_fn(input_val.dynamic_range) - mark_as_int8_layer(layer, dyn_range) +def add_elu(network, target, kwargs, name): + input_val = kwargs["input"] + alpha = kwargs["alpha"] + operation_type = trt.ActivationType.ELU + return add_activation_layer(network, input_val, operation_type, target, name, alpha) - return layer.get_output(0) +def add_selu(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.SELU + return add_activation_layer(network, input_val, operation_type, target, name) + +def add_softsign(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.SOFTSIGN + return add_activation_layer(network, input_val, operation_type, target, name) +def add_tanh(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.TANH + return add_activation_layer(network, input_val, operation_type, target, name) -@tensorrt_converter(torch.nn.functional.relu) -@tensorrt_converter(torch.nn.modules.activation.ReLU) -def relu(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 +def add_hard_sigmoid(network, target, kwargs, name): input_val = kwargs["input"] - if not isinstance(input_val, trt.tensorrt.ITensor): + if not isinstance(input_val, TRTTensor): raise RuntimeError( - f"ReLU received input {input_val} that is not part " + f"Hard sigmoid received input {input_val} that is not part " "of the TensorRT region!" ) - def activation_dyn_range_fn(dyn_range): - return max(0, dyn_range[0]), max(0, dyn_range[1]) - - return common_activation( + return add_activation_layer( network, - submod, input_val, - trt.ActivationType.RELU, - activation_dyn_range_fn, - layer_name, + trt.ActivationType.HARD_SIGMOID, + target, + name, + alpha=1 / 6, + beta=0.5, ) - -@tensorrt_converter(torch.nn.modules.activation.Sigmoid) -def sigmoid(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 +def add_sigmoid(network, target, kwargs, name): input_val = kwargs["input"] - if not isinstance(input_val, trt.tensorrt.ITensor): + if not isinstance(input_val, TRTTensor): raise RuntimeError( f"Sigmoid received input {input_val} that is not part " "of the TensorRT region!" ) - def activation_dyn_range_fn(dyn_range): - def sigmoid_fn(x): - return 1 / (1 + np.exp(-x)) + return add_activation_layer( + network, input_val, trt.ActivationType.SIGMOID, target, name + ) - return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1]) +def add_hard_tanh(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.TANH + return add_activation_layer(network, input_val, operation_type, target, name) - return common_activation( +def add_sigmoid(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"Hard sigmoid received input {input_val} that is not part " + "of the TensorRT region!" + ) + + return add_activation_layer( network, - submod, input_val, - trt.ActivationType.SIGMOID, - activation_dyn_range_fn, - layer_name, + trt.ActivationType.HARD_SIGMOID, + target, + name, + alpha=1 / 6, + beta=0.5, ) + diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 943eb203b3..b83a338cfd 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -22,6 +22,8 @@ from .converter_utils import * # noqa: F403 import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils +import activation +import operator _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -38,7 +40,7 @@ def aten_ops_add( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name) + return operator.add_add(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.mean.dim) @@ -141,13 +143,13 @@ def aten_ops_div( } rounding_mode = kwargs.get("rounding_mode") if rounding_mode is None: - return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name) + return operator.add_div(network, target, None, kwargs_new, name) elif rounding_mode == "floor": - return acc_ops_converters.acc_ops_floor_div( + return operator.add_floor_div( network, target, None, kwargs_new, name ) elif rounding_mode == "trunc": - return acc_ops_converters.acc_ops_trunc_div( + return operator.add_trunc_div( network, target, None, kwargs_new, name ) else: @@ -168,7 +170,7 @@ def aten_ops_floor_div( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_floor_div(network, target, None, kwargs_new, name) + return operator.add_floor_div(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.fmod.Scalar) @@ -184,7 +186,7 @@ def aten_ops_fmod( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name) + return operator.add_fmod(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.linear) @@ -201,7 +203,7 @@ def aten_ops_linear( "bias": args[2], } - return acc_ops_converters.acc_ops_linear(network, target, None, kwargs_new, name) + return operator.add_linear(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.max_pool3d) @@ -250,7 +252,36 @@ def aten_ops_mul( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_mul(network, target, None, kwargs_new, name) + return operator.add_mul(network, target, None, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.mul.Tensor) +def aten_ops_mul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return operator.add_mul(network, target, None, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.matmul.Tensor) +def aten_ops_matmul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return operator.add_matmul(network, target, None, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) @@ -266,7 +297,7 @@ def aten_ops_pow( "input": args[0], "exponent": args[1], } - return acc_ops_converters.acc_ops_pow(network, target, None, kwargs_new, name) + return operator.add_pow(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.relu.default) @@ -280,9 +311,8 @@ def aten_ops_relu( kwargs_new = { "input": args[0], } - return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name) - - + return activation.add_relu(network, target, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.sub.Tensor) def aten_ops_sub( network: TRTNetwork, @@ -295,7 +325,7 @@ def aten_ops_sub( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name) + return operator.add_sub(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.view.default) @@ -379,7 +409,7 @@ def aten_ops_operator_floordiv( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_floor_div(network, target, None, kwargs_new, name) + return operator.add_floor_div(network, target, None, kwargs_new, name) @tensorrt_converter(operator.mul) @@ -394,7 +424,7 @@ def aten_ops_operator_mul( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_mul(network, target, None, kwargs_new, name) + return operator.add_mul(network, target, None, kwargs_new, name) @tensorrt_converter(operator.add) @@ -409,7 +439,7 @@ def aten_ops_operator_add( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name) + return operator.add_add(network, target, None, kwargs_new, name) @tensorrt_converter(operator.sub) @@ -424,7 +454,7 @@ def aten_ops_operator_sub( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name) + return operator.add_sub(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.sym_numel) @@ -466,3 +496,139 @@ def aten_ops_sym_size( ) set_layer_name(slice_layer, target, "_slice_layer") return slice_layer.get_output(0) + +@tensorrt_converter(torch.ops.aten.leaky_relu.default) +def aten_ops_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.add_leaky_relu(network, target, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.elu.default) +def aten_ops_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.elu(network, target, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.selu.default) +def aten_ops_selu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.add_selu(network, target, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.selu.default) +def aten_ops_selu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.add_selu(network, target, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.softsign.default) +def aten_ops_softsign( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.add_softsign(network, target, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.tanh.default) +def aten_ops_tanh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.add_tanh(network, target, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.softsign.default) +def aten_ops_softsign( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.add_softsign(network, target, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.softsign.default) +def aten_ops_hard_sigmoid( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.add_hard_sigmoid(network, target, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.sigmoid.default) +def aten_ops_hard_tanh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.add_hard_tanh(network, target, kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.sigmoid.default) +def aten_ops_sigmoid( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return activation.add_sigmoid(network, target, kwargs_new, name) + + + + + + diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 17a0cef456..48e8d0a301 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -409,112 +409,7 @@ def get_shape_with_dynamic_shape( return select_layer.get_output(0) -def add_binary_elementwise_layer( - network: TRTNetwork, - lhs_val: Union[int, float, TRTTensor, torch.Tensor], - rhs_val: Union[int, float, TRTTensor, torch.Tensor], - op_type: trt.ElementWiseOperation, - target: Target, - name: str, -) -> TRTTensor: - """ - This function adds a TensorRT elementwise layer. We allow both operands to be - constant (not a trt tensor) because in implicit batch dimension mode, we could - introduce constant via .size() op. Other scenario should be const folded first. - If any operand is not a trt tensor, we make it a trt constant layer while preserve - its dtype. Then we broadcast these two inputs to have the same number of dimensions. - Limitation: - If we are using implicit batch dim mode, the operand that is not a trt - tensor are not allowed to have larger ranks than the trt tensor operand. - - Args: - network (TRTNetwork): TensorRT network object. - lhs_val (TRTTensor): Left operand of the binary operation. Could - be a TensorRT tensor, a PyTorch tensor or a simple value. - rhs_val (TRTTensor): Right operand of the binary operation. Similar - to lhs_val. - op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - - Returns: - The output of TensorRT Elementwise layer. - """ - lhs_dtype = None - rhs_dtype = None - is_lhs_trt_tensor = False - is_rhs_trt_tensor = False - - if isinstance(lhs_val, TRTTensor): - lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) - is_lhs_trt_tensor = True - if isinstance(rhs_val, TRTTensor): - rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) - is_rhs_trt_tensor = True - - if not is_lhs_trt_tensor and not is_rhs_trt_tensor: - warnings.warn( - f"Both operands of the binary elementwise op {name} " - "are constant. In this case, please consider constant fold the model first." - ) - return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) - - # If the following conditions are true: - # 1. the network has implicit batch dimension, - # 2. one operand has shape [] (real shape is [batch_size]), - # 3. another operand is a scalar, - # then the result should also have shape [] (real shape is [batch_size]). - # - # In such case, we need to convert the scalar operand to tensor, because - # this way the shape will become [1], and then will be properly squeezed - # into [], meaning that the result will have shape [], which is what we - # expect. - # - # Note that the dtype here is supposed to be the same as the scalar - # dtype but we don't have a way to detect whether it makes sense for the - # scalar to be float or half. Hence we go with the lhs dtype. - if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): - rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) - if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): - lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) - - # When lhs is scalar, and rhs has shape [1,], then currently the assert - # will fail because lhs shape has fewer dimensions than rhs shape. This - # happens when using implicit batch dimension, when we removed the 1st - # dimension from input tensor, causing it to have shape [] - a scalar. We - # fix it by reducing the rhs constant with a squeeze_left, so it becomes a - # scalar too. More generally, we squeeze_left on input if it's a constant - # tensor. This is safe because broadcast will pad dimensions on the left - # (prepend) to make lhs and rhs shape compatible. - if network.has_implicit_batch_dimension: - if isinstance(lhs_val, torch.Tensor): - lhs_val = squeeze_left(lhs_val) - if isinstance(rhs_val, torch.Tensor): - rhs_val = squeeze_left(rhs_val) - - lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) - rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) - - # Check the limitation in the doc string. - if network.has_implicit_batch_dimension: - if is_lhs_trt_tensor and not is_rhs_trt_tensor: - assert len(lhs_val.shape) >= len( - rhs_val.shape - ), f"{lhs_val.shape} >= {rhs_val.shape}" - elif not is_lhs_trt_tensor and is_rhs_trt_tensor: - assert len(rhs_val.shape) >= len( - lhs_val.shape - ), f"{rhs_val.shape} >= {lhs_val.shape}" - - lhs_val, rhs_val = broadcast( - network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" - ) - layer = network.add_elementwise(lhs_val, rhs_val, op_type) - set_layer_name(layer, target, name) - output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ - return output def squeeze_left(const: torch.Tensor): @@ -527,81 +422,6 @@ def squeeze_left(const: torch.Tensor): const = const.squeeze(dim=0) return const - -def add_unary_layer( - network: TRTNetwork, - input_val: TRTTensor, - operation_type: trt.UnaryOperation, - target: Target, - name: str, -) -> TRTTensor: - """ - Add a TensorRT Unary layer to `network`. - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. - op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - - Returns: - The output of TensorRT Unary layer. - """ - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"{operation_type} received input {input_val} that is not part " - "of the TensorRT region!" - ) - layer = network.add_unary(input_val, operation_type) - set_layer_name(layer, target, name) - output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ - return layer.get_output(0) - - -def add_activation_layer( - network: TRTNetwork, - input_val: TRTTensor, - operation_type: trt.ActivationType, - target: Target, - name: str, - alpha: Optional[Any] = None, - beta: Optional[Any] = None, -) -> TRTTensor: - """ - Add a TensorRT Activation layer to `network`. - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): Input to the activation op. - Must be a TensorRT tensor. - op_type (trt.ElementWiseOperation): Type of the TensorRT activation - operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - alpha (Optional[Any]): If not None, we will use it to set the alpha - attribute of the created TensorRT activation layer. - beta (Optional[Any]): If not None, we will use it to set the beta - attribute of the created TensorRT activation layer. - - Returns: - The output of TensorRT Activation layer. - """ - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"{operation_type} received input {input_val} that is not part " - "of the TensorRT region!" - ) - layer = network.add_activation(input_val, operation_type) - if alpha is not None: - layer.alpha = alpha - if beta is not None: - layer.beta = beta - set_layer_name(layer, target, name) - return layer.get_output(0) - - def add_reduce_layer( network: TRTNetwork, target: Target, @@ -821,24 +641,6 @@ def trunc_div( return output - -def get_python_op_from_trt_elementwise_op( - trt_op: TRTElementWiseOp, -) -> Callable[[Any, Any], Any]: - if trt_op == trt.ElementWiseOperation.SUM: - return operator.add - elif trt_op == trt.ElementWiseOperation.PROD: - return operator.mul - elif trt_op == trt.ElementWiseOperation.SUB: - return operator.sub - elif trt_op == trt.ElementWiseOperation.DIV: - return operator.truediv - elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: - return operator.floordiv - else: - raise RuntimeError(f"{trt_op} is not supported yet!") - - def dtype_uniform( network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor ): diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py new file mode 100644 index 0000000000..3276285c86 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/nn_ops_converters.py @@ -0,0 +1,31 @@ +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch + +from ..converter_registry import tensorrt_converter + +from .converter_utils import mark_as_int8_layer +import activation + +@tensorrt_converter(torch.nn.functional.relu) +@tensorrt_converter(torch.nn.modules.activation.ReLU) +def relu(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + return activation.add_relu(network,"tensorrt", kwargs, layer_name) + +@tensorrt_converter(torch.nn.functional.leaky_relu) +@tensorrt_converter(torch.nn.modules.activation.leaky_relu) +def leaky_relu(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + return activation.add_leaky_relu(network,"tensorrt", kwargs, layer_name) + +@tensorrt_converter(torch.nn.modules.activation.Sigmoid) +def sigmoid(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + return activation.add_sigmoid(network,"tensorrt", kwargs, layer_name) + diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py new file mode 100644 index 0000000000..2dab0c6475 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -0,0 +1,902 @@ +import numpy as np +import operator +import warnings +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union + +import tensorrt as trt +import torch +from torch.fx.node import Argument, Target +from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt + +from ..tracer.acc_tracer import acc_ops + +from .converter_utils import get_trt_tensor +from .converter_utils import set_layer_name +from .converter_utils import get_trt_tensor +from .converter_utils import broadcast +from .converter_utils import squeeze_left +from .converter_utils import dtype_uniform +from .converter_utils import get_trt_plugin +from .converter_utils import get_positive_dim + +from ..types import ( + Shape, + TRTDataType, + TRTElementWiseOp, + TRTLayer, + TRTNetwork, + TRTPlugin, + TRTPluginFieldCollection, + TRTTensor, +) + +_LOGGER: logging.Logger = logging.getLogger(__name__) + +def get_python_op_from_trt_elementwise_op( + trt_op: TRTElementWiseOp, +) -> Callable[[Any, Any], Any]: + if trt_op == trt.ElementWiseOperation.SUM: + return operator.add + elif trt_op == trt.ElementWiseOperation.PROD: + return operator.mul + elif trt_op == trt.ElementWiseOperation.SUB: + return operator.sub + elif trt_op == trt.ElementWiseOperation.DIV: + return operator.truediv + elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: + return operator.floordiv + else: + raise RuntimeError(f"{trt_op} is not supported yet!") + +def add_binary_elementwise_layer( + network: TRTNetwork, + lhs_val: Union[int, float, TRTTensor, torch.Tensor], + rhs_val: Union[int, float, TRTTensor, torch.Tensor], + op_type: trt.ElementWiseOperation, + target: Target, + name: str, +) -> TRTTensor: + """ + This function adds a TensorRT elementwise layer. We allow both operands to be + constant (not a trt tensor) because in implicit batch dimension mode, we could + introduce constant via .size() op. Other scenario should be const folded first. + If any operand is not a trt tensor, we make it a trt constant layer while preserve + its dtype. Then we broadcast these two inputs to have the same number of dimensions. + + Limitation: + If we are using implicit batch dim mode, the operand that is not a trt + tensor are not allowed to have larger ranks than the trt tensor operand. + + Args: + network (TRTNetwork): TensorRT network object. + lhs_val (TRTTensor): Left operand of the binary operation. Could + be a TensorRT tensor, a PyTorch tensor or a simple value. + rhs_val (TRTTensor): Right operand of the binary operation. Similar + to lhs_val. + op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + + Returns: + The output of TensorRT Elementwise layer. + """ + lhs_dtype = None + rhs_dtype = None + is_lhs_trt_tensor = False + is_rhs_trt_tensor = False + + if isinstance(lhs_val, TRTTensor): + lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) + is_lhs_trt_tensor = True + if isinstance(rhs_val, TRTTensor): + rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) + is_rhs_trt_tensor = True + + if not is_lhs_trt_tensor and not is_rhs_trt_tensor: + warnings.warn( + f"Both operands of the binary elementwise op {name} " + "are constant. In this case, please consider constant fold the model first." + ) + return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) + + # If the following conditions are true: + # 1. the network has implicit batch dimension, + # 2. one operand has shape [] (real shape is [batch_size]), + # 3. another operand is a scalar, + # then the result should also have shape [] (real shape is [batch_size]). + # + # In such case, we need to convert the scalar operand to tensor, because + # this way the shape will become [1], and then will be properly squeezed + # into [], meaning that the result will have shape [], which is what we + # expect. + # + # Note that the dtype here is supposed to be the same as the scalar + # dtype but we don't have a way to detect whether it makes sense for the + # scalar to be float or half. Hence we go with the lhs dtype. + if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): + rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) + if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): + lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) + + # When lhs is scalar, and rhs has shape [1,], then currently the assert + # will fail because lhs shape has fewer dimensions than rhs shape. This + # happens when using implicit batch dimension, when we removed the 1st + # dimension from input tensor, causing it to have shape [] - a scalar. We + # fix it by reducing the rhs constant with a squeeze_left, so it becomes a + # scalar too. More generally, we squeeze_left on input if it's a constant + # tensor. This is safe because broadcast will pad dimensions on the left + # (prepend) to make lhs and rhs shape compatible. + if network.has_implicit_batch_dimension: + if isinstance(lhs_val, torch.Tensor): + lhs_val = squeeze_left(lhs_val) + if isinstance(rhs_val, torch.Tensor): + rhs_val = squeeze_left(rhs_val) + + lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) + + # Check the limitation in the doc string. + if network.has_implicit_batch_dimension: + if is_lhs_trt_tensor and not is_rhs_trt_tensor: + assert len(lhs_val.shape) >= len( + rhs_val.shape + ), f"{lhs_val.shape} >= {rhs_val.shape}" + elif not is_lhs_trt_tensor and is_rhs_trt_tensor: + assert len(rhs_val.shape) >= len( + lhs_val.shape + ), f"{rhs_val.shape} >= {lhs_val.shape}" + + lhs_val, rhs_val = broadcast( + network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + layer = network.add_elementwise(lhs_val, rhs_val, op_type) + set_layer_name(layer, target, name) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return output + +def trunc_div( + input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str +) -> TRTTensor: + """ + Perform trunc divide on Tensor, result of divide will be round toward zero. + This means for positive number, it will be floor round; for negative number, + it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. + + Args: + input: divisor. + other: dividend. + network: INetworkDefinition. + target: node target. + name: namespace for the op + + Returns: + A TensorRT tensor represent the result of trunc divide. + """ + prod_output = add_binary_elementwise_layer( + network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod" + ) + sign_output = sign(network, prod_output, target, name) + + # Convert constant input into ITensor for UnaryOperation + if not isinstance(input, trt.tensorrt.ITensor): + input = get_trt_tensor(network, input, f"{name}_input") + if not isinstance(other, trt.tensorrt.ITensor): + other = get_trt_tensor( + network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype) + ) + + abs_input_output = add_unary_layer( + network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input" + ) + abs_other_output = add_unary_layer( + network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other" + ) + abs_floor_output = add_binary_elementwise_layer( + network, + abs_input_output, + abs_other_output, + trt.ElementWiseOperation.FLOOR_DIV, + target, + f"{name}_floor_div", + ) + output = add_binary_elementwise_layer( + network, + abs_floor_output, + sign_output, + trt.ElementWiseOperation.PROD, + target, + f"{name}_output", + ) + + return output + +def add_tile(network, target, kwargs, name): + input_t = kwargs["input"] + input_val = get_trt_tensor(network, input_t, f"{name}_input") + + dims = tuple(cast(Sequence[int], kwargs["dims"])) + n_input_dims = len(input_val.shape) + ( + 1 if network.has_implicit_batch_dimension else 0 + ) + + if len(dims) > n_input_dims: + assert not network.has_implicit_batch_dimension + layer = network.add_shuffle(input_val) + layer.name = f"{name}_reshape" + num_preceding_ones = len(dims) - n_input_dims + + if len(get_dynamic_dims(input_val.shape)) > 1: + input_shape_layer = network.add_shape(input_val) + input_shape_layer.name = f"{name}_input_shape" + preceding_ones = network.add_constant( + (num_preceding_ones,), + np.ascontiguousarray([1] * num_preceding_ones, np.int32), + ).get_output(0) + reshape_layer = network.add_concatenation( + [preceding_ones, input_shape_layer.get_output(0)] + ) + reshape_layer.axis = 0 + reshape_layer.name = f"{name}_reshape_dims" + layer.set_input(1, reshape_layer.get_output(0)) + else: + layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple( + input_val.shape + ) + input_val = layer.get_output(0) + else: + dims = (1,) * (n_input_dims - len(dims)) + dims + + if network.has_implicit_batch_dimension: + assert dims[0] == 1, "Can't tile the batch dim when it's implicit." + dims = dims[1:] + starts = [0] * len(dims) + shapes = [] + if all(isinstance(d, int) for d in dims): + shapes = [i * j for i, j in zip(input_val.shape, dims)] # type: ignore[union-attr] + else: + shape = [] + for i, (s, d) in enumerate(zip(input_val.shape, dims)): + if isinstance(d, TRTTensor) and len(d.shape) == 0: + d = prepend_ones(network, d, f"{name}_{i}", 1) + else: + d = get_trt_tensor(network, d, f"{name}_{i}") + shape.append(d) + mul = add_binary_elementwise_layer( + network, + s, + d, + trt.ElementWiseOperation.PROD, + target, + f"{name}_mul_{i}", + ) + shapes.append(mul) + dims = shape + # If there's dynmaic dim then there would be negative dims in shapes which is not allowed. + # Here we build a dummy shapes array. + if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] + shapes = [1] * len(dims) + strides = [1] * len(dims) + layer = network.add_slice(input_val, starts, shapes, strides) + layer.mode = trt.SliceMode.WRAP + set_layer_name(layer, target, name) + + if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] + starts_tensor = network.add_constant( + (len(dims),), np.ascontiguousarray([0] * len(dims), np.int32) + ).get_output(0) + if all(isinstance(d, int) for d in dims): + dims_tensor = network.add_constant( + (len(dims),), np.ascontiguousarray(dims, np.int32) + ).get_output(0) + else: + assert all(isinstance(d, TRTTensor) for d in dims) + concat_dims_layer = network.add_concatenation(inputs=dims) + concat_dims_layer.axis = 0 + concat_dims_layer.name = f"{name}_tile_dim" + dims_tensor = concat_dims_layer.get_output(0) + input_shape_layer = network.add_shape(input_val) + input_shape_layer.name = f"{name}_slice_input_shape" + slice_shapes_tensor = add_binary_elementwise_layer( + network, + input_shape_layer.get_output(0), + dims_tensor, + trt.ElementWiseOperation.PROD, + target, + f"{name}_slice_shapes", + ) + layer.set_input(1, starts_tensor) + layer.set_input(2, slice_shapes_tensor) + + return layer.get_output(0) + +def add_linear(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"Linear received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dynamic_dims = get_dynamic_dims(input_val.shape) + assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, ( + "Currently we only support one dynmaic " + "dim for linear and it can't be the last dim." + ) + + if isinstance(kwargs["weight"], torch.Tensor): + weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight") + if target not in (acc_ops.linear, torch.ops.aten.linear): + weight_op = trt.MatrixOperation.TRANSPOSE + else: + weight_op = trt.MatrixOperation.NONE + else: + assert isinstance( + kwargs["weight"], TRTTensor + ), f"Expect weight to be trt tensor but got {type(kwargs['weight'])}" + weight = kwargs["weight"] + weight_op = trt.MatrixOperation.TRANSPOSE + + preset_diff = 0 + if len(input_val.shape) == 1: + preset_diff -= 1 + input_op = trt.MatrixOperation.VECTOR + else: + input_op = trt.MatrixOperation.NONE + + input_val, weight = broadcast( + network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff + ) + matmul_layer = network.add_matrix_multiply(input_val, input_op, weight, weight_op) + set_layer_name(matmul_layer, target, f"{name}_matmul") + res = matmul_layer.get_output(0) + + if kwargs["bias"] is not None: + bias = get_trt_tensor(network, kwargs["bias"], f"{name}_bias") # type: ignore[arg-type] + res = add_binary_elementwise_layer( + network, + matmul_layer.get_output(0), + bias, + trt.ElementWiseOperation.SUM, + target, + f"{name}_add", + ) + return res + +def add_unary_layer( + network: TRTNetwork, + input_val: TRTTensor, + operation_type: trt.UnaryOperation, + target: Target, + name: str, +) -> TRTTensor: + """ + Add a TensorRT Unary layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. + op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + + Returns: + The output of TensorRT Unary layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_unary(input_val, operation_type) + set_layer_name(layer, target, name) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return layer.get_output(0) + +def layer_norm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"LayerNorm received input {input_val} that is not part " + "of the TensorRT region!" + ) + + shape = kwargs["weight"].shape # type: ignore[union-attr] + broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape + gamma = to_numpy(kwargs["weight"].reshape(*shape)) # type: ignore[union-attr] + beta = to_numpy(kwargs["bias"].reshape(*shape)) # type: ignore[union-attr] + eps = kwargs["eps"] + + axes = 0 + for d in range(len(shape)): + axes |= 1 << (len(input_val.shape) - d - 1) + + # E[x] + mean_expected_layer = network.add_reduce( + input_val, trt.ReduceOperation.AVG, axes, keep_dims=True + ) + set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") + + # X-E[x] + sub_trt = operator.add_binary_elementwise_layer( + network, + input_val, + mean_expected_layer.get_output(0), + trt.ElementWiseOperation.SUB, + target, + f"{name}_sub", + ) + # Variance = mean(pow(x_sub_mean,2)) + pow_tensor = network.add_constant( + (1,) * len(input_val.shape), + trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), + ) + pow_tensor.name = f"{name}_power" + pow_var = operator.add_binary_elementwise_layer( + network, + sub_trt, + pow_tensor.get_output(0), + trt.ElementWiseOperation.POW, + target, + f"{name}_pow_var", + ) + mean_trt_layer = network.add_reduce( + pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True + ) + set_layer_name(mean_trt_layer, target, f"{name}_mean") + # Variance + eps + eps_tensor = network.add_constant( + (1,) * len(input_val.shape), + trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), + ) + eps_tensor.name = f"{name}_eps" + add_trt = add_binary_elementwise_layer( + network, + mean_trt_layer.get_output(0), + eps_tensor.get_output(0), + trt.ElementWiseOperation.SUM, + target, + f"{name}_add", + ) + # SQRT((Var + eps)) + sqrt_trt = add_unary_layer( + network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt" + ) + # (x - E[x]) / sqrt((var + eps)) + div_trt = add_binary_elementwise_layer( + network, + sub_trt, + sqrt_trt, + trt.ElementWiseOperation.DIV, + target, + f"{name}_div_trt", + ) + + assert gamma is not None + gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) # type: ignore[attr-defined] + gamma_tensor.name = f"{name}_gamma" + assert beta is not None + beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] + beta_tensor.name = f"{name}_beta" + # y * gamma + beta + scale_layer = add_binary_elementwise_layer( + network, + div_trt, + gamma_tensor.get_output(0), + trt.ElementWiseOperation.PROD, + target, + f"{name}_scale", + ) + return add_binary_elementwise_layer( + network, + scale_layer, + beta_tensor.get_output(0), + trt.ElementWiseOperation.SUM, + target, + name, + ) + +def add_add(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.SUM, + target, + name, + ) + +def add_matmul(network, target, kwargs, name): + input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input") + other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other") + + for i in [input_val, other_val]: + if not isinstance(i, TRTTensor): + raise RuntimeError( + f"matmul received input {i} that is not part of the TensorRT region!" + ) + + input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE + preset_diff = 0 + + if len(input_val.shape) == 1: + preset_diff -= 1 + input_matrix_op = trt.MatrixOperation.VECTOR + + if len(other_val.shape) == 1: + preset_diff += 1 + other_matrix_op = trt.MatrixOperation.VECTOR + + input_val, other_val = broadcast( + network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff + ) + layer = network.add_matrix_multiply( + input_val, input_matrix_op, other_val, other_matrix_op + ) + set_layer_name(layer, target, name) + return layer.get_output(0) + +def add_layer_norm(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"LayerNorm received input {input_val} that is not part " + "of the TensorRT region!" + ) + + gamma = kwargs["weight"].detach().cpu().float().numpy() + gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) + beta = kwargs["bias"].detach().cpu().float().numpy() + beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) + eps_field = trt.PluginField( + "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32 + ) + try: + normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32) + except TypeError: + _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") + normalized_shape = np.array([], dtype=np.int32) + + normalized_shape_filed = trt.PluginField( + "normalized_shape", normalized_shape, trt.PluginFieldType.INT32 + ) + field_collection = trt.PluginFieldCollection( + [gamma_field, beta_field, eps_field, normalized_shape_filed] + ) + + try: + if network.has_implicit_batch_dimension: + plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") + else: + plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") + except AssertionError: + _LOGGER.error( + "Unable to find layer norm plugin, fall back to TensorRT implementation." + ) + return layer_norm(network, target, args, kwargs, name) + layer = network.add_plugin_v2([input_val], plugin) + layer.name = name + return layer.get_output(0) + +def add_cumsum(network, target, kwargs, name): + input_val = kwargs["input"] + dim = cast(int, kwargs["dim"]) + input_shape = input_val.shape # type: ignore[union-attr] + input_dim_size = len(input_val.shape) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"cumsum received input {input_val} that is not part " + "of the TensorRT region!" + ) + if network.has_implicit_batch_dimension: + raise RuntimeError( + "cumsum converter currently doesn't support implicit batch dimension" + ) + dim = get_positive_dim(dim, input_dim_size) + loop = network.add_loop() + trip_limit = None + if input_shape[dim] > 0: + axis = torch.tensor(input_shape[dim], dtype=torch.int32) + trip_limit_layer = network.add_constant(axis.shape, to_numpy(axis)) + else: + input_shape = network.add_shape(input_val).get_output(0) + dim_value = torch.tensor(dim, dtype=torch.int32) + axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0) + trip_limit_layer = network.add_gather(input_shape, axis, 0) + set_layer_name(trip_limit_layer, target, f"{name}_trip_limit") + trip_limit = trip_limit_layer.get_output(0) + + loop.add_trip_limit(trip_limit, trt.TripLimit(0)) + iterator = loop.add_iterator(input_val, dim, False) + data = iterator.get_output(0) + new_dims = tuple(data.shape) + zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype)) + zero_tensor = network.add_constant( + zero_tensor.shape, to_numpy(zero_tensor) + ).get_output(0) + + running_sum = loop.add_recurrence(zero_tensor) + set_layer_name(running_sum, target, f"{name}_running_sum_1") + running_sum_tensor = running_sum.get_output(0) + + current_sum = add_binary_elementwise_layer( + network, + data, + running_sum_tensor, + trt.ElementWiseOperation.SUM, + target, + f"{name}_sum_1", + ) + running_sum.set_input(1, current_sum) + + running_sum = loop.add_recurrence(zero_tensor) + set_layer_name(running_sum, target, f"{name}_running_sum_2") + running_sum_tensor = running_sum.get_output(0) + + current_sum = add_binary_elementwise_layer( + network, + data, + running_sum_tensor, + trt.ElementWiseOperation.SUM, + target, + f"{name}_sum_2", + ) + running_sum.set_input(1, current_sum) + + loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim) + set_layer_name(loop_output, target, f"{name}_loop_output") + loop_output.set_input(1, trip_limit) + return loop_output.get_output(0) + +def add_maximum(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.MAX, + target, + name, + ) + +def add_mul(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.PROD, + target, + name, + ) + +def add_pow(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.PROD, + target, + name, + ) + +def add_floor_div(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.FLOOR_DIV, + target, + name, + ) + +def add_div(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.DIV, + target, + name, + ) + +def add_sub(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.SUB, + target, + name, + ) +def add_minimum(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.MIN, + target, + name, + ) + +def add_logical_and(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `ne` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) + eq_t = add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + ) + + return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) + +def add_ne(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `ne` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) + eq_t = add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + ) + + return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) + +def add_eq(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `eq` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + ) + +def add_gt(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `gt` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name + ) + +def add_lt(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `le` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name + ) + +def add_logical_or(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `logical_or` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + if isinstance(other_t, (torch.Tensor, bool)): + if isinstance(other_t, bool): + other_t = int(other_t) + elif other_t.dtype == torch.bool: + other_t = other_t.to(torch.int32) + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + if input_t.dtype != trt.bool: + layer_i = network.add_identity(input_t) + layer_i.set_output_type(0, trt.bool) + set_layer_name(layer_i, target, f"{name}_input_dtype_change") + input_t = layer_i.get_output(0) + if other_t.dtype != trt.bool: + layer_o = network.add_identity(other_t) + layer_o.set_output_type(0, trt.bool) + set_layer_name(layer_o, target, f"{name}_other_dtype_change") + other_t = layer_o.get_output(0) + + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.OR, target, name + ) + +def add_logical_xor(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `logical_xor` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + if isinstance(other_t, (torch.Tensor, bool)): + if isinstance(other_t, bool): + other_t = int(other_t) + elif other_t.dtype == torch.bool: + other_t = other_t.to(torch.int32) + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + if input_t.dtype != trt.bool: + layer_i = network.add_identity(input_t) + layer_i.set_output_type(0, trt.bool) + set_layer_name(layer_i, target, f"{name}_input_dtype_change") + input_t = layer_i.get_output(0) + if other_t.dtype != trt.bool: + layer_o = network.add_identity(other_t) + layer_o.set_output_type(0, trt.bool) + set_layer_name(layer_o, target, f"{name}_other_dtype_change") + other_t = layer_o.get_output(0) + + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name + ) + +def add_fmod(network, target, kwargs, name): + # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it + trunc_div_value = trunc_div( + kwargs["input"], kwargs["other"], network, target, name + "_trunc_div" + ) + prod_value = add_binary_elementwise_layer( + network, + trunc_div_value, + kwargs["other"], + trt.ElementWiseOperation.PROD, + target, + name + "_prod", + ) + sub_value = add_binary_elementwise_layer( + network, + kwargs["input"], + prod_value, + trt.ElementWiseOperation.SUB, + target, + name + "_sub", + ) + return sub_value + +def add_trunc_div(network, target, kwargs, name): + return trunc_div(kwargs["input"], kwargs["other"], network, target, name) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py new file mode 100644 index 0000000000..a220b67d67 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestReLUConverter(DispatchTestCase): + def test_leaky_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file From 3f3a925b3df6cef8c4b8464069850fcccf33b600 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 16 Mar 2023 23:29:33 -0700 Subject: [PATCH 02/39] aten converter- matmul, tanh, gelu, slice select --- .../fx/converters/acc_ops_converters.py | 104 +-------------- py/torch_tensorrt/fx/converters/activation.py | 30 +++++ .../fx/converters/aten_ops_converters.py | 76 +++++++---- .../fx/converters/converter_utils.py | 2 +- py/torch_tensorrt/fx/converters/operator.py | 120 ++++++++++++++++++ .../test/converters/aten_op/test_gelu_aten.py | 52 ++++++++ .../converters/aten_op/test_matmul_aten.py | 88 +++++++++++++ .../converters/aten_op/test_select_aten.py | 0 .../converters/aten_op/test_slice_aten.py | 0 .../test/converters/aten_op/test_tanh_aten.py | 52 ++++++++ 10 files changed, 400 insertions(+), 124 deletions(-) create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 689a52743e..85e8e22252 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -1980,53 +1980,8 @@ def acc_ops_slice_tensor( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"slice_tensor received input {input_val} that is not part " - "of the TensorRT region!" - ) - - ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) - dynamic_shape = has_dynamic_shape(input_val.shape) - if network.has_implicit_batch_dimension: - if dim == 0: - raise RuntimeError( - f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" - ) - dim = dim - 1 - else: - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - - start_int = cast(int, kwargs["start"]) - stop_int = cast(int, kwargs["stop"]) - step_int = cast(int, kwargs["step"]) - start = [0] * len(input_val.shape) - start[dim] = start_int - stride = [1] * len(start) - stride[dim] = step_int - output_shape = list(input_val.shape) - output_shape[dim] = (stop_int - start_int) // step_int - - if dynamic_shape > 0: - output_shape = get_shape_with_dynamic_shape( - network, output_shape, input_val, target, name - ) - layer = network.add_slice( - input_val, - start=start, - shape=[] if dynamic_shape else output_shape, - stride=stride, - ) - if dynamic_shape: - layer.set_input(2, output_shape) - set_layer_name(layer, target, name) - return layer.get_output(0) - + return operator.add_slice(network, target, kwargs, name) + @tensorrt_converter(acc_ops.expand) def acc_ops_expand_tensor( @@ -2036,29 +1991,8 @@ def acc_ops_expand_tensor( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_t = kwargs["input"] - shape = list(kwargs["sizes"]) - - input_val = get_trt_tensor(network, input_t, f"{name}_input") - - if network.has_implicit_batch_dimension: - shape = shape[1:] - - ranks = len(input_val.shape) - # TRT does not support different dimension size - assert len(shape) == ranks - shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)] - - inshape = tuple(input_val.shape) - shape = tuple(shape) - start = tuple([0] * ranks) - stride = tuple( - [int(i == o) for i, o in zip(inshape, shape)] - ) # stride == 1 if dimensions match, 0 otherwise - layer = network.add_slice(input_val, start=start, shape=shape, stride=stride) - set_layer_name(layer, target, name) - return layer.get_output(0) - + return operator.add_expand(network, target, kwargs, name) + @tensorrt_converter(acc_ops.where) def acc_ops_where( @@ -2754,34 +2688,8 @@ def acc_ops_gelu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - approximate = kwargs["approximate"] - if approximate != "none": - raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"GELU received input {input_val} that is not part " - "of the TensorRT region!" - ) - if network.has_implicit_batch_dimension: - raise RuntimeError( - "GeLU converter currently doesn't support implicit batch dimension" - ) - - plugin_name = "CustomGeluPluginDynamic" - # type_id 0 for float32, 1 for float16 - type_id = trt.PluginField( - "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 - ) - field_collection = TRTPluginFieldCollection([type_id]) - plugin_version = "1" - - plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) - - layer = network.add_plugin_v2([input_val], plugin) - set_layer_name(layer, target, name) - return layer.get_output(0) - + return activation.add_gelu(network, target, kwargs, name) + @tensorrt_converter(acc_ops.chunk) def acc_ops_chunk( diff --git a/py/torch_tensorrt/fx/converters/activation.py b/py/torch_tensorrt/fx/converters/activation.py index 9d38dcbbc2..38b734dd67 100644 --- a/py/torch_tensorrt/fx/converters/activation.py +++ b/py/torch_tensorrt/fx/converters/activation.py @@ -12,6 +12,7 @@ from .converter_utils import mark_as_int8_layer from .converter_utils import set_layer_name +from .converter_utils import get_trt_plugin from ..types import ( Shape, @@ -106,6 +107,35 @@ def add_tanh(network, target, kwargs, name): operation_type = trt.ActivationType.TANH return add_activation_layer(network, input_val, operation_type, target, name) +def add_gelu(network, target, kwargs, name): + input_val = kwargs["input"] + approximate = kwargs["approximate"] + if approximate != "none": + raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"GELU received input {input_val} that is not part " + "of the TensorRT region!" + ) + if network.has_implicit_batch_dimension: + raise RuntimeError( + "GeLU converter currently doesn't support implicit batch dimension" + ) + + plugin_name = "CustomGeluPluginDynamic" + # type_id 0 for float32, 1 for float16 + type_id = trt.PluginField( + "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 + ) + field_collection = TRTPluginFieldCollection([type_id]) + plugin_version = "1" + + plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) + + layer = network.add_plugin_v2([input_val], plugin) + set_layer_name(layer, target, name) + return layer.get_output(0) + def add_hard_sigmoid(network, target, kwargs, name): input_val = kwargs["input"] diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index b83a338cfd..bee8eddfce 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -254,19 +254,6 @@ def aten_ops_mul( } return operator.add_mul(network, target, None, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.mul.Tensor) -def aten_ops_mul( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - "other": args[1], - } - return operator.add_mul(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.matmul.Tensor) def aten_ops_matmul( @@ -283,7 +270,6 @@ def aten_ops_matmul( return operator.add_matmul(network, target, None, kwargs_new, name) - @tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) @tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) def aten_ops_pow( @@ -392,10 +378,8 @@ def aten_ops_expand( "input": args[0], "sizes": args[1], } - return acc_ops_converters.acc_ops_expand_tensor( - network, target, None, kwargs_new, name - ) - + return operator.add_expand(network, target, kwargs_new, name) + @tensorrt_converter(operator.floordiv) def aten_ops_operator_floordiv( @@ -497,8 +481,42 @@ def aten_ops_sym_size( set_layer_name(slice_layer, target, "_slice_layer") return slice_layer.get_output(0) + +@tensorrt_converter(torch.ops.aten.slice.Tensor) +def aten_ops_slice( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input" : args[0], + "dim" : args[1], + "start" : args[2], + "stop" : args[3], + "step" : args[4], + } + return operator.add_slice(network, target. kwargs_new, name) + +@tensorrt_converter(torch.ops.aten.select.Tensor) +def aten_ops_select( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input" : args[0], + "dim" : args[1], + "index" : args[2], + } + return operator.add_select(network, target. kwargs_new, name) + + @tensorrt_converter(torch.ops.aten.leaky_relu.default) -def aten_ops_relu( +def aten_ops_leaky_relu( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], @@ -510,8 +528,9 @@ def aten_ops_relu( } return activation.add_leaky_relu(network, target, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.elu.default) -def aten_ops_relu( +def aten_ops_elu( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], @@ -521,7 +540,8 @@ def aten_ops_relu( kwargs_new = { "input": args[0], } - return activation.elu(network, target, kwargs_new, name) + return activation.add_elu(network, target, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.selu.default) def aten_ops_selu( @@ -534,10 +554,11 @@ def aten_ops_selu( kwargs_new = { "input": args[0], } - return activation.add_selu(network, target, kwargs_new, name) + return activation.selu(network, target, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.selu.default) -def aten_ops_selu( + +@tensorrt_converter(torch.ops.aten.gelu.default) +def aten_ops_gelu( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], @@ -547,7 +568,8 @@ def aten_ops_selu( kwargs_new = { "input": args[0], } - return activation.add_selu(network, target, kwargs_new, name) + return activation.add_gelu(network, target, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.softsign.default) def aten_ops_softsign( @@ -562,6 +584,7 @@ def aten_ops_softsign( } return activation.add_softsign(network, target, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.tanh.default) def aten_ops_tanh( network: TRTNetwork, @@ -588,6 +611,7 @@ def aten_ops_softsign( } return activation.add_softsign(network, target, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.softsign.default) def aten_ops_hard_sigmoid( network: TRTNetwork, @@ -601,6 +625,7 @@ def aten_ops_hard_sigmoid( } return activation.add_hard_sigmoid(network, target, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.sigmoid.default) def aten_ops_hard_tanh( network: TRTNetwork, @@ -614,6 +639,7 @@ def aten_ops_hard_tanh( } return activation.add_hard_tanh(network, target, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.sigmoid.default) def aten_ops_sigmoid( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 48e8d0a301..4820a1c782 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -396,7 +396,7 @@ def get_shape_with_dynamic_shape( ) set_layer_name(zero_layer, target, f"{name}_zeros") - condition_val = add_binary_elementwise_layer( + condition_val = operator.add_binary_elementwise_layer( network, scale_res, zero_layer.get_output(0), diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 2dab0c6475..0fa09918d1 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -18,6 +18,9 @@ from .converter_utils import dtype_uniform from .converter_utils import get_trt_plugin from .converter_utils import get_positive_dim +from .converter_utils import prepend_ones +from .converter_utils import has_dynamic_shape +from .converter_utils import get_shape_with_dynamic_shape from ..types import ( Shape, @@ -900,3 +903,120 @@ def add_fmod(network, target, kwargs, name): def add_trunc_div(network, target, kwargs, name): return trunc_div(kwargs["input"], kwargs["other"], network, target, name) + +def add_expand(network, target, kwargs, name): + input_t = kwargs["input"] + shape = list(kwargs["sizes"]) + + input_val = get_trt_tensor(network, input_t, f"{name}_input") + + if network.has_implicit_batch_dimension: + shape = shape[1:] + + ranks = len(input_val.shape) + # TRT does not support different dimension size + assert len(shape) == ranks + shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)] + + inshape = tuple(input_val.shape) + shape = tuple(shape) + start = tuple([0] * ranks) + stride = tuple( + [int(i == o) for i, o in zip(inshape, shape)] + ) # stride == 1 if dimensions match, 0 otherwise + layer = network.add_slice(input_val, start=start, shape=shape, stride=stride) + set_layer_name(layer, target, name) + return layer.get_output(0) + +def add_slice(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input_val} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + + start_int = cast(int, kwargs["start"]) + stop_int = cast(int, kwargs["stop"]) + step_int = cast(int, kwargs["step"]) + start = [0] * len(input_val.shape) + start[dim] = start_int + stride = [1] * len(start) + stride[dim] = step_int + output_shape = list(input_val.shape) + output_shape[dim] = (stop_int - start_int) // step_int + + if dynamic_shape > 0: + output_shape = get_shape_with_dynamic_shape( + network, output_shape, input_val, target, name + ) + layer = network.add_slice( + input_val, + start=start, + shape=[] if dynamic_shape else output_shape, + stride=stride, + ) + if dynamic_shape: + layer.set_input(2, output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) + +def add_select(network, target, kwargs, name): + input_val = kwargs["input"] + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input_val} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input_val.shape[dim] != -1, "Can't select on negative shape dimension!" + index = kwargs[2] + if index >= input_val.shape[dim]: + raise RuntimeError( + f"cannot have index greater than the dimension length! {input_val.shape[dim]}" + ) + output_shape = list(input_val.shape) + output_shape[dim] = 1 + if dynamic_shape > 0: + output_shape = get_shape_with_dynamic_shape( + network, output_shape, input_val, target, name + ) + layer = network.add_gather( + input_val, + dim, + index + ) + out = layer.getOutput(0) + if(len(out.shape) != 1): + layer = network.add_shuffle(out) + return layer.getOutput(0) + + + diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py new file mode 100644 index 0000000000..35f9498b18 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + +class TestGeLUConverter(DispatchTestCase): + def test_gelu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.gelu.default} + ) + + def test_gelu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + def test_gelu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py new file mode 100644 index 0000000000..44d011ef4e --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py @@ -0,0 +1,88 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + +class TestMatMulConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv1d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.convolution.default}, + test_explicit_precision=True, + ) + + def test_conv1d_with_dynamic_shape( + self, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) \ No newline at end of file diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py new file mode 100644 index 0000000000..35f9498b18 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + +class TestGeLUConverter(DispatchTestCase): + def test_gelu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.gelu.default} + ) + + def test_gelu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + def test_gelu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file From 8c8e8979c16ac9c1d55a94d69de7ad134d55f5ba Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 17 Mar 2023 10:08:15 -0700 Subject: [PATCH 03/39] Fixing matmul, select, tanh tests --- .../fx/converters/acc_ops_converters.py | 5 +- py/torch_tensorrt/fx/converters/activation.py | 9 +- .../fx/converters/aten_ops_converters.py | 106 ++++++------------ py/torch_tensorrt/fx/converters/operator.py | 3 +- .../converters/aten_op/test_matmul_aten.py | 85 ++++---------- .../converters/aten_op/test_select_aten.py | 39 +++++++ .../converters/aten_op/test_slice_aten.py | 0 .../test/converters/aten_op/test_tanh_aten.py | 20 ++-- 8 files changed, 113 insertions(+), 154 deletions(-) delete mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 85e8e22252..e070376797 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -26,8 +26,9 @@ trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous -import activation -import operator + +from .activation import * +from .operator import * _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/fx/converters/activation.py b/py/torch_tensorrt/fx/converters/activation.py index 38b734dd67..4b774b8d9e 100644 --- a/py/torch_tensorrt/fx/converters/activation.py +++ b/py/torch_tensorrt/fx/converters/activation.py @@ -33,7 +33,7 @@ def add_activation_layer( name: str, alpha: Optional[Any] = None, beta: Optional[Any] = None, - dyn_range_fn: Optional[Callable[Tuple[float, float]]] = None + dyn_range_fn: Optional[Callable[[float, float], Any]] = None ) -> TRTTensor: """ Add a TensorRT Activation layer to `network`. @@ -109,9 +109,10 @@ def add_tanh(network, target, kwargs, name): def add_gelu(network, target, kwargs, name): input_val = kwargs["input"] - approximate = kwargs["approximate"] - if approximate != "none": - raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") + if "approximate" in kwargs.keys(): + approximate = kwargs["approximate"] + if approximate != "none": + raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") if not isinstance(input_val, TRTTensor): raise RuntimeError( f"GELU received input {input_val} that is not part " diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index bee8eddfce..d26507e805 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -21,9 +21,11 @@ from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt from .converter_utils import * # noqa: F403 +from .activation import * +from .operator import * + import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils -import activation -import operator + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -40,7 +42,7 @@ def aten_ops_add( "input": args[0], "other": args[1], } - return operator.add_add(network, target, None, kwargs_new, name) + return add_add(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.mean.dim) @@ -143,13 +145,13 @@ def aten_ops_div( } rounding_mode = kwargs.get("rounding_mode") if rounding_mode is None: - return operator.add_div(network, target, None, kwargs_new, name) + return add_div(network, target, None, kwargs_new, name) elif rounding_mode == "floor": - return operator.add_floor_div( + return add_floor_div( network, target, None, kwargs_new, name ) elif rounding_mode == "trunc": - return operator.add_trunc_div( + return add_trunc_div( network, target, None, kwargs_new, name ) else: @@ -170,7 +172,7 @@ def aten_ops_floor_div( "input": args[0], "other": args[1], } - return operator.add_floor_div(network, target, None, kwargs_new, name) + return add_floor_div(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.fmod.Scalar) @@ -186,7 +188,7 @@ def aten_ops_fmod( "input": args[0], "other": args[1], } - return operator.add_fmod(network, target, None, kwargs_new, name) + return add_fmod(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.linear) @@ -203,7 +205,7 @@ def aten_ops_linear( "bias": args[2], } - return operator.add_linear(network, target, None, kwargs_new, name) + return add_linear(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.max_pool3d) @@ -252,10 +254,11 @@ def aten_ops_mul( "input": args[0], "other": args[1], } - return operator.add_mul(network, target, None, kwargs_new, name) + return add_mul(network, target, None, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.matmul.Tensor) +@tensorrt_converter(torch.ops.aten.matmul) +@tensorrt_converter(torch.ops.aten.mm.default) def aten_ops_matmul( network: TRTNetwork, target: Target, @@ -267,7 +270,7 @@ def aten_ops_matmul( "input": args[0], "other": args[1], } - return operator.add_matmul(network, target, None, kwargs_new, name) + return add_matmul(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) @@ -283,7 +286,7 @@ def aten_ops_pow( "input": args[0], "exponent": args[1], } - return operator.add_pow(network, target, None, kwargs_new, name) + return add_pow(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.relu.default) @@ -297,7 +300,7 @@ def aten_ops_relu( kwargs_new = { "input": args[0], } - return activation.add_relu(network, target, kwargs_new, name) + return add_relu(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.sub.Tensor) def aten_ops_sub( @@ -311,7 +314,7 @@ def aten_ops_sub( "input": args[0], "other": args[1], } - return operator.add_sub(network, target, None, kwargs_new, name) + return add_sub(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.view.default) @@ -378,7 +381,7 @@ def aten_ops_expand( "input": args[0], "sizes": args[1], } - return operator.add_expand(network, target, kwargs_new, name) + return add_expand(network, target, kwargs_new, name) @tensorrt_converter(operator.floordiv) @@ -393,7 +396,7 @@ def aten_ops_operator_floordiv( "input": args[0], "other": args[1], } - return operator.add_floor_div(network, target, None, kwargs_new, name) + return add_floor_div(network, target, None, kwargs_new, name) @tensorrt_converter(operator.mul) @@ -408,7 +411,7 @@ def aten_ops_operator_mul( "input": args[0], "other": args[1], } - return operator.add_mul(network, target, None, kwargs_new, name) + return add_mul(network, target, None, kwargs_new, name) @tensorrt_converter(operator.add) @@ -423,7 +426,7 @@ def aten_ops_operator_add( "input": args[0], "other": args[1], } - return operator.add_add(network, target, None, kwargs_new, name) + return add_add(network, target, None, kwargs_new, name) @tensorrt_converter(operator.sub) @@ -438,7 +441,7 @@ def aten_ops_operator_sub( "input": args[0], "other": args[1], } - return operator.add_sub(network, target, None, kwargs_new, name) + return add_sub(network, target, None, kwargs_new, name) @tensorrt_converter(torch.ops.aten.sym_numel) @@ -497,9 +500,10 @@ def aten_ops_slice( "stop" : args[3], "step" : args[4], } - return operator.add_slice(network, target. kwargs_new, name) + return add_slice(network, target. kwargs_new, name) -@tensorrt_converter(torch.ops.aten.select.Tensor) + +@tensorrt_converter(torch.ops.aten.select) def aten_ops_select( network: TRTNetwork, target: Target, @@ -512,7 +516,7 @@ def aten_ops_select( "dim" : args[1], "index" : args[2], } - return operator.add_select(network, target. kwargs_new, name) + return add_select(network, target. kwargs_new, name) @tensorrt_converter(torch.ops.aten.leaky_relu.default) @@ -526,7 +530,7 @@ def aten_ops_leaky_relu( kwargs_new = { "input": args[0], } - return activation.add_leaky_relu(network, target, kwargs_new, name) + return add_leaky_relu(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.elu.default) @@ -540,7 +544,7 @@ def aten_ops_elu( kwargs_new = { "input": args[0], } - return activation.add_elu(network, target, kwargs_new, name) + return add_elu(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.selu.default) @@ -554,7 +558,7 @@ def aten_ops_selu( kwargs_new = { "input": args[0], } - return activation.selu(network, target, kwargs_new, name) + return add_selu(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.gelu.default) @@ -568,22 +572,7 @@ def aten_ops_gelu( kwargs_new = { "input": args[0], } - return activation.add_gelu(network, target, kwargs_new, name) - - -@tensorrt_converter(torch.ops.aten.softsign.default) -def aten_ops_softsign( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - } - return activation.add_softsign(network, target, kwargs_new, name) - + return add_gelu(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.tanh.default) def aten_ops_tanh( @@ -596,34 +585,7 @@ def aten_ops_tanh( kwargs_new = { "input": args[0], } - return activation.add_tanh(network, target, kwargs_new, name) - -@tensorrt_converter(torch.ops.aten.softsign.default) -def aten_ops_softsign( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - } - return activation.add_softsign(network, target, kwargs_new, name) - - -@tensorrt_converter(torch.ops.aten.softsign.default) -def aten_ops_hard_sigmoid( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - } - return activation.add_hard_sigmoid(network, target, kwargs_new, name) + return add_tanh(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.sigmoid.default) @@ -637,7 +599,7 @@ def aten_ops_hard_tanh( kwargs_new = { "input": args[0], } - return activation.add_hard_tanh(network, target, kwargs_new, name) + return add_hard_tanh(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.sigmoid.default) @@ -651,7 +613,7 @@ def aten_ops_sigmoid( kwargs_new = { "input": args[0], } - return activation.add_sigmoid(network, target, kwargs_new, name) + return add_sigmoid(network, target, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 0fa09918d1..25eb2a00ca 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1,6 +1,7 @@ import numpy as np import operator import warnings +import logging from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import tensorrt as trt @@ -687,7 +688,7 @@ def add_pow(network, target, kwargs, name): network, kwargs["input"], kwargs["other"], - trt.ElementWiseOperation.PROD, + trt.ElementWiseOperation.POW, target, name, ) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py index 44d011ef4e..c632425e19 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py @@ -7,82 +7,37 @@ from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestMatMulConverter(DispatchTestCase): - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1), (1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv1d( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): + def test_matmul(self): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32)] + def forward(self, x, y): + return torch.matmul(x, y) + inputOne = torch.randn(1, 32) + inputTwo = torch.randn(32, 3) + inputs = [inputOne, inputTwo] self.run_test( - TestModule(), - inputs, - expected_ops={torch.ops.aten.convolution.default}, - test_explicit_precision=True, + TestModule(), inputs, expected_ops={torch.ops.aten.matmul}, ) - def test_conv1d_with_dynamic_shape( - self, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): + def test_matmul_with_dynamic_shape(self): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) + def forward(self, x, y): + return torch.matmul(x, y) input_specs = [ InputTensorSpec( - shape=(-1, 3, 3), + shape=(-1, 3), dtype=torch.float32, - shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + shape_ranges=[((3, 3, 1), (3, 3, 3))], + ), + InputTensorSpec( + shape=(3, -1), + dtype=torch.float32, + shape_ranges=[((3, 3, 3), (5, 3, 3))], ), ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + TestModule(), input_specs, expected_ops={torch.ops.aten.mm}, ) - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1), (1, 1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) \ No newline at end of file +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py index e69de29bb2..2a92e206aa 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py @@ -0,0 +1,39 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + +class TestSelectConverter(DispatchTestCase): + def test_select(self): + class TestModule(torch.nn.Module): + def forward(self, input, dim, index): + return torch.select(input, dim, index) + input = [torch.randn(1, 3, 32)] + dim = 2 + index = 1 + inputs = (input, dim, index) + self.run_test( + TestModule(), input, expected_ops={torch.ops.aten.select.Tensor}, test_explicit_precision=True, + ) + + def test_select_with_dynamic_shape(self, x, y): + class TestModule(torch.nn.Module): + def forward(self, input, dim, index): + return torch.select(input, dim, index) + + input_spec = [ + InputTensorSpec( + shape=(-1, 3, 32), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))], + ), + ] + dim = 2 + index = 1 + inputs_spec = (input_spec, dim, index) + self.run_test_with_dynamic_shape( + TestModule(), inputs_spec, expected_ops={torch.ops.aten.select.Tensor} + ) \ No newline at end of file diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py index 35f9498b18..8854284938 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py @@ -3,21 +3,21 @@ from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec -class TestGeLUConverter(DispatchTestCase): - def test_gelu(self): +class TestTanhConverter(DispatchTestCase): + def test_tanh(self): class TestModule(nn.Module): def forward(self, x): - return nn.functional.gelu(x) + return nn.functional.tanh(x) inputs = [torch.randn(1, 10)] self.run_test( - TestModule(), inputs, expected_ops={torch.ops.aten.gelu.default} + TestModule(), inputs, expected_ops={torch.ops.aten.tanh.default} ) - def test_gelu_with_dynamic_shape(self): + def test_tanh_with_dynamic_shape(self): class TestModule(nn.Module): def forward(self, x): - return nn.functional.gelu(x) + return nn.functional.tanh(x) input_specs = [ InputTensorSpec( @@ -27,13 +27,13 @@ def forward(self, x): ), ] self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} ) - def test_gelu_with_dynamic_shape_four_dimensions(self): + def test_tanh_with_dynamic_shape_four_dimensions(self): class TestModule(nn.Module): def forward(self, x): - return nn.functional.gelu(x) + return nn.functional.tanh(x) input_specs = [ InputTensorSpec( @@ -44,7 +44,7 @@ def forward(self, x): ] self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} ) From 4f742a99ad2c3f47ff540dff60ca9256ef8b4e4d Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 17 Mar 2023 10:34:11 -0700 Subject: [PATCH 04/39] Modifications to matmul and select tests --- .../converters/aten_op/test_matmul_aten.py | 2 +- .../converters/aten_op/test_select_aten.py | 32 +++++++++++-------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py index c632425e19..344eeb9a2b 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py @@ -11,7 +11,7 @@ def test_matmul(self): class TestModule(torch.nn.Module): def forward(self, x, y): return torch.matmul(x, y) - inputOne = torch.randn(1, 32) + inputOne = torch.randn(3, 32) inputTwo = torch.randn(32, 3) inputs = [inputOne, inputTwo] self.run_test( diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py index 2a92e206aa..6ba0b7918a 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py @@ -7,19 +7,25 @@ from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec class TestSelectConverter(DispatchTestCase): - def test_select(self): + @parameterized.expand( + [ + ("select_dim_index", 2, 1), + ] + ) + def test_select(self, dim_test, index_test): class TestModule(torch.nn.Module): - def forward(self, input, dim, index): - return torch.select(input, dim, index) + def __init__(self, dim, index): + super().__init__() + self.dim = dim + self.index = index + def forward(self, input): + return torch.select(input, self.dim, self.index) input = [torch.randn(1, 3, 32)] - dim = 2 - index = 1 - inputs = (input, dim, index) self.run_test( - TestModule(), input, expected_ops={torch.ops.aten.select.Tensor}, test_explicit_precision=True, + TestModule(dim_test, index_test), input, expected_ops={torch.ops.aten.select}, test_explicit_precision=True, ) - def test_select_with_dynamic_shape(self, x, y): + def test_select_with_dynamic_shape(self, dim_test, index_test): class TestModule(torch.nn.Module): def forward(self, input, dim, index): return torch.select(input, dim, index) @@ -31,9 +37,9 @@ def forward(self, input, dim, index): shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))], ), ] - dim = 2 - index = 1 - inputs_spec = (input_spec, dim, index) self.run_test_with_dynamic_shape( - TestModule(), inputs_spec, expected_ops={torch.ops.aten.select.Tensor} - ) \ No newline at end of file + TestModule(dim_test, index_test), input_spec, expected_ops={torch.ops.aten.select} + ) + +if __name__ == "__main__": + run_tests() \ No newline at end of file From e8c2786d3bcd138fa2424c2f779ae9761677ef69 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 17 Mar 2023 10:43:26 -0700 Subject: [PATCH 05/39] Fixing aten::select test --- .../fx/test/converters/aten_op/test_select_aten.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py index 6ba0b7918a..a5dcfd9742 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py @@ -12,7 +12,7 @@ class TestSelectConverter(DispatchTestCase): ("select_dim_index", 2, 1), ] ) - def test_select(self, dim_test, index_test): + def test_select(self, _, dim_test, index_test): class TestModule(torch.nn.Module): def __init__(self, dim, index): super().__init__() @@ -25,10 +25,14 @@ def forward(self, input): TestModule(dim_test, index_test), input, expected_ops={torch.ops.aten.select}, test_explicit_precision=True, ) - def test_select_with_dynamic_shape(self, dim_test, index_test): + def test_select_with_dynamic_shape(self, _, dim_test, index_test): class TestModule(torch.nn.Module): - def forward(self, input, dim, index): - return torch.select(input, dim, index) + def __init__(self, dim, index): + super().__init__() + self.dim = dim + self.index = index + def forward(self, input): + return torch.select(input, self.dim, self.index) input_spec = [ InputTensorSpec( From 038520d20117de09aa7feca7f1a9e4a1ccafd1a7 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 17 Mar 2023 12:05:55 -0700 Subject: [PATCH 06/39] Removing matmul and select operator --- .../fx/converters/acc_ops_converters.py | 72 +++++++++---------- py/torch_tensorrt/fx/converters/operator.py | 41 ----------- .../converters/aten_op/test_matmul_aten.py | 43 ----------- .../converters/aten_op/test_select_aten.py | 49 ------------- 4 files changed, 36 insertions(+), 169 deletions(-) delete mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py delete mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index e070376797..c7df3bd7ec 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -80,7 +80,7 @@ def trt_transposed_linear_converter(network, target, args, kwargs, name): trt.MatrixOperation.NONE, ) set_layer_name(layer, target, f"{name}_mm") - return operator.add_binary_elementwise_layer( + return add_binary_elementwise_layer( network, layer.get_output(0), bias, @@ -679,7 +679,7 @@ def acc_ops_batch_norm( @tensorrt_converter(acc_ops.layer_norm) def acc_ops_layer_norm(network, target, args, kwargs, name): - return operator.add_layer_norm(network, target, kwargs, name) + return add_layer_norm(network, target, kwargs, name) @tensorrt_converter(acc_ops.softmax) def acc_ops_softmax( @@ -730,7 +730,7 @@ def acc_ops_tile( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_tile(network, target, kwargs, name) + return add_tile(network, target, kwargs, name) @tensorrt_converter(acc_ops.sign) def acc_ops_sign( @@ -758,7 +758,7 @@ def acc_ops_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_relu(network, target, kwargs, name) + return add_relu(network, target, kwargs, name) @tensorrt_converter(acc_ops.leaky_relu) def acc_ops_leaky_relu( @@ -768,7 +768,7 @@ def acc_ops_leaky_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_leaky_relu(network, target, kwargs, name) + return add_leaky_relu(network, target, kwargs, name) @tensorrt_converter(acc_ops.elu) def acc_ops_elu( @@ -778,7 +778,7 @@ def acc_ops_elu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_elu(network, target, kwargs, name) + return add_elu(network, target, kwargs, name) @tensorrt_converter(acc_ops.selu) def acc_ops_selu( @@ -788,7 +788,7 @@ def acc_ops_selu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_selu(network, target, kwargs, name) + return add_selu(network, target, kwargs, name) @tensorrt_converter(acc_ops.softsign) def acc_ops_softsign( @@ -798,7 +798,7 @@ def acc_ops_softsign( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_softsign(network, target, kwargs, name) + return add_softsign(network, target, kwargs, name) @tensorrt_converter(acc_ops.sin) def acc_ops_sin( @@ -873,7 +873,7 @@ def acc_ops_tanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_tanh(network, target, kwargs, name) + return add_tanh(network, target, kwargs, name) @tensorrt_converter(acc_ops.asin) def acc_ops_asin( @@ -1190,7 +1190,7 @@ def acc_ops_maximum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_maximum(network, target, kwargs, name) + return add_maximum(network, target, kwargs, name) @tensorrt_converter(acc_ops.minimum) def acc_ops_minimum( @@ -1200,7 +1200,7 @@ def acc_ops_minimum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_minimum(network, target, kwargs, name) + return add_minimum(network, target, kwargs, name) @tensorrt_converter(acc_ops.dtype) def acc_ops_dtype( @@ -1269,7 +1269,7 @@ def acc_ops_logical_and( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_logical_and(network, target, kwargs, name) + return add_logical_and(network, target, kwargs, name) @tensorrt_converter(acc_ops.ne, no_implicit_batch_dim=True) def acc_ops_ne( @@ -1279,7 +1279,7 @@ def acc_ops_ne( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_ne(network, target, kwargs, name) + return add_ne(network, target, kwargs, name) @tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True) def acc_ops_eq( @@ -1289,7 +1289,7 @@ def acc_ops_eq( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_eq(network, target, kwargs, name) + return add_eq(network, target, kwargs, name) @tensorrt_converter(acc_ops.gt, no_implicit_batch_dim=True) def acc_ops_gt( @@ -1299,7 +1299,7 @@ def acc_ops_gt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_gt(network, target, kwargs, name) + return add_gt(network, target, kwargs, name) @tensorrt_converter(acc_ops.lt, no_implicit_batch_dim=True) def acc_ops_lt( @@ -1309,7 +1309,7 @@ def acc_ops_lt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_lt(network, target, kwargs, name) + return add_lt(network, target, kwargs, name) @tensorrt_converter(acc_ops.logical_or, no_implicit_batch_dim=True) @@ -1320,7 +1320,7 @@ def acc_ops_logical_or( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_logical_or(network, target, kwargs, name) + return add_logical_or(network, target, kwargs, name) @tensorrt_converter(acc_ops.logical_xor, no_implicit_batch_dim=True) def acc_ops_logical_xor( @@ -1330,7 +1330,7 @@ def acc_ops_logical_xor( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_logical_xor(network, target, kwargs, name) + return add_logical_xor(network, target, kwargs, name) # T113156424 Have some accuracy problems in hf_T5. # [TRT] [W] Weights [name=isinf_1_inf_t]: Converted FP32 value in weights (either FP32 infinity or FP32 value outside FP16 range) to corresponding FP16 infinity. If this is not the desired behavior, please modify the weights or retrain with regularization to reduce the magnitude of the weights. @@ -1423,7 +1423,7 @@ def acc_ops_fmod( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_fmod(network, target, kwargs, name) + return add_fmod(network, target, kwargs, name) # T113156424 embedding implemenatation is very limited and shows no usage in hf models due to the indices are int64. # if we cast to int32, it will create accuracy issues. We'd better leave it to future implementation. @@ -1651,7 +1651,7 @@ def acc_ops_add( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_add(network, target, kwargs, name) + return add_add(network, target, kwargs, name) @tensorrt_converter(acc_ops.sub) def acc_ops_sub( @@ -1661,7 +1661,7 @@ def acc_ops_sub( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_sub(network, target, kwargs, name) + return add_sub(network, target, kwargs, name) @tensorrt_converter(acc_ops.div) def acc_ops_div( @@ -1671,7 +1671,7 @@ def acc_ops_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_div(network, target, kwargs, name) + return add_div(network, target, kwargs, name) @tensorrt_converter(acc_ops.floor_div) def acc_ops_floor_div( @@ -1681,7 +1681,7 @@ def acc_ops_floor_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_floor_div(network, target, kwargs, name) + return add_floor_div(network, target, kwargs, name) @tensorrt_converter(acc_ops.trunc_div) def acc_ops_trunc_div( @@ -1691,7 +1691,7 @@ def acc_ops_trunc_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_trunc_div(network, target, kwargs, name) + return add_trunc_div(network, target, kwargs, name) @tensorrt_converter(acc_ops.mul) def acc_ops_mul( @@ -1701,7 +1701,7 @@ def acc_ops_mul( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_mul(network, target, kwargs, name) + return add_mul(network, target, kwargs, name) @tensorrt_converter(acc_ops.pow) def acc_ops_pow( @@ -1711,7 +1711,7 @@ def acc_ops_pow( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_pow(network, target, kwargs, name) + return add_pow(network, target, kwargs, name) @tensorrt_converter(acc_ops.unsqueeze) def acc_ops_unsqueeze( @@ -1981,7 +1981,7 @@ def acc_ops_slice_tensor( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_slice(network, target, kwargs, name) + return add_slice(network, target, kwargs, name) @tensorrt_converter(acc_ops.expand) @@ -1992,7 +1992,7 @@ def acc_ops_expand_tensor( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_expand(network, target, kwargs, name) + return add_expand(network, target, kwargs, name) @tensorrt_converter(acc_ops.where) @@ -2214,7 +2214,7 @@ def acc_ops_linear( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_linear(network, target, kwargs, name) + return add_linear(network, target, kwargs, name) def add_clamp(network, input, val, op, name): if not len(input.shape): @@ -2310,7 +2310,7 @@ def acc_ops_getitem( input_val = kwargs["input"] slices = kwargs["idx"] if not isinstance(input_val, TRTTensor): - return operator.getitem(input_val, slices) # type: ignore[arg-type] + return getitem(input_val, slices) # type: ignore[arg-type] if not isinstance(slices, tuple) and not isinstance(slices, list): slices = (slices,) @@ -2467,7 +2467,7 @@ def acc_ops_matmul( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_matmul(network, target, kwargs, name) + return add_matmul(network, target, kwargs, name) @tensorrt_converter(acc_ops.hardsigmoid) def acc_ops_hard_sigmoid( @@ -2477,7 +2477,7 @@ def acc_ops_hard_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_hard_sigmoid(network, target, kwargs, name) + return add_hard_sigmoid(network, target, kwargs, name) @tensorrt_converter(acc_ops.sigmoid) @@ -2488,7 +2488,7 @@ def acc_ops_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_sigmoid(network, target, kwargs, name) + return add_sigmoid(network, target, kwargs, name) @tensorrt_converter(acc_ops.permute) @@ -2689,7 +2689,7 @@ def acc_ops_gelu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_gelu(network, target, kwargs, name) + return add_gelu(network, target, kwargs, name) @tensorrt_converter(acc_ops.chunk) @@ -2766,7 +2766,7 @@ def acc_ops_cumsum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return operator.add_cumsum(network, target, kwargs, name) + return add_cumsum(network, target, kwargs, name) @tensorrt_converter(acc_ops.hardtanh) def acc_ops_hardtanh( @@ -2776,7 +2776,7 @@ def acc_ops_hardtanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.add_hardtanh(network, target, kwargs, name) + return add_hardtanh(network, target, kwargs, name) @tensorrt_converter(acc_ops.interpolate) def acc_ops_interpolate( diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 25eb2a00ca..294d76f357 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -977,47 +977,6 @@ def add_slice(network, target, kwargs, name): set_layer_name(layer, target, name) return layer.get_output(0) -def add_select(network, target, kwargs, name): - input_val = kwargs["input"] - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"slice_tensor received input {input_val} that is not part " - "of the TensorRT region!" - ) - - ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) - dynamic_shape = has_dynamic_shape(input_val.shape) - if network.has_implicit_batch_dimension: - if dim == 0: - raise RuntimeError( - f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" - ) - dim = dim - 1 - else: - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input_val.shape[dim] != -1, "Can't select on negative shape dimension!" - index = kwargs[2] - if index >= input_val.shape[dim]: - raise RuntimeError( - f"cannot have index greater than the dimension length! {input_val.shape[dim]}" - ) - output_shape = list(input_val.shape) - output_shape[dim] = 1 - if dynamic_shape > 0: - output_shape = get_shape_with_dynamic_shape( - network, output_shape, input_val, target, name - ) - layer = network.add_gather( - input_val, - dim, - index - ) - out = layer.getOutput(0) - if(len(out.shape) != 1): - layer = network.add_shuffle(out) - return layer.getOutput(0) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py deleted file mode 100644 index 344eeb9a2b..0000000000 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py +++ /dev/null @@ -1,43 +0,0 @@ -import unittest - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec - -class TestMatMulConverter(DispatchTestCase): - def test_matmul(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - return torch.matmul(x, y) - inputOne = torch.randn(3, 32) - inputTwo = torch.randn(32, 3) - inputs = [inputOne, inputTwo] - self.run_test( - TestModule(), inputs, expected_ops={torch.ops.aten.matmul}, - ) - - def test_matmul_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - return torch.matmul(x, y) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3), - dtype=torch.float32, - shape_ranges=[((3, 3, 1), (3, 3, 3))], - ), - InputTensorSpec( - shape=(3, -1), - dtype=torch.float32, - shape_ranges=[((3, 3, 3), (5, 3, 3))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.mm}, - ) - -if __name__ == "__main__": - run_tests() \ No newline at end of file diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py deleted file mode 100644 index a5dcfd9742..0000000000 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py +++ /dev/null @@ -1,49 +0,0 @@ -import unittest - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec - -class TestSelectConverter(DispatchTestCase): - @parameterized.expand( - [ - ("select_dim_index", 2, 1), - ] - ) - def test_select(self, _, dim_test, index_test): - class TestModule(torch.nn.Module): - def __init__(self, dim, index): - super().__init__() - self.dim = dim - self.index = index - def forward(self, input): - return torch.select(input, self.dim, self.index) - input = [torch.randn(1, 3, 32)] - self.run_test( - TestModule(dim_test, index_test), input, expected_ops={torch.ops.aten.select}, test_explicit_precision=True, - ) - - def test_select_with_dynamic_shape(self, _, dim_test, index_test): - class TestModule(torch.nn.Module): - def __init__(self, dim, index): - super().__init__() - self.dim = dim - self.index = index - def forward(self, input): - return torch.select(input, self.dim, self.index) - - input_spec = [ - InputTensorSpec( - shape=(-1, 3, 32), - dtype=torch.float32, - shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(dim_test, index_test), input_spec, expected_ops={torch.ops.aten.select} - ) - -if __name__ == "__main__": - run_tests() \ No newline at end of file From e8a8e38956d2c8b37b58ec2144af08b8aafe5fc9 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 20 Mar 2023 13:02:06 -0700 Subject: [PATCH 07/39] fx2trt fixing add converter and python linting changes --- .../fx/converters/acc_ops_converters.py | 72 +++++++++++++------ py/torch_tensorrt/fx/converters/activation.py | 21 ++++-- .../fx/converters/aten_ops_converters.py | 43 +++++------ .../fx/converters/converter_utils.py | 5 +- .../fx/converters/nn_ops_converters.py | 10 +-- py/torch_tensorrt/fx/converters/operator.py | 36 ++++++++-- .../test/converters/aten_op/test_gelu_aten.py | 7 +- .../aten_op/test_leaky_relu_aten.py | 2 +- .../test/converters/aten_op/test_tanh_aten.py | 7 +- 9 files changed, 130 insertions(+), 73 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index c7df3bd7ec..0e77a1c659 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -680,7 +680,8 @@ def acc_ops_batch_norm( @tensorrt_converter(acc_ops.layer_norm) def acc_ops_layer_norm(network, target, args, kwargs, name): return add_layer_norm(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.softmax) def acc_ops_softmax( network: TRTNetwork, @@ -731,7 +732,8 @@ def acc_ops_tile( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_tile(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.sign) def acc_ops_sign( network: TRTNetwork, @@ -760,6 +762,7 @@ def acc_ops_relu( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_relu(network, target, kwargs, name) + @tensorrt_converter(acc_ops.leaky_relu) def acc_ops_leaky_relu( network: TRTNetwork, @@ -769,7 +772,8 @@ def acc_ops_leaky_relu( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_leaky_relu(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.elu) def acc_ops_elu( network: TRTNetwork, @@ -779,7 +783,8 @@ def acc_ops_elu( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_elu(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.selu) def acc_ops_selu( network: TRTNetwork, @@ -790,6 +795,7 @@ def acc_ops_selu( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_selu(network, target, kwargs, name) + @tensorrt_converter(acc_ops.softsign) def acc_ops_softsign( network: TRTNetwork, @@ -799,7 +805,8 @@ def acc_ops_softsign( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_softsign(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.sin) def acc_ops_sin( network: TRTNetwork, @@ -875,6 +882,7 @@ def acc_ops_tanh( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_tanh(network, target, kwargs, name) + @tensorrt_converter(acc_ops.asin) def acc_ops_asin( network: TRTNetwork, @@ -1191,7 +1199,8 @@ def acc_ops_maximum( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_maximum(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.minimum) def acc_ops_minimum( network: TRTNetwork, @@ -1201,7 +1210,8 @@ def acc_ops_minimum( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_minimum(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.dtype) def acc_ops_dtype( network: TRTNetwork, @@ -1271,6 +1281,7 @@ def acc_ops_logical_and( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_logical_and(network, target, kwargs, name) + @tensorrt_converter(acc_ops.ne, no_implicit_batch_dim=True) def acc_ops_ne( network: TRTNetwork, @@ -1280,7 +1291,8 @@ def acc_ops_ne( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_ne(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True) def acc_ops_eq( network: TRTNetwork, @@ -1290,7 +1302,8 @@ def acc_ops_eq( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_eq(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.gt, no_implicit_batch_dim=True) def acc_ops_gt( network: TRTNetwork, @@ -1300,7 +1313,8 @@ def acc_ops_gt( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_gt(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.lt, no_implicit_batch_dim=True) def acc_ops_lt( network: TRTNetwork, @@ -1310,7 +1324,7 @@ def acc_ops_lt( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_lt(network, target, kwargs, name) - + @tensorrt_converter(acc_ops.logical_or, no_implicit_batch_dim=True) def acc_ops_logical_or( @@ -1321,7 +1335,8 @@ def acc_ops_logical_or( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_logical_or(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.logical_xor, no_implicit_batch_dim=True) def acc_ops_logical_xor( network: TRTNetwork, @@ -1331,7 +1346,8 @@ def acc_ops_logical_xor( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_logical_xor(network, target, kwargs, name) - + + # T113156424 Have some accuracy problems in hf_T5. # [TRT] [W] Weights [name=isinf_1_inf_t]: Converted FP32 value in weights (either FP32 infinity or FP32 value outside FP16 range) to corresponding FP16 infinity. If this is not the desired behavior, please modify the weights or retrain with regularization to reduce the magnitude of the weights. # @tensorrt_converter(acc_ops.isinf) @@ -1425,6 +1441,7 @@ def acc_ops_fmod( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_fmod(network, target, kwargs, name) + # T113156424 embedding implemenatation is very limited and shows no usage in hf models due to the indices are int64. # if we cast to int32, it will create accuracy issues. We'd better leave it to future implementation. # @tensorrt_converter(acc_ops.embedding, no_implicit_batch_dim=True) @@ -1653,6 +1670,7 @@ def acc_ops_add( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_add(network, target, kwargs, name) + @tensorrt_converter(acc_ops.sub) def acc_ops_sub( network: TRTNetwork, @@ -1663,6 +1681,7 @@ def acc_ops_sub( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_sub(network, target, kwargs, name) + @tensorrt_converter(acc_ops.div) def acc_ops_div( network: TRTNetwork, @@ -1673,6 +1692,7 @@ def acc_ops_div( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_div(network, target, kwargs, name) + @tensorrt_converter(acc_ops.floor_div) def acc_ops_floor_div( network: TRTNetwork, @@ -1682,7 +1702,8 @@ def acc_ops_floor_div( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_floor_div(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.trunc_div) def acc_ops_trunc_div( network: TRTNetwork, @@ -1692,7 +1713,8 @@ def acc_ops_trunc_div( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_trunc_div(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.mul) def acc_ops_mul( network: TRTNetwork, @@ -1702,7 +1724,8 @@ def acc_ops_mul( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_mul(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.pow) def acc_ops_pow( network: TRTNetwork, @@ -1713,6 +1736,7 @@ def acc_ops_pow( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_pow(network, target, kwargs, name) + @tensorrt_converter(acc_ops.unsqueeze) def acc_ops_unsqueeze( network: TRTNetwork, @@ -1982,7 +2006,7 @@ def acc_ops_slice_tensor( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_slice(network, target, kwargs, name) - + @tensorrt_converter(acc_ops.expand) def acc_ops_expand_tensor( @@ -1993,7 +2017,7 @@ def acc_ops_expand_tensor( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_expand(network, target, kwargs, name) - + @tensorrt_converter(acc_ops.where) def acc_ops_where( @@ -2215,7 +2239,8 @@ def acc_ops_linear( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_linear(network, target, kwargs, name) - + + def add_clamp(network, input, val, op, name): if not len(input.shape): # clamping scalar @@ -2468,7 +2493,8 @@ def acc_ops_matmul( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_matmul(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.hardsigmoid) def acc_ops_hard_sigmoid( network: TRTNetwork, @@ -2690,7 +2716,7 @@ def acc_ops_gelu( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_gelu(network, target, kwargs, name) - + @tensorrt_converter(acc_ops.chunk) def acc_ops_chunk( @@ -2767,7 +2793,8 @@ def acc_ops_cumsum( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_cumsum(network, target, kwargs, name) - + + @tensorrt_converter(acc_ops.hardtanh) def acc_ops_hardtanh( network: TRTNetwork, @@ -2778,6 +2805,7 @@ def acc_ops_hardtanh( ) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_hardtanh(network, target, kwargs, name) + @tensorrt_converter(acc_ops.interpolate) def acc_ops_interpolate( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/activation.py b/py/torch_tensorrt/fx/converters/activation.py index 4b774b8d9e..118f6d3105 100644 --- a/py/torch_tensorrt/fx/converters/activation.py +++ b/py/torch_tensorrt/fx/converters/activation.py @@ -25,6 +25,7 @@ TRTTensor, ) + def add_activation_layer( network: TRTNetwork, input_val: TRTTensor, @@ -33,7 +34,7 @@ def add_activation_layer( name: str, alpha: Optional[Any] = None, beta: Optional[Any] = None, - dyn_range_fn: Optional[Callable[[float, float], Any]] = None + dyn_range_fn: Optional[Callable[[float, float], Any]] = None, ) -> TRTTensor: """ Add a TensorRT Activation layer to `network`. @@ -51,7 +52,7 @@ def add_activation_layer( beta (Optional[Any]): If not None, we will use it to set the beta attribute of the created TensorRT activation layer. dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range - + Returns: The output of TensorRT Activation layer. @@ -73,11 +74,13 @@ def add_activation_layer( mark_as_int8_layer(layer, dyn_range) return layer.get_output(0) + def add_relu(network, target, kwargs, name): input_val = kwargs["input"] operation_type = trt.ActivationType.RELU return add_activation_layer(network, input_val, operation_type, target, name) + def add_leaky_relu(network, target, kwargs, name): input_val = kwargs["input"] negative_slope = kwargs["negative_slope"] @@ -86,33 +89,40 @@ def add_leaky_relu(network, target, kwargs, name): network, input_val, operation_type, target, name, negative_slope ) + def add_elu(network, target, kwargs, name): input_val = kwargs["input"] alpha = kwargs["alpha"] operation_type = trt.ActivationType.ELU return add_activation_layer(network, input_val, operation_type, target, name, alpha) + def add_selu(network, target, kwargs, name): input_val = kwargs["input"] operation_type = trt.ActivationType.SELU return add_activation_layer(network, input_val, operation_type, target, name) + def add_softsign(network, target, kwargs, name): input_val = kwargs["input"] operation_type = trt.ActivationType.SOFTSIGN return add_activation_layer(network, input_val, operation_type, target, name) + def add_tanh(network, target, kwargs, name): input_val = kwargs["input"] operation_type = trt.ActivationType.TANH return add_activation_layer(network, input_val, operation_type, target, name) + def add_gelu(network, target, kwargs, name): input_val = kwargs["input"] if "approximate" in kwargs.keys(): approximate = kwargs["approximate"] if approximate != "none": - raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") + raise RuntimeError( + "GeLU converter currently doesn't support fast gelu compute" + ) if not isinstance(input_val, TRTTensor): raise RuntimeError( f"GELU received input {input_val} that is not part " @@ -137,6 +147,7 @@ def add_gelu(network, target, kwargs, name): set_layer_name(layer, target, name) return layer.get_output(0) + def add_hard_sigmoid(network, target, kwargs, name): input_val = kwargs["input"] @@ -156,6 +167,7 @@ def add_hard_sigmoid(network, target, kwargs, name): beta=0.5, ) + def add_sigmoid(network, target, kwargs, name): input_val = kwargs["input"] @@ -169,11 +181,13 @@ def add_sigmoid(network, target, kwargs, name): network, input_val, trt.ActivationType.SIGMOID, target, name ) + def add_hard_tanh(network, target, kwargs, name): input_val = kwargs["input"] operation_type = trt.ActivationType.TANH return add_activation_layer(network, input_val, operation_type, target, name) + def add_sigmoid(network, target, kwargs, name): input_val = kwargs["input"] @@ -192,4 +206,3 @@ def add_sigmoid(network, target, kwargs, name): alpha=1 / 6, beta=0.5, ) - diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index d26507e805..a026e7b0cf 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -42,7 +42,7 @@ def aten_ops_add( "input": args[0], "other": args[1], } - return add_add(network, target, None, kwargs_new, name) + return add_add(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.mean.dim) @@ -56,7 +56,6 @@ def aten_ops_adaptive_avg_poolnd( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: if target == torch.ops.aten.mean.dim: - if list(args[1]) != [-1, -2]: raise RuntimeError(f"We do not support {target} has dim={args[1]}") else: @@ -147,13 +146,9 @@ def aten_ops_div( if rounding_mode is None: return add_div(network, target, None, kwargs_new, name) elif rounding_mode == "floor": - return add_floor_div( - network, target, None, kwargs_new, name - ) + return add_floor_div(network, target, None, kwargs_new, name) elif rounding_mode == "trunc": - return add_trunc_div( - network, target, None, kwargs_new, name - ) + return add_trunc_div(network, target, None, kwargs_new, name) else: raise RuntimeError( f"Target {target} does not support rounding mode {rounding_mode}" @@ -301,7 +296,8 @@ def aten_ops_relu( "input": args[0], } return add_relu(network, target, kwargs_new, name) - + + @tensorrt_converter(torch.ops.aten.sub.Tensor) def aten_ops_sub( network: TRTNetwork, @@ -382,7 +378,7 @@ def aten_ops_expand( "sizes": args[1], } return add_expand(network, target, kwargs_new, name) - + @tensorrt_converter(operator.floordiv) def aten_ops_operator_floordiv( @@ -494,13 +490,13 @@ def aten_ops_slice( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: kwargs_new = { - "input" : args[0], - "dim" : args[1], - "start" : args[2], - "stop" : args[3], - "step" : args[4], + "input": args[0], + "dim": args[1], + "start": args[2], + "stop": args[3], + "step": args[4], } - return add_slice(network, target. kwargs_new, name) + return add_slice(network, target.kwargs_new, name) @tensorrt_converter(torch.ops.aten.select) @@ -512,11 +508,11 @@ def aten_ops_select( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: kwargs_new = { - "input" : args[0], - "dim" : args[1], - "index" : args[2], + "input": args[0], + "dim": args[1], + "index": args[2], } - return add_select(network, target. kwargs_new, name) + return add_select(network, target.kwargs_new, name) @tensorrt_converter(torch.ops.aten.leaky_relu.default) @@ -574,6 +570,7 @@ def aten_ops_gelu( } return add_gelu(network, target, kwargs_new, name) + @tensorrt_converter(torch.ops.aten.tanh.default) def aten_ops_tanh( network: TRTNetwork, @@ -614,9 +611,3 @@ def aten_ops_sigmoid( "input": args[0], } return add_sigmoid(network, target, kwargs_new, name) - - - - - - diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 4820a1c782..34ba88371e 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -409,9 +409,6 @@ def get_shape_with_dynamic_shape( return select_layer.get_output(0) - - - def squeeze_left(const: torch.Tensor): """ Squeeze the size-1 dimensions on the left side of the shape tuple. @@ -422,6 +419,7 @@ def squeeze_left(const: torch.Tensor): const = const.squeeze(dim=0) return const + def add_reduce_layer( network: TRTNetwork, target: Target, @@ -641,6 +639,7 @@ def trunc_div( return output + def dtype_uniform( network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor ): diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py index 3276285c86..9da6a71bfc 100644 --- a/py/torch_tensorrt/fx/converters/nn_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/nn_ops_converters.py @@ -9,23 +9,25 @@ from .converter_utils import mark_as_int8_layer import activation + @tensorrt_converter(torch.nn.functional.relu) @tensorrt_converter(torch.nn.modules.activation.ReLU) def relu(network, submod, args, kwargs, layer_name): # args/kwargs should have already been normalized to kwargs assert len(args) == 0 - return activation.add_relu(network,"tensorrt", kwargs, layer_name) + return activation.add_relu(network, "tensorrt", kwargs, layer_name) + @tensorrt_converter(torch.nn.functional.leaky_relu) @tensorrt_converter(torch.nn.modules.activation.leaky_relu) def leaky_relu(network, submod, args, kwargs, layer_name): # args/kwargs should have already been normalized to kwargs assert len(args) == 0 - return activation.add_leaky_relu(network,"tensorrt", kwargs, layer_name) + return activation.add_leaky_relu(network, "tensorrt", kwargs, layer_name) + @tensorrt_converter(torch.nn.modules.activation.Sigmoid) def sigmoid(network, submod, args, kwargs, layer_name): # args/kwargs should have already been normalized to kwargs assert len(args) == 0 - return activation.add_sigmoid(network,"tensorrt", kwargs, layer_name) - + return activation.add_sigmoid(network, "tensorrt", kwargs, layer_name) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 294d76f357..ffca964633 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -36,6 +36,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) + def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, ) -> Callable[[Any, Any], Any]: @@ -52,6 +53,7 @@ def get_python_op_from_trt_elementwise_op( else: raise RuntimeError(f"{trt_op} is not supported yet!") + def add_binary_elementwise_layer( network: TRTNetwork, lhs_val: Union[int, float, TRTTensor, torch.Tensor], @@ -159,6 +161,7 @@ def add_binary_elementwise_layer( output.name = output.name + "_" + target.__name__ return output + def trunc_div( input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str ) -> TRTTensor: @@ -215,6 +218,7 @@ def trunc_div( return output + def add_tile(network, target, kwargs, name): input_t = kwargs["input"] input_val = get_trt_tensor(network, input_t, f"{name}_input") @@ -314,6 +318,7 @@ def add_tile(network, target, kwargs, name): return layer.get_output(0) + def add_linear(network, target, kwargs, name): input_val = kwargs["input"] @@ -368,6 +373,7 @@ def add_linear(network, target, kwargs, name): ) return res + def add_unary_layer( network: TRTNetwork, input_val: TRTTensor, @@ -399,6 +405,7 @@ def add_unary_layer( output.name = output.name + "_" + target.__name__ return layer.get_output(0) + def layer_norm( network: TRTNetwork, target: Target, @@ -509,6 +516,7 @@ def layer_norm( name, ) + def add_add(network, target, kwargs, name): return add_binary_elementwise_layer( network, @@ -519,6 +527,7 @@ def add_add(network, target, kwargs, name): name, ) + def add_matmul(network, target, kwargs, name): input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input") other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other") @@ -549,6 +558,7 @@ def add_matmul(network, target, kwargs, name): set_layer_name(layer, target, name) return layer.get_output(0) + def add_layer_norm(network, target, kwargs, name): input_val = kwargs["input"] @@ -592,6 +602,7 @@ def add_layer_norm(network, target, kwargs, name): layer.name = name return layer.get_output(0) + def add_cumsum(network, target, kwargs, name): input_val = kwargs["input"] dim = cast(int, kwargs["dim"]) @@ -663,6 +674,7 @@ def add_cumsum(network, target, kwargs, name): loop_output.set_input(1, trip_limit) return loop_output.get_output(0) + def add_maximum(network, target, kwargs, name): return add_binary_elementwise_layer( network, @@ -673,6 +685,7 @@ def add_maximum(network, target, kwargs, name): name, ) + def add_mul(network, target, kwargs, name): return add_binary_elementwise_layer( network, @@ -683,6 +696,7 @@ def add_mul(network, target, kwargs, name): name, ) + def add_pow(network, target, kwargs, name): return add_binary_elementwise_layer( network, @@ -693,6 +707,7 @@ def add_pow(network, target, kwargs, name): name, ) + def add_floor_div(network, target, kwargs, name): return add_binary_elementwise_layer( network, @@ -703,6 +718,7 @@ def add_floor_div(network, target, kwargs, name): name, ) + def add_div(network, target, kwargs, name): return add_binary_elementwise_layer( network, @@ -713,6 +729,7 @@ def add_div(network, target, kwargs, name): name, ) + def add_sub(network, target, kwargs, name): return add_binary_elementwise_layer( network, @@ -722,6 +739,8 @@ def add_sub(network, target, kwargs, name): target, name, ) + + def add_minimum(network, target, kwargs, name): return add_binary_elementwise_layer( network, @@ -732,6 +751,7 @@ def add_minimum(network, target, kwargs, name): name, ) + def add_logical_and(network, target, kwargs, name): if network.has_implicit_batch_dimension: raise RuntimeError( @@ -751,6 +771,7 @@ def add_logical_and(network, target, kwargs, name): return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) + def add_ne(network, target, kwargs, name): if network.has_implicit_batch_dimension: raise RuntimeError( @@ -770,6 +791,7 @@ def add_ne(network, target, kwargs, name): return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) + def add_eq(network, target, kwargs, name): if network.has_implicit_batch_dimension: raise RuntimeError( @@ -787,6 +809,7 @@ def add_eq(network, target, kwargs, name): network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name ) + def add_gt(network, target, kwargs, name): if network.has_implicit_batch_dimension: raise RuntimeError( @@ -804,6 +827,7 @@ def add_gt(network, target, kwargs, name): network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name ) + def add_lt(network, target, kwargs, name): if network.has_implicit_batch_dimension: raise RuntimeError( @@ -821,6 +845,7 @@ def add_lt(network, target, kwargs, name): network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name ) + def add_logical_or(network, target, kwargs, name): if network.has_implicit_batch_dimension: raise RuntimeError( @@ -850,6 +875,7 @@ def add_logical_or(network, target, kwargs, name): network, input_t, other_t, trt.ElementWiseOperation.OR, target, name ) + def add_logical_xor(network, target, kwargs, name): if network.has_implicit_batch_dimension: raise RuntimeError( @@ -879,8 +905,9 @@ def add_logical_xor(network, target, kwargs, name): network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name ) + def add_fmod(network, target, kwargs, name): - # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it + # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it trunc_div_value = trunc_div( kwargs["input"], kwargs["other"], network, target, name + "_trunc_div" ) @@ -902,9 +929,11 @@ def add_fmod(network, target, kwargs, name): ) return sub_value + def add_trunc_div(network, target, kwargs, name): return trunc_div(kwargs["input"], kwargs["other"], network, target, name) + def add_expand(network, target, kwargs, name): input_t = kwargs["input"] shape = list(kwargs["sizes"]) @@ -929,6 +958,7 @@ def add_expand(network, target, kwargs, name): set_layer_name(layer, target, name) return layer.get_output(0) + def add_slice(network, target, kwargs, name): input_val = kwargs["input"] @@ -976,7 +1006,3 @@ def add_slice(network, target, kwargs, name): layer.set_input(2, output_shape) set_layer_name(layer, target, name) return layer.get_output(0) - - - - diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py index 35f9498b18..bcc9cbb761 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py @@ -3,6 +3,7 @@ from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + class TestGeLUConverter(DispatchTestCase): def test_gelu(self): class TestModule(nn.Module): @@ -10,9 +11,7 @@ def forward(self, x): return nn.functional.gelu(x) inputs = [torch.randn(1, 10)] - self.run_test( - TestModule(), inputs, expected_ops={torch.ops.aten.gelu.default} - ) + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.gelu.default}) def test_gelu_with_dynamic_shape(self): class TestModule(nn.Module): @@ -49,4 +48,4 @@ def forward(self, x): if __name__ == "__main__": - run_tests() \ No newline at end of file + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py index a220b67d67..e694acf05a 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py @@ -50,4 +50,4 @@ def forward(self, x): if __name__ == "__main__": - run_tests() \ No newline at end of file + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py index 8854284938..581a5e589f 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py @@ -3,6 +3,7 @@ from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + class TestTanhConverter(DispatchTestCase): def test_tanh(self): class TestModule(nn.Module): @@ -10,9 +11,7 @@ def forward(self, x): return nn.functional.tanh(x) inputs = [torch.randn(1, 10)] - self.run_test( - TestModule(), inputs, expected_ops={torch.ops.aten.tanh.default} - ) + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.tanh.default}) def test_tanh_with_dynamic_shape(self): class TestModule(nn.Module): @@ -49,4 +48,4 @@ def forward(self, x): if __name__ == "__main__": - run_tests() \ No newline at end of file + run_tests() From d7c82ab967e900dfb9fd7e925ed38a82ec0447cf Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 20 Mar 2023 18:38:54 -0700 Subject: [PATCH 08/39] binary operator changes in aten --- .../fx/converters/aten_ops_converters.py | 42 ++----- .../fx/converters/converter_utils.py | 116 ------------------ py/torch_tensorrt/fx/converters/operator.py | 61 ++++++++- 3 files changed, 73 insertions(+), 146 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index a026e7b0cf..4399f37a44 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -144,11 +144,11 @@ def aten_ops_div( } rounding_mode = kwargs.get("rounding_mode") if rounding_mode is None: - return add_div(network, target, None, kwargs_new, name) + return add_div(network, target, kwargs_new, name) elif rounding_mode == "floor": - return add_floor_div(network, target, None, kwargs_new, name) + return add_floor_div(network, target, kwargs_new, name) elif rounding_mode == "trunc": - return add_trunc_div(network, target, None, kwargs_new, name) + return add_trunc_div(network, target, kwargs_new, name) else: raise RuntimeError( f"Target {target} does not support rounding mode {rounding_mode}" @@ -167,7 +167,7 @@ def aten_ops_floor_div( "input": args[0], "other": args[1], } - return add_floor_div(network, target, None, kwargs_new, name) + return add_floor_div(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.fmod.Scalar) @@ -183,7 +183,7 @@ def aten_ops_fmod( "input": args[0], "other": args[1], } - return add_fmod(network, target, None, kwargs_new, name) + return add_fmod(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.linear) @@ -200,7 +200,7 @@ def aten_ops_linear( "bias": args[2], } - return add_linear(network, target, None, kwargs_new, name) + return add_linear(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.max_pool3d) @@ -249,7 +249,7 @@ def aten_ops_mul( "input": args[0], "other": args[1], } - return add_mul(network, target, None, kwargs_new, name) + return add_mul(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.matmul) @@ -265,7 +265,7 @@ def aten_ops_matmul( "input": args[0], "other": args[1], } - return add_matmul(network, target, None, kwargs_new, name) + return add_matmul(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) @@ -310,7 +310,7 @@ def aten_ops_sub( "input": args[0], "other": args[1], } - return add_sub(network, target, None, kwargs_new, name) + return add_sub(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.view.default) @@ -392,7 +392,7 @@ def aten_ops_operator_floordiv( "input": args[0], "other": args[1], } - return add_floor_div(network, target, None, kwargs_new, name) + return add_floor_div(network, target, kwargs_new, name) @tensorrt_converter(operator.mul) @@ -407,7 +407,7 @@ def aten_ops_operator_mul( "input": args[0], "other": args[1], } - return add_mul(network, target, None, kwargs_new, name) + return add_mul(network, target, kwargs_new, name) @tensorrt_converter(operator.add) @@ -422,7 +422,7 @@ def aten_ops_operator_add( "input": args[0], "other": args[1], } - return add_add(network, target, None, kwargs_new, name) + return add_add(network, target, kwargs_new, name) @tensorrt_converter(operator.sub) @@ -437,7 +437,7 @@ def aten_ops_operator_sub( "input": args[0], "other": args[1], } - return add_sub(network, target, None, kwargs_new, name) + return add_sub(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.sym_numel) @@ -499,22 +499,6 @@ def aten_ops_slice( return add_slice(network, target.kwargs_new, name) -@tensorrt_converter(torch.ops.aten.select) -def aten_ops_select( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - "dim": args[1], - "index": args[2], - } - return add_select(network, target.kwargs_new, name) - - @tensorrt_converter(torch.ops.aten.leaky_relu.default) def aten_ops_leaky_relu( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 34ba88371e..062d4dde5b 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -524,122 +524,6 @@ def get_inputs_from_args_and_kwargs(args, kwargs, input_names): return inputs -def sign( - network: TRTNetwork, input_val: TRTTensor, target: Target, name: str -) -> TRTTensor: - """ - Sign is calculated as below: - x = input - sign = (exp(x) // exp(abs(x))) * 2 - 1 - For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. - With multiply 2, the value become 2(for pos and 0) and 0(for neg). - Finally minus 1, the value become 1(for pos and 0) and -1(for neg). - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): The input tensor. - target (Target): fx node target. - name (str): Name of the fx node with optional suffix. - - Returns: - A TensorRT tensor represent the result of sign operator. - """ - input_exp_output = add_unary_layer( - network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp" - ) - input_abs_output = add_unary_layer( - network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs" - ) - input_abs_exp_output = add_unary_layer( - network, - input_abs_output, - trt.UnaryOperation.EXP, - target, - f"{name}_prod_abs_exp", - ) - floor_div_output = add_binary_elementwise_layer( - network, - input_exp_output, - input_abs_exp_output, - trt.ElementWiseOperation.FLOOR_DIV, - target, - f"{name}_exp_floor_div", - ) - double_floor_div_output = add_binary_elementwise_layer( - network, - floor_div_output, - 2, - trt.ElementWiseOperation.PROD, - target, - f"{name}_floor_div*2", - ) - return add_binary_elementwise_layer( - network, - double_floor_div_output, - 1, - trt.ElementWiseOperation.SUB, - target, - f"{name}_sign", - ) - - -def trunc_div( - input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str -) -> TRTTensor: - """ - Perform trunc divide on Tensor, result of divide will be round toward zero. - This means for positive number, it will be floor round; for negative number, - it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. - - Args: - input: divisor. - other: dividend. - network: INetworkDefinition. - target: node target. - name: namespace for the op - - Returns: - A TensorRT tensor represent the result of trunc divide. - """ - prod_output = add_binary_elementwise_layer( - network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod" - ) - sign_output = sign(network, prod_output, target, name) - - # Convert constant input into ITensor for UnaryOperation - if not isinstance(input, trt.tensorrt.ITensor): - input = get_trt_tensor(network, input, f"{name}_input") - if not isinstance(other, trt.tensorrt.ITensor): - other = get_trt_tensor( - network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype) - ) - - abs_input_output = add_unary_layer( - network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input" - ) - abs_other_output = add_unary_layer( - network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other" - ) - abs_floor_output = add_binary_elementwise_layer( - network, - abs_input_output, - abs_other_output, - trt.ElementWiseOperation.FLOOR_DIV, - target, - f"{name}_floor_div", - ) - output = add_binary_elementwise_layer( - network, - abs_floor_output, - sign_output, - trt.ElementWiseOperation.PROD, - target, - f"{name}_output", - ) - - return output - - def dtype_uniform( network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor ): diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index ffca964633..e24e6509ca 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -162,6 +162,65 @@ def add_binary_elementwise_layer( return output +def sign( + network: TRTNetwork, input_val: TRTTensor, target: Target, name: str +) -> TRTTensor: + """ + Sign is calculated as below: + x = input + sign = (exp(x) // exp(abs(x))) * 2 - 1 + For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. + With multiply 2, the value become 2(for pos and 0) and 0(for neg). + Finally minus 1, the value become 1(for pos and 0) and -1(for neg). + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): The input tensor. + target (Target): fx node target. + name (str): Name of the fx node with optional suffix. + + Returns: + A TensorRT tensor represent the result of sign operator. + """ + input_exp_output = add_unary_layer( + network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp" + ) + input_abs_output = add_unary_layer( + network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs" + ) + input_abs_exp_output = add_unary_layer( + network, + input_abs_output, + trt.UnaryOperation.EXP, + target, + f"{name}_prod_abs_exp", + ) + floor_div_output = add_binary_elementwise_layer( + network, + input_exp_output, + input_abs_exp_output, + trt.ElementWiseOperation.FLOOR_DIV, + target, + f"{name}_exp_floor_div", + ) + double_floor_div_output = add_binary_elementwise_layer( + network, + floor_div_output, + 2, + trt.ElementWiseOperation.PROD, + target, + f"{name}_floor_div*2", + ) + return add_binary_elementwise_layer( + network, + double_floor_div_output, + 1, + trt.ElementWiseOperation.SUB, + target, + f"{name}_sign", + ) + + def trunc_div( input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str ) -> TRTTensor: @@ -701,7 +760,7 @@ def add_pow(network, target, kwargs, name): return add_binary_elementwise_layer( network, kwargs["input"], - kwargs["other"], + kwargs["exponent"], trt.ElementWiseOperation.POW, target, name, From 36f9a3f61573ad4c03b1c04b7d12da51fe501111 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 21 Mar 2023 16:15:05 -0700 Subject: [PATCH 09/39] Removing the aten.slice and add_slice --- .../fx/converters/aten_ops_converters.py | 18 ------- py/torch_tensorrt/fx/converters/operator.py | 49 ------------------- 2 files changed, 67 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 4399f37a44..0711a94f00 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -481,24 +481,6 @@ def aten_ops_sym_size( return slice_layer.get_output(0) -@tensorrt_converter(torch.ops.aten.slice.Tensor) -def aten_ops_slice( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - "dim": args[1], - "start": args[2], - "stop": args[3], - "step": args[4], - } - return add_slice(network, target.kwargs_new, name) - - @tensorrt_converter(torch.ops.aten.leaky_relu.default) def aten_ops_leaky_relu( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index e24e6509ca..255e8a937b 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1016,52 +1016,3 @@ def add_expand(network, target, kwargs, name): layer = network.add_slice(input_val, start=start, shape=shape, stride=stride) set_layer_name(layer, target, name) return layer.get_output(0) - - -def add_slice(network, target, kwargs, name): - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"slice_tensor received input {input_val} that is not part " - "of the TensorRT region!" - ) - - ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) - dynamic_shape = has_dynamic_shape(input_val.shape) - if network.has_implicit_batch_dimension: - if dim == 0: - raise RuntimeError( - f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" - ) - dim = dim - 1 - else: - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - - start_int = cast(int, kwargs["start"]) - stop_int = cast(int, kwargs["stop"]) - step_int = cast(int, kwargs["step"]) - start = [0] * len(input_val.shape) - start[dim] = start_int - stride = [1] * len(start) - stride[dim] = step_int - output_shape = list(input_val.shape) - output_shape[dim] = (stop_int - start_int) // step_int - - if dynamic_shape > 0: - output_shape = get_shape_with_dynamic_shape( - network, output_shape, input_val, target, name - ) - layer = network.add_slice( - input_val, - start=start, - shape=[] if dynamic_shape else output_shape, - stride=stride, - ) - if dynamic_shape: - layer.set_input(2, output_shape) - set_layer_name(layer, target, name) - return layer.get_output(0) From 979ab42c0982b37fe7594e2df6d389501b6a2d18 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 22 Mar 2023 12:13:36 -0700 Subject: [PATCH 10/39] correcting selu, hard_tanh ops, adding tests for sigmoid, selu, elu and hard_tanh --- py/torch_tensorrt/fx/converters/activation.py | 21 ++------ .../fx/converters/aten_ops_converters.py | 31 +++++------ .../test/converters/aten_op/test_elu_aten.py | 51 ++++++++++++++++++ .../converters/aten_op/test_hard_tanh_aten.py | 53 +++++++++++++++++++ .../aten_op/test_leaky_relu_aten.py | 2 +- .../test/converters/aten_op/test_selu_aten.py | 51 ++++++++++++++++++ .../converters/aten_op/test_sigmoid_aten.py | 53 +++++++++++++++++++ 7 files changed, 227 insertions(+), 35 deletions(-) create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_hard_tanh_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_sigmoid_aten.py diff --git a/py/torch_tensorrt/fx/converters/activation.py b/py/torch_tensorrt/fx/converters/activation.py index 118f6d3105..e27da49f1d 100644 --- a/py/torch_tensorrt/fx/converters/activation.py +++ b/py/torch_tensorrt/fx/converters/activation.py @@ -184,25 +184,14 @@ def add_sigmoid(network, target, kwargs, name): def add_hard_tanh(network, target, kwargs, name): input_val = kwargs["input"] - operation_type = trt.ActivationType.TANH - return add_activation_layer(network, input_val, operation_type, target, name) - - -def add_sigmoid(network, target, kwargs, name): - input_val = kwargs["input"] - + alpha = kwargs["min_val"] + beta = kwargs["max_val"] if not isinstance(input_val, TRTTensor): raise RuntimeError( - f"Hard sigmoid received input {input_val} that is not part " + f"hardtanh received input {input_val} that is not part " "of the TensorRT region!" ) - + operation_type = trt.ActivationType.CLIP return add_activation_layer( - network, - input_val, - trt.ActivationType.HARD_SIGMOID, - target, - name, - alpha=1 / 6, - beta=0.5, + network, input_val, operation_type, target, name, alpha, beta ) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 0711a94f00..3f55ab3827 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -503,26 +503,19 @@ def aten_ops_elu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - } + if len(args) > 2: + kwargs_new = { + "input": args[0], + } + return add_selu(network, target, kwargs_new, name) + else: + kwargs_new = { + "input": args[0], + "alpha": args[1], + } return add_elu(network, target, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.selu.default) -def aten_ops_selu( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - } - return add_selu(network, target, kwargs_new, name) - - @tensorrt_converter(torch.ops.aten.gelu.default) def aten_ops_gelu( network: TRTNetwork, @@ -551,7 +544,7 @@ def aten_ops_tanh( return add_tanh(network, target, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.sigmoid.default) +@tensorrt_converter(torch.ops.aten.hardtanh.default) def aten_ops_hard_tanh( network: TRTNetwork, target: Target, @@ -561,6 +554,8 @@ def aten_ops_hard_tanh( ) -> Union[TRTTensor, Sequence[TRTTensor]]: kwargs_new = { "input": args[0], + "min_val": args[1], + "max_val": args[2], } return add_hard_tanh(network, target, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py new file mode 100644 index 0000000000..cd8ef1b48a --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestELUConverter(DispatchTestCase): + def test_elu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_elu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_elu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_hard_tanh_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_hard_tanh_aten.py new file mode 100644 index 0000000000..644c09345f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_hard_tanh_aten.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestHardTanHConverter(DispatchTestCase): + def test_hardtanh(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + def test_hardtanh_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + def test_hardtanh_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py index e694acf05a..9c26e441ff 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py @@ -4,7 +4,7 @@ from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec -class TestReLUConverter(DispatchTestCase): +class TestLeakyReLUConverter(DispatchTestCase): def test_leaky_relu(self): class TestModule(nn.Module): def forward(self, x): diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py new file mode 100644 index 0000000000..a6e501daa0 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSeLUConverter(DispatchTestCase): + def test_selu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_selu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_selu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_sigmoid_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_sigmoid_aten.py new file mode 100644 index 0000000000..d9557d271b --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_sigmoid_aten.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSigmoidConverter(DispatchTestCase): + def test_sigmoid(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + +if __name__ == "__main__": + run_tests() From 79286e7390f7c91ccbcc63723f6ccc06f860eb1a Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 22 Mar 2023 13:25:25 -0700 Subject: [PATCH 11/39] Moving funcs to_numpy and trt_dtype_to_torch_dtype from converter_util to operator --- .../fx/converters/converter_utils.py | 34 ------------------- py/torch_tensorrt/fx/converters/operator.py | 34 +++++++++++++++++++ 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 062d4dde5b..ba6421701e 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -120,30 +120,6 @@ def extend_mod_attr_to_tuple(mod: torch.nn.Module, name: str, size: int): return extend_attr_to_tuple(val, size) -def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: - """ - Convert a PyTorch Tensor to a Numpy Array. If the tensor is - quantized it will be dequantized first. - - Args: - tensor (Optional[torch.Tensor]): A PyTorch tensor or None. - - Returns: - A Numpy array. - """ - - if tensor is None: - return tensor - - assert isinstance( - tensor, torch.Tensor - ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}" - if tensor.is_quantized: - tensor = tensor.dequantize() - - return tensor.cpu().detach().contiguous().numpy() - - def has_dynamic_shape(shape: Shape) -> bool: """ Determine if the given shape has dynamic dim. i.e. if there're -1 in shape. @@ -567,13 +543,3 @@ def type_cast( layer_i.set_output_type(0, cast_type) set_layer_name(layer_i, target, f"{name}_dtype_change") return layer_i.get_output(0) - - -def trt_dtype_to_torch_dtype(trt_dtype): - table = { - trt.bool: torch.bool, - trt.int32: torch.int32, - trt.float16: torch.float16, - trt.float32: torch.float32, - } - return table[trt_dtype] diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 255e8a937b..539b766336 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -278,6 +278,40 @@ def trunc_div( return output +def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: + """ + Convert a PyTorch Tensor to a Numpy Array. If the tensor is + quantized it will be dequantized first. + + Args: + tensor (Optional[torch.Tensor]): A PyTorch tensor or None. + + Returns: + A Numpy array. + """ + + if tensor is None: + return tensor + + assert isinstance( + tensor, torch.Tensor + ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}" + if tensor.is_quantized: + tensor = tensor.dequantize() + + return tensor.cpu().detach().contiguous().numpy() + + +def trt_dtype_to_torch_dtype(trt_dtype): + table = { + trt.bool: torch.bool, + trt.int32: torch.int32, + trt.float16: torch.float16, + trt.float32: torch.float32, + } + return table[trt_dtype] + + def add_tile(network, target, kwargs, name): input_t = kwargs["input"] input_val = get_trt_tensor(network, input_t, f"{name}_input") From 2bcf5f4cb08f8cf51f6f5f39c608d2ffac77a919 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 22 Mar 2023 14:12:53 -0700 Subject: [PATCH 12/39] Move to_numpy implementation to converter_util --- .../fx/converters/converter_utils.py | 24 ++++++++++++++++++ py/torch_tensorrt/fx/converters/operator.py | 25 +------------------ 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index ba6421701e..0b9561b6db 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -543,3 +543,27 @@ def type_cast( layer_i.set_output_type(0, cast_type) set_layer_name(layer_i, target, f"{name}_dtype_change") return layer_i.get_output(0) + + +def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: + """ + Convert a PyTorch Tensor to a Numpy Array. If the tensor is + quantized it will be dequantized first. + + Args: + tensor (Optional[torch.Tensor]): A PyTorch tensor or None. + + Returns: + A Numpy array. + """ + + if tensor is None: + return tensor + + assert isinstance( + tensor, torch.Tensor + ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}" + if tensor.is_quantized: + tensor = tensor.dequantize() + + return tensor.cpu().detach().contiguous().numpy() diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 539b766336..140b293223 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -22,6 +22,7 @@ from .converter_utils import prepend_ones from .converter_utils import has_dynamic_shape from .converter_utils import get_shape_with_dynamic_shape +from .converter_utils import to_numpy from ..types import ( Shape, @@ -278,30 +279,6 @@ def trunc_div( return output -def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: - """ - Convert a PyTorch Tensor to a Numpy Array. If the tensor is - quantized it will be dequantized first. - - Args: - tensor (Optional[torch.Tensor]): A PyTorch tensor or None. - - Returns: - A Numpy array. - """ - - if tensor is None: - return tensor - - assert isinstance( - tensor, torch.Tensor - ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}" - if tensor.is_quantized: - tensor = tensor.dequantize() - - return tensor.cpu().detach().contiguous().numpy() - - def trt_dtype_to_torch_dtype(trt_dtype): table = { trt.bool: torch.bool, From 96bdead71c60fe08b51812c959fa8566f5d52c24 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 22 Mar 2023 15:18:21 -0700 Subject: [PATCH 13/39] Implementation of slice and select operations --- .../fx/converters/aten_ops_converters.py | 16 +++++ .../fx/converters/converter_utils.py | 24 +++++++ py/torch_tensorrt/fx/converters/operator.py | 66 ++++++++++++------- .../converters/aten_op/test_select_aten.py | 56 ++++++++++++++++ 4 files changed, 138 insertions(+), 24 deletions(-) create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 3f55ab3827..275887ae57 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -572,3 +572,19 @@ def aten_ops_sigmoid( "input": args[0], } return add_sigmoid(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.select) +def aten_ops_select( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + "index": args[2], + } + return add_select(network, target.kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index ba6421701e..0b9561b6db 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -543,3 +543,27 @@ def type_cast( layer_i.set_output_type(0, cast_type) set_layer_name(layer_i, target, f"{name}_dtype_change") return layer_i.get_output(0) + + +def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: + """ + Convert a PyTorch Tensor to a Numpy Array. If the tensor is + quantized it will be dequantized first. + + Args: + tensor (Optional[torch.Tensor]): A PyTorch tensor or None. + + Returns: + A Numpy array. + """ + + if tensor is None: + return tensor + + assert isinstance( + tensor, torch.Tensor + ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}" + if tensor.is_quantized: + tensor = tensor.dequantize() + + return tensor.cpu().detach().contiguous().numpy() diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 539b766336..de9ac6bc4e 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -22,6 +22,7 @@ from .converter_utils import prepend_ones from .converter_utils import has_dynamic_shape from .converter_utils import get_shape_with_dynamic_shape +from .converter_utils import to_numpy from ..types import ( Shape, @@ -278,30 +279,6 @@ def trunc_div( return output -def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: - """ - Convert a PyTorch Tensor to a Numpy Array. If the tensor is - quantized it will be dequantized first. - - Args: - tensor (Optional[torch.Tensor]): A PyTorch tensor or None. - - Returns: - A Numpy array. - """ - - if tensor is None: - return tensor - - assert isinstance( - tensor, torch.Tensor - ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}" - if tensor.is_quantized: - tensor = tensor.dequantize() - - return tensor.cpu().detach().contiguous().numpy() - - def trt_dtype_to_torch_dtype(trt_dtype): table = { trt.bool: torch.bool, @@ -1050,3 +1027,44 @@ def add_expand(network, target, kwargs, name): layer = network.add_slice(input_val, start=start, shape=shape, stride=stride) set_layer_name(layer, target, name) return layer.get_output(0) + + +def add_select(network, target, kwargs, name): + input_val = kwargs["input"] + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input_val} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert ( + input_val.shape[dim] != -1 + ), "Can't select on negative shape dimension!" + index = kwargs[2] + if index >= input_val.shape[dim]: + raise RuntimeError( + f"cannot have index greater than the dimension length! {input_val.shape[dim]}" + ) + output_shape = list(input_val.shape) + output_shape[dim] = 1 + if dynamic_shape > 0: + output_shape = get_shape_with_dynamic_shape( + network, output_shape, input_val, target, name + ) + layer = network.add_gather(input_val, dim, index) + out = layer.getOutput(0) + if len(out.shape) != 1: + layer = network.add_shuffle(out) + return layer.getOutput(0) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py new file mode 100644 index 0000000000..8868db2668 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py @@ -0,0 +1,56 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSelectConverter(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 2, 1), + ] + ) + def test_select(self, _, dim_test, index_test): + class TestModule(torch.nn.Module): + def __init__(self, dim, index): + super().__init__() + self.dim = dim + self.index = index + + def forward(self, input): + return torch.select(input, self.dim, self.index) + + input = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(dim_test, index_test), + input, + expected_ops={torch.ops.aten.select}, + test_explicit_precision=True, + ) + + # def test_select_with_dynamic_shape(self, _, dim_test, index_test): + # class TestModule(torch.nn.Module): + # def __init__(self, dim, index): + # super().__init__() + # self.dim = dim + # self.index = index + # def forward(self, input): + # return torch.select(input, self.dim, self.index) + + # input_spec = [ + # InputTensorSpec( + # shape=(-1, 3, 32), + # dtype=torch.float32, + # shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # TestModule(dim_test, index_test), input_spec, expected_ops={torch.ops.aten.select} + # ) + + +if __name__ == "__main__": + run_tests() From a1d94c1c7a9e4be7769d605da480b18b60d44be7 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 22 Mar 2023 16:51:19 -0700 Subject: [PATCH 14/39] Fixing the acc tests, logical_and operator and the leaky_relu test --- .../fx/converters/acc_ops_converters.py | 11 +- .../fx/converters/aten_ops_converters.py | 4 +- .../fx/converters/converter_utils.py | 73 ----------- py/torch_tensorrt/fx/converters/operator.py | 115 ++++++++++++++++-- .../aten_op/test_leaky_relu_aten.py | 6 +- 5 files changed, 115 insertions(+), 94 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 0e77a1c659..e556e81bb5 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -1262,12 +1262,7 @@ def acc_ops_logical_not( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.UnaryOperation.NOT - # cast to bool type - if input_val.dtype in (trt.float32, trt.float16, trt.int32): - input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool) - return add_unary_layer(network, input_val, operation_type, target, name) + return add_logical_not(network, target, kwargs, name) @tensorrt_converter(acc_ops.logical_and, no_implicit_batch_dim=True) @@ -2335,7 +2330,7 @@ def acc_ops_getitem( input_val = kwargs["input"] slices = kwargs["idx"] if not isinstance(input_val, TRTTensor): - return getitem(input_val, slices) # type: ignore[arg-type] + return operator.getitem(input_val, slices) # type: ignore[arg-type] if not isinstance(slices, tuple) and not isinstance(slices, list): slices = (slices,) @@ -2803,7 +2798,7 @@ def acc_ops_hardtanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_hardtanh(network, target, kwargs, name) + return add_hard_tanh(network, target, kwargs, name) @tensorrt_converter(acc_ops.interpolate) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 3f55ab3827..42aaded2d3 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -489,9 +489,7 @@ def aten_ops_leaky_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - } + kwargs_new = {"input": args[0], "negative_slope": args[1]} return add_leaky_relu(network, target, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 0b9561b6db..551c18652d 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -328,63 +328,6 @@ def broadcast( return a, b -def get_shape_with_dynamic_shape( - network: TRTNetwork, - shape: Union[list, tuple, torch.Tensor], - input_val: TRTTensor, - target: Target, - name: str, -) -> TRTTensor: - """ - Prepare the real output tensor shape for dynamic shape mode tensor input. - How this functions works: - Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation - output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual - reduce operation output shape. Steps of calculations are: - 1. get the actual tensor shape of input_val via add_shape layer; - 2. create a all 0 tensor [0, 0, 0]; - 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; - 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace - all -1 dynamic shape dimensions with actual batch_size value; - 5. output shape with actual batch_size as [2048, 128, 256] - - Args: - network (TRTNetwork): TensorRT network object. - shape: calculated shape of the expected output tensor - input_val (TRTTensor): A TensorRT ITensor. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - Returns: - TensorRT ITensors that represents the actual shape of the input_val - """ - # Ger real shape info for input_val - input_shape = network.add_shape(input_val).get_output(0) - - scale_layer = network.add_constant( - input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) - ) - set_layer_name(scale_layer, target, f"{name}_scale") - scale_res = scale_layer.get_output(0) - - length = input_shape.shape[0] - zero_layer = network.add_constant( - input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) - ) - set_layer_name(zero_layer, target, f"{name}_zeros") - - condition_val = operator.add_binary_elementwise_layer( - network, - scale_res, - zero_layer.get_output(0), - trt.ElementWiseOperation.LESS, - target, - f"{name}_shape", - ) - select_layer = network.add_select(condition_val, input_shape, scale_res) - set_layer_name(select_layer, target, f"{name}_select") - return select_layer.get_output(0) - - def squeeze_left(const: torch.Tensor): """ Squeeze the size-1 dimensions on the left side of the shape tuple. @@ -529,22 +472,6 @@ def dtype_uniform( return input, other -def type_cast( - network: TRTNetwork, - target: Target, - name: str, - input: TRTTensor, - cast_type: TRTDataType, -): - """ - This function helps to cast the input type to cast_type - """ - layer_i = network.add_identity(input) - layer_i.set_output_type(0, cast_type) - set_layer_name(layer_i, target, f"{name}_dtype_change") - return layer_i.get_output(0) - - def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: """ Convert a PyTorch Tensor to a Numpy Array. If the tensor is diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 140b293223..ccee6b1c42 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -21,7 +21,6 @@ from .converter_utils import get_positive_dim from .converter_utils import prepend_ones from .converter_utils import has_dynamic_shape -from .converter_utils import get_shape_with_dynamic_shape from .converter_utils import to_numpy from ..types import ( @@ -289,6 +288,79 @@ def trt_dtype_to_torch_dtype(trt_dtype): return table[trt_dtype] +def get_shape_with_dynamic_shape( + network: TRTNetwork, + shape: Union[list, tuple, torch.Tensor], + input_val: TRTTensor, + target: Target, + name: str, +) -> TRTTensor: + """ + Prepare the real output tensor shape for dynamic shape mode tensor input. + How this functions works: + Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation + output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual + reduce operation output shape. Steps of calculations are: + 1. get the actual tensor shape of input_val via add_shape layer; + 2. create a all 0 tensor [0, 0, 0]; + 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; + 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace + all -1 dynamic shape dimensions with actual batch_size value; + 5. output shape with actual batch_size as [2048, 128, 256] + + Args: + network (TRTNetwork): TensorRT network object. + shape: calculated shape of the expected output tensor + input_val (TRTTensor): A TensorRT ITensor. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + Returns: + TensorRT ITensors that represents the actual shape of the input_val + """ + # Ger real shape info for input_val + input_shape = network.add_shape(input_val).get_output(0) + + scale_layer = network.add_constant( + input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) + ) + set_layer_name(scale_layer, target, f"{name}_scale") + scale_res = scale_layer.get_output(0) + + length = input_shape.shape[0] + zero_layer = network.add_constant( + input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) + ) + set_layer_name(zero_layer, target, f"{name}_zeros") + + condition_val = add_binary_elementwise_layer( + network, + scale_res, + zero_layer.get_output(0), + trt.ElementWiseOperation.LESS, + target, + f"{name}_shape", + ) + select_layer = network.add_select(condition_val, input_shape, scale_res) + set_layer_name(select_layer, target, f"{name}_select") + return select_layer.get_output(0) + + +def type_cast( + network: TRTNetwork, + target: Target, + name: str, + input: TRTTensor, + cast_type: TRTDataType, +): + """ + This function helps to cast the input type to cast_type + """ + layer_i = network.add_identity(input) + layer_i.set_output_type(0, cast_type) + set_layer_name(layer_i, target, f"{name}_dtype_change") + return layer_i.get_output(0) + + def add_tile(network, target, kwargs, name): input_t = kwargs["input"] input_val = get_trt_tensor(network, input_t, f"{name}_input") @@ -822,25 +894,54 @@ def add_minimum(network, target, kwargs, name): ) +def add_logical_not(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.NOT + # cast to bool type + if input_val.dtype in (trt.float32, trt.float16, trt.int32): + input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool) + return add_unary_layer(network, input_val, operation_type, target, name) + + def add_logical_and(network, target, kwargs, name): if network.has_implicit_batch_dimension: raise RuntimeError( - "The `ne` function should be called with explicit batch dimension." + "The `logical_and` function should be called with explicit batch dimension." ) input_t = kwargs["input"] other_t = kwargs["other"] + # we only support both inputs are bool type + if target == acc_ops.bitwise_and: + + def check_is_bool(input_t): + if isinstance(input_t, TRTTensor): + assert ( + input_t.dtype == trt.bool + ), "We currently do not support input is non-bool" + elif isinstance(input_t, torch.Tensor): + assert ( + input_t.dtype == torch.bool + ), "We currently do not support input is non-bool" + else: + assert isinstance( + input_t.bool + ), "We currently do not support input is non-bool" + + check_is_bool(input_t) + check_is_bool(other_t) input_t = get_trt_tensor(network, input_t, f"{name}_input_t") other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - eq_t = add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + if input_t.dtype != trt.bool: + input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool) + if other_t.dtype != trt.bool: + other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool) + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.AND, target, name ) - return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) - def add_ne(network, target, kwargs, name): if network.has_implicit_batch_dimension: diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py index 9c26e441ff..7cdce77092 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py @@ -8,7 +8,7 @@ class TestLeakyReLUConverter(DispatchTestCase): def test_leaky_relu(self): class TestModule(nn.Module): def forward(self, x): - return nn.functional.leaky_relu(x) + return nn.functional.leaky_relu(x, negative_slope=0.05) inputs = [torch.randn(1, 10)] self.run_test( @@ -18,7 +18,7 @@ def forward(self, x): def test_leaky_relu_with_dynamic_shape(self): class TestModule(nn.Module): def forward(self, x): - return nn.functional.leaky_relu(x) + return nn.functional.leaky_relu(x, negative_slope=0.05) input_specs = [ InputTensorSpec( @@ -34,7 +34,7 @@ def forward(self, x): def test_leaky_relu_with_dynamic_shape_four_dimensions(self): class TestModule(nn.Module): def forward(self, x): - return nn.functional.leaky_relu(x) + return nn.functional.leaky_relu(x, negative_slope=0.05) input_specs = [ InputTensorSpec( From c8811cd1788cf220fc4d809e34b56cc28c71a19f Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 22 Mar 2023 17:52:02 -0700 Subject: [PATCH 15/39] select test implementation --- .../fx/converters/aten_ops_converters.py | 4 ++-- py/torch_tensorrt/fx/converters/operator.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 042298a471..228194fbe9 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -572,7 +572,7 @@ def aten_ops_sigmoid( return add_sigmoid(network, target, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.select) +@tensorrt_converter(torch.ops.aten.select.int) def aten_ops_select( network: TRTNetwork, target: Target, @@ -585,4 +585,4 @@ def aten_ops_select( "dim": args[1], "index": args[2], } - return add_select(network, target.kwargs_new, name) + return add_select(network, target, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index bdda8338d1..cf5ffe0349 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1153,7 +1153,8 @@ def add_select(network, target, kwargs, name): assert ( input_val.shape[dim] != -1 ), "Can't select on negative shape dimension!" - index = kwargs[2] + index = kwargs["index"] + if index >= input_val.shape[dim]: raise RuntimeError( f"cannot have index greater than the dimension length! {input_val.shape[dim]}" @@ -1164,8 +1165,11 @@ def add_select(network, target, kwargs, name): output_shape = get_shape_with_dynamic_shape( network, output_shape, input_val, target, name ) - layer = network.add_gather(input_val, dim, index) - out = layer.getOutput(0) + input_shape = network.add_shape(input_val).get_output(0) + dim_value = torch.tensor(dim, dtype=torch.int32) + axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0) + layer = network.add_gather(input_shape, axis, index) + out = layer.get_output(0) if len(out.shape) != 1: layer = network.add_shuffle(out) - return layer.getOutput(0) + return layer.get_output(0) From 4f18c0f800e2580c97a988c78d9c747b16fe850f Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 22 Mar 2023 17:52:40 -0700 Subject: [PATCH 16/39] select aten test --- .../fx/test/converters/aten_op/test_select_aten.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py index 8868db2668..e21ab0dd61 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py @@ -13,21 +13,19 @@ class TestSelectConverter(DispatchTestCase): ("select_dim_index", 2, 1), ] ) - def test_select(self, _, dim_test, index_test): + def test_select(self, _, dim, index): class TestModule(torch.nn.Module): - def __init__(self, dim, index): + def __init__(self): super().__init__() - self.dim = dim - self.index = index def forward(self, input): - return torch.select(input, self.dim, self.index) + return torch.select(input, dim, index) input = [torch.randn(1, 3, 32)] self.run_test( - TestModule(dim_test, index_test), + TestModule(), input, - expected_ops={torch.ops.aten.select}, + expected_ops={torch.ops.aten.select.int}, test_explicit_precision=True, ) From cf96dec63bc81cf931aef2b95767f6ce5ddbbc1a Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 22 Mar 2023 17:59:20 -0700 Subject: [PATCH 17/39] Adding add_slice function in operator.py --- py/torch_tensorrt/fx/converters/operator.py | 48 +++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index ccee6b1c42..bec0f6f0fc 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1128,3 +1128,51 @@ def add_expand(network, target, kwargs, name): layer = network.add_slice(input_val, start=start, shape=shape, stride=stride) set_layer_name(layer, target, name) return layer.get_output(0) + + +def add_slice(network, target, kwargs, name): + input_val = kwargs["input"] + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input_val} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + + start_int = cast(int, kwargs["start"]) + stop_int = cast(int, kwargs["stop"]) + step_int = cast(int, kwargs["step"]) + start = [0] * len(input_val.shape) + start[dim] = start_int + stride = [1] * len(start) + stride[dim] = step_int + output_shape = list(input_val.shape) + output_shape[dim] = (stop_int - start_int) // step_int + + if dynamic_shape > 0: + output_shape = get_shape_with_dynamic_shape( + network, output_shape, input_val, target, name + ) + layer = network.add_slice( + input_val, + start=start, + shape=[] if dynamic_shape else output_shape, + stride=stride, + ) + if dynamic_shape: + layer.set_input(2, output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) From 8303cd55669177b189f1000e769e8d55356836ef Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 24 Mar 2023 16:50:22 -0700 Subject: [PATCH 18/39] aten::matmul, aten::slice, aten::select converters --- .../fx/converters/aten_ops_converters.py | 34 +++++++ py/torch_tensorrt/fx/converters/operator.py | 89 ++++++++++++++++++- .../converters/aten_op/test_matmul_aten.py | 27 ++++++ .../converters/aten_op/test_select_aten.py | 73 ++++++++++----- .../converters/aten_op/test_slice_aten.py | 58 ++++++++++++ 5 files changed, 255 insertions(+), 26 deletions(-) create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 228194fbe9..1dbfa14076 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -586,3 +586,37 @@ def aten_ops_select( "index": args[2], } return add_select(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.slice.Tensor) +def aten_ops_slice( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + "start": args[2], + "stop": args[3], + "step": args[4], + } + return add_slice(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.matmul) +@tensorrt_converter(torch.ops.aten.mm.default) +def aten_ops_matmul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return add_matmul(network, target, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index cf5ffe0349..5955e598f5 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1165,11 +1165,92 @@ def add_select(network, target, kwargs, name): output_shape = get_shape_with_dynamic_shape( network, output_shape, input_val, target, name ) - input_shape = network.add_shape(input_val).get_output(0) - dim_value = torch.tensor(dim, dtype=torch.int32) - axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0) - layer = network.add_gather(input_shape, axis, index) + index_value = torch.tensor(index, dtype=torch.int32) + indices_tensor = network.add_constant( + index_value.shape, to_numpy(index_value) + ).get_output(0) + layer = network.add_gather(input_val, indices_tensor, dim) out = layer.get_output(0) if len(out.shape) != 1: layer = network.add_shuffle(out) return layer.get_output(0) + + +def add_slice(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input_val} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + + start_int = cast(int, kwargs["start"]) + stop_int = cast(int, kwargs["stop"]) + step_int = cast(int, kwargs["step"]) + start = [0] * len(input_val.shape) + start[dim] = start_int + stride = [1] * len(start) + stride[dim] = step_int + output_shape = list(input_val.shape) + output_shape[dim] = (stop_int - start_int) // step_int + 1 + + if dynamic_shape > 0: + output_shape = get_shape_with_dynamic_shape( + network, output_shape, input_val, target, name + ) + layer = network.add_slice( + input_val, + start=start, + shape=[] if dynamic_shape else output_shape, + stride=stride, + ) + if dynamic_shape: + layer.set_input(2, output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_matmul(network, target, kwargs, name): + input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input") + other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other") + + for i in [input_val, other_val]: + if not isinstance(i, TRTTensor): + raise RuntimeError( + f"matmul received input {i} that is not part of the TensorRT region!" + ) + + input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE + preset_diff = 0 + + if len(input_val.shape) == 1: + preset_diff -= 1 + input_matrix_op = trt.MatrixOperation.VECTOR + + if len(other_val.shape) == 1: + preset_diff += 1 + other_matrix_op = trt.MatrixOperation.VECTOR + + input_val, other_val = broadcast( + network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff + ) + layer = network.add_matrix_multiply( + input_val, input_matrix_op, other_val, other_matrix_op + ) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py new file mode 100644 index 0000000000..0b0cd8d0b5 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py @@ -0,0 +1,27 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestMatMulConverter(DispatchTestCase): + def test_matmul(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + return torch.matmul(x, y) + + inputOne = torch.randn(2, 32) + inputTwo = torch.randn(32, 2) + inputs = [inputOne, inputTwo] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py index e21ab0dd61..1d5cb84f31 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py @@ -7,10 +7,10 @@ from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec -class TestSelectConverter(DispatchTestCase): +class TestSelectConverterOne(DispatchTestCase): @parameterized.expand( [ - ("select_dim_index", 2, 1), + ("select_dim_index", 1, 0), ] ) def test_select(self, _, dim, index): @@ -21,7 +21,7 @@ def __init__(self): def forward(self, input): return torch.select(input, dim, index) - input = [torch.randn(1, 3, 32)] + input = [torch.randn(1, 2)] self.run_test( TestModule(), input, @@ -29,25 +29,54 @@ def forward(self, input): test_explicit_precision=True, ) - # def test_select_with_dynamic_shape(self, _, dim_test, index_test): - # class TestModule(torch.nn.Module): - # def __init__(self, dim, index): - # super().__init__() - # self.dim = dim - # self.index = index - # def forward(self, input): - # return torch.select(input, self.dim, self.index) - - # input_spec = [ - # InputTensorSpec( - # shape=(-1, 3, 32), - # dtype=torch.float32, - # shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))], - # ), - # ] - # self.run_test_with_dynamic_shape( - # TestModule(dim_test, index_test), input_spec, expected_ops={torch.ops.aten.select} - # ) + +class TestSelectConverterTwo(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input = [torch.randn(4, 4, 4, 4)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.select.int}, + test_explicit_precision=True, + ) + + +class TestSelectConverterWithDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select_with_dynamic_shape(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input_spec = [ + InputTensorSpec( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_spec, expected_ops={torch.ops.aten.select.int} + ) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py new file mode 100644 index 0000000000..b018aff73e --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py @@ -0,0 +1,58 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSelectConverterImplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 0, 0, 7, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input = [torch.randn(10, 2, 3, 1)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + +class TestSelectConverterExplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 1, 0, 7, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input = [torch.randn(10, 10, 3, 1)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.slice.Tensor}, + test_explicit_precision=True, + ) + + +if __name__ == "__main__": + run_tests() From f1098f2520f30a6082597c93895c350905c8245d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 20 Mar 2023 14:45:35 -0700 Subject: [PATCH 19/39] feat: Add sample torch.compile backend for tensorrt aten path - Add backend adapted from previous `fx2trt_compiler` provided by Dynamo - Currently, the TRTSplitter needs work to fully support the `aten` path - Additionally, the existing `aten` pass was reworked to exclude the `torch._dynamo.export` call, which may be necessary here --- .../fx/tracer/dispatch_tracer/aten_tracer.py | 8 +- .../tensorrt_dynamo_backend.py | 107 ++++++++++++++++++ 2 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index e60c8f8d13..356ddc978e 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -130,7 +130,7 @@ def trace(f, args, *rest): @req_torch_version("2.dev") -def opt_trace(f, args, *rest): +def opt_trace(f, args, perform_trace=True, *rest): """ Optimized trace with necessary passes which re-compose some ops or replace some ops These passes should be general and functional purpose @@ -148,7 +148,11 @@ def opt_trace(f, args, *rest): replace_inplace_ops, # remove it once functionalization is enabled ] - fx_module, _ = trace(f, args) + if perform_trace: + fx_module, _ = trace(f, args) + else: + fx_module = f + print(fx_module.graph) for passes in passes_list: pr: PassResult = passes(fx_module) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py new file mode 100644 index 0000000000..bb6e68b0b5 --- /dev/null +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -0,0 +1,107 @@ +import torch +import traceback +import torch._dynamo as td + +from torch_tensorrt.fx.fx2trt import ( + InputTensorSpec, + TRTInterpreter, +) +import tensorrt as trt +from torch_tensorrt.fx.tools.trt_splitter import ( + TRTSplitter, + TRTSplitterSetting, +) +from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision + +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler + +MAX_SPLITS_THRESHOLD = 10 + + +def tensorrt_backend(gm, sample_inputs): + # Invoke AOTAutograd to compile model + return aot_module_simplified( + gm, + sample_inputs, + fw_compiler=make_boxed_compiler(fx2trt_compiler), + ) + + +def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): + model = gm + inputs = example_inputs + + # Perform lowering pass on model + model = aten_tracer.opt_trace(model, inputs, perform_trace=False) + + # Split out unsupported ops --> Needs rewrite/revision for ATEN + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = False + splitter = TRTSplitter(model, inputs, settings=splitter_setting) + + splitter.node_support_preview() + split_mod = splitter() + num_piece = 0 + + for name, _ in split_mod.named_children(): + print(f"Graph is split into {name}") + num_pieces += 1 + + # Select threshold above which segmentation is not beneficial and run graph in Torch + if num_pieces > MAX_SPLITS_THRESHOLD: + raise AssertionError( + f"The graph module is split into {num_piece} which is large than the \ + threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module." + ) + + precision = LowerPrecision.FP32 + + def get_submod_inputs(mod, submod, inputs): + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + + for name, _ in split_mod.named_children(): + if "_run_on_acc" in name: + submod = getattr(split_mod, name) + acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + logger_level=trt.Logger.VERBOSE, + ) + r = interp.run( + max_workspace_size=20 << 30, + lower_precision=precision, + profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, + ) + + trt_mod = TRTModule(*r) + + setattr(split_mod, name, trt_mod) + + return split_mod + + +@td.register_backend +def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): + try: + trt_compiled = fx2trt(gm, example_inputs) + return trt_compiled + except Exception: + traceback.print_exc() + print( + "FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead" + ) + return gm.forward From 243bf9bc340e27837a33c3d6fc3c0998381aff0a Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 21 Mar 2023 16:17:51 -0700 Subject: [PATCH 20/39] Add decompositions to aot call --- .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index bb6e68b0b5..a76162b93b 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -17,6 +17,9 @@ from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +from torch._inductor.decomposition import decompositions + +DECOMPOSITIONS = decompositions.copy() MAX_SPLITS_THRESHOLD = 10 @@ -26,6 +29,7 @@ def tensorrt_backend(gm, sample_inputs): gm, sample_inputs, fw_compiler=make_boxed_compiler(fx2trt_compiler), + decompositions=DECOMPOSITIONS, ) From 76fd3c8207bdf017af294f1883863a755045b1a8 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 27 Mar 2023 15:31:22 -0700 Subject: [PATCH 21/39] Mark FX2TRT converter as fake tensor unsupported --- .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index a76162b93b..20cea4ffd5 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -15,6 +15,8 @@ from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt.fx.utils import LowerPrecision +from torch._dynamo.backends.common import fake_tensor_unsupported + from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler from torch._inductor.decomposition import decompositions @@ -99,6 +101,7 @@ def get_input(self, inputs): @td.register_backend +@fake_tensor_unsupported def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): try: trt_compiled = fx2trt(gm, example_inputs) From 6a8102c14f3c0fa7a200222979888e9d213d0d84 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 28 Mar 2023 18:52:12 -0700 Subject: [PATCH 22/39] Minor naming bugfix --- .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index 20cea4ffd5..55c5e2df33 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -49,7 +49,7 @@ def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): splitter.node_support_preview() split_mod = splitter() - num_piece = 0 + num_pieces = 0 for name, _ in split_mod.named_children(): print(f"Graph is split into {name}") @@ -58,7 +58,7 @@ def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): # Select threshold above which segmentation is not beneficial and run graph in Torch if num_pieces > MAX_SPLITS_THRESHOLD: raise AssertionError( - f"The graph module is split into {num_piece} which is large than the \ + f"The graph module is split into {num_pieces} which is large than the \ threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module." ) From e97ed50eeb17b661cb7da060b5dd24bc32d9bb43 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 7 Apr 2023 11:12:12 -0700 Subject: [PATCH 23/39] Implementing aten::chunk, aten::layer_norm, aten::softmax, aten::where, aten::rsub, aten::rsqrt --- .../fx/converters/acc_ops_converters.py | 220 +------------- .../fx/converters/aten_ops_converters.py | 113 ++++++++ py/torch_tensorrt/fx/converters/operator.py | 269 +++++++++++++++++- .../converters/aten_op/test_chunk_aten.py | 58 ++++ .../aten_op/test_layer_norm_aten.py | 45 +++ .../converters/aten_op/test_rsqrt_aten.py | 0 .../test/converters/aten_op/test_rsub_aten.py | 0 .../converters/aten_op/test_softmax_aten.py | 44 +++ .../converters/aten_op/test_squeeze_aten.py | 67 +++++ .../converters/aten_op/test_where_aten.py | 56 ++++ 10 files changed, 662 insertions(+), 210 deletions(-) create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index e556e81bb5..a321bb8dfe 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -678,7 +678,13 @@ def acc_ops_batch_norm( @tensorrt_converter(acc_ops.layer_norm) -def acc_ops_layer_norm(network, target, args, kwargs, name): +def acc_ops_layer_norm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_layer_norm(network, target, kwargs, name) @@ -690,37 +696,7 @@ def acc_ops_softmax( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"softmax received input {input_val} that is not part " - "of the TensorRT region!" - ) - - # Used to get dim when dim is None. Copied from PyTorch softmax implementation. - def get_softmax_dim(ndim: int) -> int: - if ndim == 0 or ndim == 1 or ndim == 3: - ret = 0 - else: - ret = 1 - return ret - - if kwargs["dim"] is None: - dim = get_softmax_dim(input_ranks) - else: - dim = cast(int, kwargs["dim"]) - - dim = get_positive_dim(dim, input_ranks) - if network.has_implicit_batch_dimension: - assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." - dim -= 1 - - layer = network.add_softmax(input_val) - layer.axes = 1 << dim - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_softmax(network, target, kwargs, name) @tensorrt_converter(acc_ops.tile) @@ -956,9 +932,7 @@ def acc_ops_sqrt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.UnaryOperation.SQRT - return add_unary_layer(network, input_val, operation_type, target, name) + return add_sqrt(network, target, kwargs, name) @tensorrt_converter(acc_ops.reciprocal) @@ -1619,40 +1593,7 @@ def acc_ops_squeeze( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"squeeze received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) - # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic - # dim, which is a very rare case. For now we just claim not supporting dim=None. - assert dim is not None, "We don't support dim=None right now for squeeze." - - dim = get_positive_dim( - dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - ) - if network.has_implicit_batch_dimension: - assert dim != 0, "We don't support squeeze batch dim when it's implicit." - dim -= 1 - - assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." - assert ( - len(get_dynamic_dims(input_val.shape)) <= 1 - ), "Currently more than one dynamic dim for input to squeeze is not supported." - - output_shape = [] - for i, s in enumerate(input_val.shape): - if i == dim and s == 1: - continue - output_shape.append(s) - layer = network.add_shuffle(input_val) - layer.reshape_dims = tuple(output_shape) - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_squeeze(network, target, kwargs, name) @tensorrt_converter(acc_ops.add) @@ -2022,89 +1963,7 @@ def acc_ops_where( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - - condition_t = kwargs["condition"] - x_t = kwargs["x"] - y_t = kwargs["y"] - - if type(x_t) != TRTTensor: - assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" - - if type(y_t) != TRTTensor: - assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" - - # get output shape - - x_shape = list(x_t.shape) - y_shape = list(y_t.shape) - condition_shape = list(condition_t.shape) - output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) - - # expand shape - if type(condition_t) != TRTTensor: - assert condition_t.dtype == torch.bool, "condition dtype is not bool" - if condition_shape != output_shape: - condition_t.expand(output_shape) - condition_t = condition_t.to(torch.int32) - condition_const = get_trt_tensor(network, condition_t, f"{name}_condition") - condition_layer = network.add_identity(condition_const) - condition_layer.set_output_type(0, trt.bool) - set_layer_name(condition_layer, target, f"{name}_condition") - condition_val = condition_layer.get_output(0) - else: - assert condition_t.dtype == trt.bool, "mask dtype is not bool!" - if condition_shape != output_shape: - condition_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": condition_t, "sizes": output_shape}, - name=f"{name}_expand", - ) - else: - condition_val = condition_t - - if type(x_t) != TRTTensor: - if x_shape != output_shape: - # special case where 1 element in x_t - if len(x_t.shape) == 0: - x_t = x_t.unsqueeze(0) - x_t = x_t.expand(output_shape) - x_val = get_trt_tensor(network, x_t, f"{name}_x") - else: - x_val = x_t - if x_shape != output_shape: - x_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": x_val, "sizes": output_shape}, - name=f"{name}_x_expand", - ) - - if type(y_t) != TRTTensor: - if y_shape != output_shape: - # special case where 1 element in y_t - if len(y_t.shape) == 0: - y_t = y_t.unsqueeze(0) - y_t = y_t.expand(output_shape) - y_val = get_trt_tensor(network, y_t, f"{name}_y") - else: - y_val = y_t - if y_shape != output_shape: - y_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": y_val, "sizes": output_shape}, - name=f"{name}_y_expand", - ) - - select_layer = network.add_select(condition_val, x_val, y_val) - - set_layer_name(select_layer, target, f"{name}_select") - - return select_layer.get_output(0) + return add_where(network, target, kwargs, name) @tensorrt_converter(acc_ops.masked_fill, no_implicit_batch_dim=True) @@ -2721,62 +2580,7 @@ def acc_ops_chunk( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - chunks = cast(int, kwargs["chunks"]) - dim = cast(int, kwargs["dim"]) - input_dim_size = len(input_val.shape) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"chunk received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dynamic_shape = has_dynamic_shape(input_val.shape) - if network.has_implicit_batch_dimension: - input_dim_size += 1 - dim = get_positive_dim(dim, input_dim_size) - assert dim != 0, "Can't chunk on batch dim when it's implicit!" - dim -= 1 - else: - if dynamic_shape: - assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - dim = get_positive_dim(dim, input_dim_size) - - if chunks > input_val.shape[dim]: - warnings.warn( - f"Asked for {chunks} chunks along dimention " - f"{dim} on tensor with size {input_val.shape}, chunks " - f"will default to {input_val.shape[dim]}", - RuntimeWarning, - ) - chunks = input_val.shape[dim] - - start = [0] * len(input_val.shape) - stride = [1] * len(start) - offset = 0 - split_size = (input_val.shape[dim] + chunks - 1) // chunks - - max_offset = input_val.shape[dim] - # add slice layers - output = [] - for i in range(chunks): - shape = list(input_val.shape) - shape[dim] = min(split_size, max_offset - offset) - if dynamic_shape: - shape = get_shape_with_dynamic_shape( - network, shape, input_val, target, f"{name}_{i}" - ) - start[dim] = offset - layer = network.add_slice( - input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride - ) - if dynamic_shape: - layer.set_input(2, shape) - offset += split_size - set_layer_name(layer, target, f"{name}_{i}") - output.append(layer.get_output(0)) - return output + return add_chunk(network, target, kwargs, name) @tensorrt_converter(acc_ops.cumsum, no_implicit_batch_dim=True) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 1dbfa14076..d47f30a790 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -620,3 +620,116 @@ def aten_ops_matmul( "other": args[1], } return add_matmul(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.layer_norm.default) +def aten_ops_layernorm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "normalized_shape": args[1], + "weight": args[2], + "bias": args[3], + "eps": args[4], + } + return add_layer_norm(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten._softmax.default) +def aten_ops_softmax( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + } + return add_softmax(network, target, kwargs_new, name) + + +# FIXME: need to look at case where dim is tuple +@tensorrt_converter(torch.ops.aten.squeeze.dim) +@tensorrt_converter(torch.ops.aten.squeeze.dims) +def aten_ops_squeeze( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + } + return add_squeeze(network, target, kwargs_new, name) + + +# FIXME: need to confirm lower basic passes +# @tensorrt_converter(torch.ops.aten.chunk) +# def aten_ops_chunk( +# network: TRTNetwork, +# target: Target, +# args: Tuple[Argument, ...], +# kwargs: Dict[str, Argument], +# name: str, +# ) -> Union[TRTTensor, Sequence[TRTTensor]]: +# kwargs_new = { +# "input": args[0], +# "chunks": args[1], +# "dim": args[2], +# } +# return add_chunk(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.where.self) +def aten_ops_where( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "condition": args[0], + "x": args[1], + "y": args[2], + } + return add_where(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.rsub) +def aten_ops_rsub( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + "alpha": args[2], + } + return add_rsub(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.rsqrt) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return add_rsqrt(network, target, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 5955e598f5..8d45278548 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -580,7 +580,7 @@ def layer_norm( set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") # X-E[x] - sub_trt = operator.add_binary_elementwise_layer( + sub_trt = add_binary_elementwise_layer( network, input_val, mean_expected_layer.get_output(0), @@ -594,7 +594,7 @@ def layer_norm( trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), ) pow_tensor.name = f"{name}_power" - pow_var = operator.add_binary_elementwise_layer( + pow_var = add_binary_elementwise_layer( network, sub_trt, pow_tensor.get_output(0), @@ -739,6 +739,7 @@ def add_layer_norm(network, target, kwargs, name): _LOGGER.error( "Unable to find layer norm plugin, fall back to TensorRT implementation." ) + args = [] return layer_norm(network, target, args, kwargs, name) layer = network.add_plugin_v2([input_val], plugin) layer.name = name @@ -1254,3 +1255,267 @@ def add_matmul(network, target, kwargs, name): ) set_layer_name(layer, target, name) return layer.get_output(0) + + +def add_softmax(network, target, kwargs, name): + input_val = kwargs["input"] + input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"softmax received input {input_val} that is not part " + "of the TensorRT region!" + ) + + # Used to get dim when dim is None. Copied from PyTorch softmax implementation. + def get_softmax_dim(ndim: int) -> int: + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + if kwargs["dim"] is None: + dim = get_softmax_dim(input_ranks) + else: + dim = cast(int, kwargs["dim"]) + + dim = get_positive_dim(dim, input_ranks) + if network.has_implicit_batch_dimension: + assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." + dim -= 1 + + layer = network.add_softmax(input_val) + layer.axes = 1 << dim + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_squeeze(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"squeeze received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) + # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic + # dim, which is a very rare case. For now we just claim not supporting dim=None. + assert dim is not None, "We don't support dim=None right now for squeeze." + + dim = get_positive_dim( + dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + ) + if network.has_implicit_batch_dimension: + assert dim != 0, "We don't support squeeze batch dim when it's implicit." + dim -= 1 + + assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "Currently more than one dynamic dim for input to squeeze is not supported." + + output_shape = [] + for i, s in enumerate(input_val.shape): + if i == dim and s == 1: + continue + output_shape.append(s) + layer = network.add_shuffle(input_val) + layer.reshape_dims = tuple(output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_chunk(network, target, kwargs, name): + input_val = kwargs["input"] + chunks = cast(int, kwargs["chunks"]) + dim = cast(int, kwargs["dim"]) + input_dim_size = len(input_val.shape) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"chunk received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + input_dim_size += 1 + dim = get_positive_dim(dim, input_dim_size) + assert dim != 0, "Can't chunk on batch dim when it's implicit!" + dim -= 1 + else: + if dynamic_shape: + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + dim = get_positive_dim(dim, input_dim_size) + + if chunks > input_val.shape[dim]: + warnings.warn( + f"Asked for {chunks} chunks along dimention " + f"{dim} on tensor with size {input_val.shape}, chunks " + f"will default to {input_val.shape[dim]}", + RuntimeWarning, + ) + chunks = input_val.shape[dim] + + start = [0] * len(input_val.shape) + stride = [1] * len(start) + offset = 0 + split_size = (input_val.shape[dim] + chunks - 1) // chunks + + max_offset = input_val.shape[dim] + # add slice layers + output = [] + for i in range(chunks): + shape = list(input_val.shape) + shape[dim] = min(split_size, max_offset - offset) + if dynamic_shape: + shape = get_shape_with_dynamic_shape( + network, shape, input_val, target, f"{name}_{i}" + ) + start[dim] = offset + layer = network.add_slice( + input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride + ) + if dynamic_shape: + layer.set_input(2, shape) + offset += split_size + set_layer_name(layer, target, f"{name}_{i}") + output.append(layer.get_output(0)) + return output + + +def add_where(network, target, kwargs, name): + condition_t = kwargs["condition"] + x_t = kwargs["x"] + y_t = kwargs["y"] + + if type(x_t) != TRTTensor: + assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" + + if type(y_t) != TRTTensor: + assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" + + # get output shape + + x_shape = list(x_t.shape) + y_shape = list(y_t.shape) + condition_shape = list(condition_t.shape) + output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) + + # expand shape + if type(condition_t) != TRTTensor: + assert condition_t.dtype == torch.bool, "condition dtype is not bool" + if condition_shape != output_shape: + condition_t.expand(output_shape) + condition_t = condition_t.to(torch.int32) + condition_const = get_trt_tensor(network, condition_t, f"{name}_condition") + condition_layer = network.add_identity(condition_const) + condition_layer.set_output_type(0, trt.bool) + set_layer_name(condition_layer, target, f"{name}_condition") + condition_val = condition_layer.get_output(0) + else: + assert condition_t.dtype == trt.bool, "mask dtype is not bool!" + if condition_shape != output_shape: + condition_val = add_expand( + network, + target, + None, + {"input": condition_t, "sizes": output_shape}, + name=f"{name}_expand", + ) + else: + condition_val = condition_t + + if type(x_t) != TRTTensor: + if x_shape != output_shape: + # special case where 1 element in x_t + if len(x_t.shape) == 0: + x_t = x_t.unsqueeze(0) + x_t = x_t.expand(output_shape) + x_val = get_trt_tensor(network, x_t, f"{name}_x") + else: + x_val = x_t + if x_shape != output_shape: + x_val = add_expand( + network, + target, + None, + {"input": x_val, "sizes": output_shape}, + name=f"{name}_x_expand", + ) + + if type(y_t) != TRTTensor: + if y_shape != output_shape: + # special case where 1 element in y_t + if len(y_t.shape) == 0: + y_t = y_t.unsqueeze(0) + y_t = y_t.expand(output_shape) + y_val = get_trt_tensor(network, y_t, f"{name}_y") + else: + y_val = y_t + if y_shape != output_shape: + y_val = add_expand( + network, + target, + None, + {"input": y_val, "sizes": output_shape}, + name=f"{name}_y_expand", + ) + + select_layer = network.add_select(condition_val, x_val, y_val) + + set_layer_name(select_layer, target, f"{name}_select") + + return select_layer.get_output(0) + + +def add_scale(network, target, kwargs, name): + other = kwargs["other"] + scale = kwargs["scale"] + if isinstance(other, TRTTensor): + other_dtype = torch_dtype_from_trt(other.dtype) + is_other_trt_tensor = True + + if not is_other_trt_tensor: + warnings.warn( + f"The value to be scaled is constant" + "In this case, please consider constant fold the model first." + ) + return other * scale + layer = network.add_scale(other, trt.ScaleMode.UNIFORM, 0, scale, 1) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_rsub(network, target, kwargs, name): + scaled_tensor = add_scale(network, target, kwargs, name) + input = kwargs["input"] + return add_binary_elementwise_layer( + network, + kwargs["input"], + scaled_tensor, + trt.ElementWiseOperation.SUB, + target, + name, + ) + + +def add_sqrt(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.SQRT + return add_unary_layer(network, input_val, operation_type, target, name) + + +def add_rsqrt(network, target, kwargs, name): + sqrt_trt = add_sqrt(network, target, kwargs, name) + div_trt = add_binary_elementwise_layer( + network, + 1, + sqrt_trt, + trt.ElementWiseOperation.DIV, + target, + f"{name}_div_trt", + ) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py new file mode 100644 index 0000000000..8fae6da293 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py @@ -0,0 +1,58 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSelectConverterImplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_chunk_dim", 6, 0), + ] + ) + def test_chunk(self, _, chunk, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.chunk(input, chunk, dim) + return out + + input = [torch.randn(11)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.chunk}, + ) + + +class TestSelectConverterExplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_chunk_dim", 6, 0), + ] + ) + def test_chunk(self, _, chunk, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.chunk(input, chunk, dim) + return out + + input = [torch.randn(12)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.chunk}, + test_explicit_precision=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py new file mode 100644 index 0000000000..cf97e828d0 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -0,0 +1,45 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestLayerNormConverter(DispatchTestCase): + def test_layer_norm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + +def test_layernorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[(1, 3, 1, 1)], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py new file mode 100644 index 0000000000..31e293fc91 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py @@ -0,0 +1,44 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSoftMaxConverter(DispatchTestCase): + def test_softmax(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(1) + + def forward(self, x): + return self.softmax(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default} + ) + + def test_softmax_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(2) + + def forward(self, x): + return self.softmax(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py new file mode 100644 index 0000000000..5dd15a89e7 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (0), (2, 1)), + ("3d_one_dim", (0), (2, 2, 1)), + # ("3d_two_dim", (0, 1), (2, 2, 1)), + # ("4d_dim", (0, 1, 2), (2, 2, 2, 1)), + ] + ) + def test_squeeze(self, _, dim, init_size): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + inputs = [torch.randn(*init_size)] + expected_op = {} + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + self.run_test( + Squeeze(), + inputs, + expected_ops=expected_op, + ) + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), + ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), + # ("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), + ] + ) + def test_squeeze(self, _, dim, init_size, shape_range): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + input_specs = [ + InputTensorSpec( + shape=init_size, + dtype=torch.float32, + shape_ranges=shape_range, + ), + ] + self.run_test_with_dynamic_shape( + Squeeze(), + input_specs, + expected_ops=expected_op, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py new file mode 100644 index 0000000000..6c050eee2f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestWhereConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2)), + ("2d_broadcast_condition_xshape_yshape", (x < 0), (2, 2), (2, 1)), + ("3d_condition_xshape_yshape", (x > 0), (2, 2, 1), (2, 2, 1)), + ("2d_3d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2, 1)), + ] + ) + def test_(self, _, condition, x_size, y_size): + class Where(nn.Module): + def forward(self, x): + return torch.where(x, dim) + + inputX = [torch.randn(*x_size)] + inputOther = [torch.randn(*y_size)] + expected_op = {} + self.run_test( + Where(), + inputs, + expected_ops=torch.ops.aten.where.self, + ) + + +# class TestWhereConverter(DispatchTestCase): +# @parameterized.expand( +# [ +# ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), +# ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), +# #("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), +# ] +# ) +# def test_where(self, _, dim, init_size, shape_range): +# class Squeeze(nn.Module): +# def forward(self, x): +# return torch.squeeze(x, dim) + +# input_specs = [ +# InputTensorSpec( +# shape=init_size, +# dtype=torch.float32, +# shape_ranges=shape_range, +# ), +# ] +# self.run_test_with_dynamic_shape( +# Squeeze(), +# input_specs, +# expected_ops=torch.ops.aten.where.self, +# ) From c5a4744867e8637a58042972bfee133372dcfbb1 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 10 Apr 2023 09:13:14 -0700 Subject: [PATCH 24/39] Transformer operator changes --- .../fx/converters/converter_utils.py | 33 ++++++++++ py/torch_tensorrt/fx/converters/operator.py | 64 +++++++++++++------ .../fx/passes/lower_basic_pass_aten.py | 1 + .../converters/aten_op/test_rsqrt_aten.py | 29 +++++++++ .../test/converters/aten_op/test_rsub_aten.py | 29 +++++++++ .../converters/aten_op/test_squeeze_aten.py | 4 +- .../converters/aten_op/test_where_aten.py | 57 +++++++++-------- .../tensorrt_dynamo_backend.py | 2 +- 8 files changed, 171 insertions(+), 48 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 551c18652d..9d405767ea 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -288,6 +288,39 @@ def prepend_ones( return layer.get_output(0) +def broadcastable( + a: TRTTensor, + b: TRTTensor, +) -> bool: + "Check if two tensors are broadcastable according to torch rules" + a_shape = tuple(a.shape) + b_shape = tuple(b.shape) + print("a shape is", a_shape) + print("b shape is", b_shape) + # check from the trailing + diff = len(a_shape) - len(b_shape) + if diff == 0: + return True + if diff > 0: + max = len(a_shape) + min = len(b_shape) + greater_tensor = a_shape + lesser_tensor = b_shape + elif diff < 0: + max = len(b_shape) + min = len(a_shape) + greater_tensor = b_shape + lesser_tensor = a_shape + j = min - 1 + for i in range(max - 1, diff - 1, -1): + if not ( + greater_tensor[i] != lesser_tensor[j] + and (greater_tensor[i] == 1 or lesser_tensor[i] == 1) + ): + return False + return True + + def broadcast( network: TRTNetwork, a: TRTTensor, diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 8d45278548..4449b7146e 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -15,6 +15,7 @@ from .converter_utils import set_layer_name from .converter_utils import get_trt_tensor from .converter_utils import broadcast +from .converter_utils import broadcastable from .converter_utils import squeeze_left from .converter_utils import dtype_uniform from .converter_utils import get_trt_plugin @@ -1119,7 +1120,6 @@ def add_expand(network, target, kwargs, name): # TRT does not support different dimension size assert len(shape) == ranks shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)] - inshape = tuple(input_val.shape) shape = tuple(shape) start = tuple([0] * ranks) @@ -1299,27 +1299,36 @@ def add_squeeze(network, target, kwargs, name): f"squeeze received input {input_val} that is not part " "of the TensorRT region!" ) + dims = [] + if "dim" in kwargs: + if isinstance(kwargs["dim"], int): + dims.append(cast(Optional[int], kwargs["dim"])) + else: + for dim in kwargs["dim"]: + dims.append(cast(Optional[int], dim)) - dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) + # dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic # dim, which is a very rare case. For now we just claim not supporting dim=None. - assert dim is not None, "We don't support dim=None right now for squeeze." + assert not (len(dims) == 0), "We don't support dim=None right now for squeeze." - dim = get_positive_dim( - dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - ) - if network.has_implicit_batch_dimension: - assert dim != 0, "We don't support squeeze batch dim when it's implicit." - dim -= 1 + for dim in dims: + dim = get_positive_dim( + dim, + len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0), + ) + if network.has_implicit_batch_dimension: + assert dim != 0, "We don't support squeeze batch dim when it's implicit." + dim -= 1 - assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." - assert ( - len(get_dynamic_dims(input_val.shape)) <= 1 - ), "Currently more than one dynamic dim for input to squeeze is not supported." + assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "Currently more than one dynamic dim for input to squeeze is not supported." output_shape = [] for i, s in enumerate(input_val.shape): - if i == dim and s == 1: + if (i in dims) and s == 1: continue output_shape.append(s) layer = network.add_shuffle(input_val) @@ -1392,14 +1401,32 @@ def add_where(network, target, kwargs, name): x_t = kwargs["x"] y_t = kwargs["y"] + x_t_dim = len(tuple(x_t.shape)) + y_t_dim = len(tuple(y_t.shape)) + condition_t_dim = len(tuple(condition_t.shape)) + if type(x_t) != TRTTensor: assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" if type(y_t) != TRTTensor: assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" + if not (broadcastable(x_t, y_t)): + assert f"The two torch tensors should be broadcastable" + # get output shape + # purpose of this is to bring x_t and y_t rank same as + # output_shape to input it to the add_expand operation + # condition_t will have dimension of either x_t or y_t + x_t, y_t = broadcast(network, x_t, y_t, f"{name}_x", f"{name}_y") + if len(tuple(condition_t.shape)) != len(tuple(x_t.shape)): + condition_t, x_t = broadcast( + network, condition_t, x_t, f"{name}_condition", f"{name}_x" + ) + print("x_t shape", x_t.shape) + print("y_t shape", y_t.shape) + print("condition_t shape", condition_t.shape) x_shape = list(x_t.shape) y_shape = list(y_t.shape) condition_shape = list(condition_t.shape) @@ -1418,11 +1445,10 @@ def add_where(network, target, kwargs, name): condition_val = condition_layer.get_output(0) else: assert condition_t.dtype == trt.bool, "mask dtype is not bool!" - if condition_shape != output_shape: + if condition_shape != condition_t_dim: condition_val = add_expand( network, target, - None, {"input": condition_t, "sizes": output_shape}, name=f"{name}_expand", ) @@ -1430,7 +1456,7 @@ def add_where(network, target, kwargs, name): condition_val = condition_t if type(x_t) != TRTTensor: - if x_shape != output_shape: + if x_shape != x_t_dim: # special case where 1 element in x_t if len(x_t.shape) == 0: x_t = x_t.unsqueeze(0) @@ -1442,7 +1468,6 @@ def add_where(network, target, kwargs, name): x_val = add_expand( network, target, - None, {"input": x_val, "sizes": output_shape}, name=f"{name}_x_expand", ) @@ -1456,11 +1481,10 @@ def add_where(network, target, kwargs, name): y_val = get_trt_tensor(network, y_t, f"{name}_y") else: y_val = y_t - if y_shape != output_shape: + if y_shape != y_t_dim: y_val = add_expand( network, target, - None, {"input": y_val, "sizes": output_shape}, name=f"{name}_y_expand", ) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index 00063c3e21..30aeee6944 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -258,6 +258,7 @@ def remove_ops( for n in module.graph.nodes: if n.op == "call_function" and n.target in ( torch.ops.aten._unsafe_view.default, + torch.ops.aten.view.default, ): modified = True node = n diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py index e69de29bb2..da3aa30cb7 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input, input, alpha) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops=torch.ops.aten.rsqrt, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py index e69de29bb2..9be23fc419 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsub(self, _, x, alpha): + class rsub(nn.Module): + def forward(self, input): + return torch.rsub(input, input, alpha) + + inputs = [torch.randn(x)] + self.run_test( + rsub(), + inputs, + expected_ops=torch.ops.aten.rsub, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py index 5dd15a89e7..5c655422de 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py @@ -10,8 +10,8 @@ class TestSqueezeConverter(DispatchTestCase): [ ("2d_dim", (0), (2, 1)), ("3d_one_dim", (0), (2, 2, 1)), - # ("3d_two_dim", (0, 1), (2, 2, 1)), - # ("4d_dim", (0, 1, 2), (2, 2, 2, 1)), + ("3d_two_dim", (0, 1), (2, 1, 1)), + ("4d_dim", (0, 1, 2), (2, 2, 1, 1)), ] ) def test_squeeze(self, _, dim, init_size): diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py index 6c050eee2f..0d4849c21f 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py @@ -8,49 +8,56 @@ class TestWhereConverter(DispatchTestCase): @parameterized.expand( [ - ("2d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2)), - ("2d_broadcast_condition_xshape_yshape", (x < 0), (2, 2), (2, 1)), - ("3d_condition_xshape_yshape", (x > 0), (2, 2, 1), (2, 2, 1)), - ("2d_3d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2, 1)), + ("2d_condition_xshape_yshape", (2, 2), (2, 2)), + ("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)), + ("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)), + ("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)), ] ) - def test_(self, _, condition, x_size, y_size): + def test_(self, _, x_size, y_size): class Where(nn.Module): - def forward(self, x): - return torch.where(x, dim) + def forward(self, condition, x, y): + return torch.where(condition, x, y) - inputX = [torch.randn(*x_size)] - inputOther = [torch.randn(*y_size)] - expected_op = {} + inputX = torch.randn(*x_size) + inputOther = torch.randn(*y_size) + condition = inputX < 0 self.run_test( Where(), - inputs, - expected_ops=torch.ops.aten.where.self, + (condition, inputX, inputOther), + expected_ops={torch.ops.aten.where.self}, ) +# FIXME: How to specify condition for dynamic shape +# InputTensorSpec like case below where one input is dynamic another is not # class TestWhereConverter(DispatchTestCase): # @parameterized.expand( # [ -# ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), -# ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), +# ("2d_dim", (-1, 2), [((1, 2), (2, 2), (2, 2))], (2,2)) +# #("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), # #("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), # ] # ) -# def test_where(self, _, dim, init_size, shape_range): -# class Squeeze(nn.Module): -# def forward(self, x): -# return torch.squeeze(x, dim) - -# input_specs = [ -# InputTensorSpec( -# shape=init_size, +# def test_where(self, _, x_size, x_size_range, y_size): +# class Where(nn.Module): +# def forward(self, condition, x, y): +# return torch.where(condition, x, y) +# inputX = InputTensorSpec( +# shape=x_size, # dtype=torch.float32, -# shape_ranges=shape_range, -# ), +# shape_ranges=x_size_range, +# ) +# inputOther = torch.randn(*y_size) +# condition = (inputOther < 0) +# input_specs = [ +# inputX, inputOther, condition # ] # self.run_test_with_dynamic_shape( -# Squeeze(), +# Where(), # input_specs, # expected_ops=torch.ops.aten.where.self, # ) + +# if __name__ == "__main__": +# run_tests() diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index 55c5e2df33..e53f0bc64e 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -22,7 +22,7 @@ from torch._inductor.decomposition import decompositions DECOMPOSITIONS = decompositions.copy() -MAX_SPLITS_THRESHOLD = 10 +MAX_SPLITS_THRESHOLD = 100 def tensorrt_backend(gm, sample_inputs): From 8d4e4b4f89c8e413a4420e4ba468e2e3d3e284ce Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 10 Apr 2023 23:39:59 -0700 Subject: [PATCH 25/39] Fixing acc split test --- py/torch_tensorrt/fx/converters/operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 5955e598f5..53e0d88557 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1207,7 +1207,7 @@ def add_slice(network, target, kwargs, name): stride = [1] * len(start) stride[dim] = step_int output_shape = list(input_val.shape) - output_shape[dim] = (stop_int - start_int) // step_int + 1 + output_shape[dim] = (stop_int - start_int) // step_int if dynamic_shape > 0: output_shape = get_shape_with_dynamic_shape( From 1ab9af5ae5f6842fcd00f7e12d4fe2b308c1fbfe Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 11 Apr 2023 12:42:05 -0700 Subject: [PATCH 26/39] Bug fix for add_slice --- py/torch_tensorrt/fx/converters/operator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 53e0d88557..37dd84d84e 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -2,6 +2,7 @@ import operator import warnings import logging +import math from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import tensorrt as trt @@ -1207,7 +1208,7 @@ def add_slice(network, target, kwargs, name): stride = [1] * len(start) stride[dim] = step_int output_shape = list(input_val.shape) - output_shape[dim] = (stop_int - start_int) // step_int + output_shape[dim] = math.ceil((stop_int - start_int) // step_int) if dynamic_shape > 0: output_shape = get_shape_with_dynamic_shape( From 8de6c9d449518512c7a4f3c12cda06495866f284 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 11 Apr 2023 16:05:49 -0700 Subject: [PATCH 27/39] dynamic test for slice --- .../converters/aten_op/test_slice_aten.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py index b018aff73e..6ddc082657 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py @@ -34,6 +34,7 @@ class TestSelectConverterExplicitBatch(DispatchTestCase): @parameterized.expand( [ ("select_dim_start_stop_step", 1, 0, 7, 2), + ("select_dim_start_stop_step_exact", 1, 0, 10, 2), ] ) def test_slice(self, _, dim, start, stop, step): @@ -54,5 +55,35 @@ def forward(self, input): ) +class TestSelectConverterDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 1, 0, 7, 2), + ("select_dim_start_stop_step", 1, 0, 10, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input_specs = [ + InputTensorSpec( + shape=(1, 10, -1), + dtype=torch.float32, + shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + if __name__ == "__main__": run_tests() From ab89d2b045eb93f549c5c6de65b22830a47f9386 Mon Sep 17 00:00:00 2001 From: Apurba Bose <44209735+apbose@users.noreply.github.com> Date: Fri, 14 Apr 2023 12:09:12 -0700 Subject: [PATCH 28/39] Correct the output_shape dimension for add_slice --- py/torch_tensorrt/fx/converters/operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 37dd84d84e..374db3f611 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1208,7 +1208,7 @@ def add_slice(network, target, kwargs, name): stride = [1] * len(start) stride[dim] = step_int output_shape = list(input_val.shape) - output_shape[dim] = math.ceil((stop_int - start_int) // step_int) + output_shape[dim] = math.ceil((stop_int - start_int) / step_int) if dynamic_shape > 0: output_shape = get_shape_with_dynamic_shape( From 09a52b99082604279c59b66aa5cb5b9d417db71f Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 19 Apr 2023 08:04:31 -0700 Subject: [PATCH 29/39] matmul changes, bmm changes and adding broadcastable --- .../fx/converters/converter_utils.py | 37 +++++- py/torch_tensorrt/fx/converters/operator.py | 13 ++- .../fx/passes/lower_basic_pass_aten.py | 28 ++++- .../converters/aten_op/test_matmul_aten.py | 107 ++++++++++++++++-- 4 files changed, 167 insertions(+), 18 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 551c18652d..8205d32ecb 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -77,7 +77,7 @@ def get_positive_dim(dim: int, dim_size: int) -> int: return dim -def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None: +def set_layer_name(layer: TRTLayer, target: Target, name: str, is_acc=True) -> None: """ Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]" @@ -87,7 +87,7 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None: the node represents. name (str): Consists of fx node.name with optional suffix. """ - target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}" + target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}" if is_acc else f"aten_ops.{target.__name__}" layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]" @@ -288,6 +288,39 @@ def prepend_ones( return layer.get_output(0) +def broadcastable( + a: TRTTensor, + b: TRTTensor, +) -> bool: + "Check if two tensors are broadcastable according to torch rules" + a_shape = tuple(a.shape) + b_shape = tuple(b.shape) + print("a shape is", a_shape) + print("b shape is", b_shape) + # check from the trailing + diff = len(a_shape) - len(b_shape) + if diff == 0: + return True + if diff > 0: + max = len(a_shape) + min = len(b_shape) + greater_tensor = a_shape + lesser_tensor = b_shape + elif diff < 0: + max = len(b_shape) + min = len(a_shape) + greater_tensor = b_shape + lesser_tensor = a_shape + j = min - 1 + for i in range(max - 1, diff - 1, -1): + if not ( + greater_tensor[i] != lesser_tensor[j] + and (greater_tensor[i] == 1 or lesser_tensor[i] == 1) + ): + return False + return True + + def broadcast( network: TRTNetwork, a: TRTTensor, diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 374db3f611..674942b1ce 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -16,6 +16,7 @@ from .converter_utils import set_layer_name from .converter_utils import get_trt_tensor from .converter_utils import broadcast +from .converter_utils import broadcastable from .converter_utils import squeeze_left from .converter_utils import dtype_uniform from .converter_utils import get_trt_plugin @@ -1117,7 +1118,17 @@ def add_expand(network, target, kwargs, name): ranks = len(input_val.shape) # TRT does not support different dimension size - assert len(shape) == ranks + #though this condition is not seen in the case of bmm + # where input_t and shape dimensions are not equal + assert len(shape) >= ranks + if(len(shape) != ranks): + shape_tuple = tuple([0] * len(shape)) + shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape") + input_val, shape_tensor = broadcast(network, input_val, shape_tensor, + f"{name}_input_val", + f"{name}_shape_val") + ranks = len(shape) + shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)] inshape = tuple(input_val.shape) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index 00063c3e21..f9b5b20fbf 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -416,29 +416,46 @@ def compose_bmm( node = n input_n = node.all_input_nodes[0] other_n = node.all_input_nodes[1] + + # If no input nodes are available, the bmm argument itself could be an input + # Alternatively, if the node has no users, it can be eliminated + if len(input_n.all_input_nodes) == 0 or len(node.users) == 0: + return PassResult(module, modified) + output = next(iter(node.users)) input_input_n = input_n.all_input_nodes[0] if ( input_input_n.target != torch.ops.aten.expand.default and input_n.target != torch.ops.aten.view.default ): - raise RuntimeError( - "Bmm is addressed in fixed pattern. A new pattern is met!" + _LOGGER.warn( + "Bmm is addressed in fixed pattern. " + + f"A new pattern {input_input_n.target}, {input_n.target} is met! " + + "Skipping bmm lowering on this operation" ) + return PassResult(module, modified) + real_input = input_input_n.all_input_nodes[0] input_other_n = other_n.all_input_nodes[0] if ( input_other_n.target != torch.ops.aten.expand.default and other_n.target != torch.ops.aten.view.default ): - raise RuntimeError( - "Bmm is addressed in fixed pattern. A new pattern is met!" + _LOGGER.warn( + "Bmm is addressed in fixed pattern. " + + f"A new pattern {input_other_n.target}, {other_n.target} is met! " + + "Skipping bmm lowering on this operation" ) + return PassResult(module, modified) + real_other = input_other_n.all_input_nodes[0] if len(real_other.meta["val"].size()) == 2: new_func = aten_compose_bmm_2d - if len(real_other.meta["val"].size()) == 3: + elif len(real_other.meta["val"].size()) == 3: new_func = aten_compose_bmm_3d + else: + # No valid bmm replacement exists for the specified dimensions + return PassResult(module, modified) with module.graph.inserting_after(node): new_args = (real_input, real_other) @@ -449,6 +466,7 @@ def compose_bmm( kwargs=None, ) output.replace_all_uses_with(new_node) + modified = True module.graph.eliminate_dead_code() module.recompile() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py index 0b0cd8d0b5..de4911bb08 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py @@ -6,22 +6,109 @@ from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec + class TestMatMulConverter(DispatchTestCase): - def test_matmul(self): - class TestModule(torch.nn.Module): - def forward(self, x, y): - return torch.matmul(x, y) - - inputOne = torch.randn(2, 32) - inputTwo = torch.randn(32, 2) - inputs = [inputOne, inputTwo] + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("2_2", (2, 3), (3, 1)), + #FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # (2,3), (3,) torch.ops.aten.mv.default + # Following cases use torch.ops.aten.bmm.defauly + #("4_3", (3,1,3,2), (2,2,3)), + #("3_4", (3,1,3,2), (2,2,3)), + #("3_4", (2, 2, 3), (3, 1, 3, 3)), + #("4_2", (1, 2, 2, 3), (3, 2)), + ] + ) + def test_matmul_other_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.other = nn.Parameter(torch.randn(*other_shape)) + + def forward(self, input): + return torch.matmul(input, self.other) + + inputs = [torch.randn(*input_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + test_explicit_batch_dim=(len(input_shape) >= 1), + ) + + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("1_2", (1, 3), (3, 2)), + #FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # (2,3), (3,) torch.ops.aten.mv.default + # Following cases use torch.ops.aten.bmm.defauly + #("4_3", (3,1,3,2), (2,2,3)), + #("3_4", (3,1,3,2), (2,2,3)), + #("3_4", (2, 2, 3), (3, 1, 3, 3)), + #("4_2", (1, 2, 2, 3), (3, 2)), + + ] + ) + def test_matmul_input_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.input = nn.Parameter(torch.randn(*input_shape)) + + def forward(self, other): + return torch.matmul(self.input, other) + + inputs = [torch.randn(*other_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + test_explicit_batch_dim=True + #test_explicit_batch_dim=(len(other_shape) <= 2), + ) + + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + # ("2_3", (2, 3), (2, 3, 4)), + # ("4_4", (2, 2, 2, 3), (2, 1, 3, 2)), + # ("4_2", (2, 1, 2, 3), (3, 2)), + # ("2_1", (2, 3), (3,)), + # ("1_2", (3,), (3, 2)), + # ("1_1", (3,), (3,)), + ] + ) + def test_matmul(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def forward(self, input, other): + return torch.matmul(input, other) + + inputs = [torch.randn(*input_shape), torch.randn(*other_shape)] + test_explicit_batch_dim = not( + input_shape[0] == other_shape[0] + and len(input_shape) > 2 + and len(other_shape) > 2 + ) self.run_test( - TestModule(), + MatMul(), inputs, expected_ops={torch.ops.aten.mm.default}, + test_explicit_batch_dim=test_explicit_batch_dim, ) + #FIXME: dynamic shape is giving bmm if __name__ == "__main__": - run_tests() + run_tests() \ No newline at end of file From d1fd1d7b30548281f8155abfda872e7e76a8f362 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 19 Apr 2023 08:43:49 -0700 Subject: [PATCH 30/39] Correcting pre-commit hooks --- .../fx/converters/converter_utils.py | 8 +++- py/torch_tensorrt/fx/converters/operator.py | 18 ++++----- .../fx/passes/lower_basic_pass_aten.py | 8 ++-- .../converters/aten_op/test_matmul_aten.py | 38 +++++++++---------- 4 files changed, 39 insertions(+), 33 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 8205d32ecb..2d79014ebd 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -87,7 +87,13 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str, is_acc=True) -> N the node represents. name (str): Consists of fx node.name with optional suffix. """ - target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}" if is_acc else f"aten_ops.{target.__name__}" + target_name = ( + target + if isinstance(target, str) + else f"acc_ops.{target.__name__}" + if is_acc + else f"aten_ops.{target.__name__}" + ) layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]" diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 674942b1ce..9487894506 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1118,17 +1118,17 @@ def add_expand(network, target, kwargs, name): ranks = len(input_val.shape) # TRT does not support different dimension size - #though this condition is not seen in the case of bmm + # though this condition is not seen in the case of bmm # where input_t and shape dimensions are not equal assert len(shape) >= ranks - if(len(shape) != ranks): - shape_tuple = tuple([0] * len(shape)) - shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape") - input_val, shape_tensor = broadcast(network, input_val, shape_tensor, - f"{name}_input_val", - f"{name}_shape_val") - ranks = len(shape) - + if len(shape) != ranks: + shape_tuple = tuple([0] * len(shape)) + shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape") + input_val, shape_tensor = broadcast( + network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val" + ) + ranks = len(shape) + shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)] inshape = tuple(input_val.shape) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index f9b5b20fbf..6790962621 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -417,11 +417,11 @@ def compose_bmm( input_n = node.all_input_nodes[0] other_n = node.all_input_nodes[1] - # If no input nodes are available, the bmm argument itself could be an input + # If no input nodes are available, the bmm argument itself could be an input # Alternatively, if the node has no users, it can be eliminated if len(input_n.all_input_nodes) == 0 or len(node.users) == 0: return PassResult(module, modified) - + output = next(iter(node.users)) input_input_n = input_n.all_input_nodes[0] if ( @@ -434,7 +434,7 @@ def compose_bmm( + "Skipping bmm lowering on this operation" ) return PassResult(module, modified) - + real_input = input_input_n.all_input_nodes[0] input_other_n = other_n.all_input_nodes[0] if ( @@ -447,7 +447,7 @@ def compose_bmm( + "Skipping bmm lowering on this operation" ) return PassResult(module, modified) - + real_other = input_other_n.all_input_nodes[0] if len(real_other.meta["val"].size()) == 2: new_func = aten_compose_bmm_2d diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py index de4911bb08..e0dc05fded 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py @@ -19,13 +19,13 @@ class TestMatMulConverter(DispatchTestCase): [ ("2_2", (2, 3), (3, 2)), ("2_2", (2, 3), (3, 1)), - #FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? # (2,3), (3,) torch.ops.aten.mv.default - # Following cases use torch.ops.aten.bmm.defauly - #("4_3", (3,1,3,2), (2,2,3)), - #("3_4", (3,1,3,2), (2,2,3)), - #("3_4", (2, 2, 3), (3, 1, 3, 3)), - #("4_2", (1, 2, 2, 3), (3, 2)), + # Following cases use torch.ops.aten.bmm.defauly + # ("4_3", (3,1,3,2), (2,2,3)), + # ("3_4", (3,1,3,2), (2,2,3)), + # ("3_4", (2, 2, 3), (3, 1, 3, 3)), + # ("4_2", (1, 2, 2, 3), (3, 2)), ] ) def test_matmul_other_constant(self, _, input_shape, other_shape): @@ -38,7 +38,7 @@ def forward(self, input): return torch.matmul(input, self.other) inputs = [torch.randn(*input_shape)] - + self.run_test( MatMul(), inputs, @@ -50,14 +50,13 @@ def forward(self, input): [ ("2_2", (2, 3), (3, 2)), ("1_2", (1, 3), (3, 2)), - #FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? # (2,3), (3,) torch.ops.aten.mv.default - # Following cases use torch.ops.aten.bmm.defauly - #("4_3", (3,1,3,2), (2,2,3)), - #("3_4", (3,1,3,2), (2,2,3)), - #("3_4", (2, 2, 3), (3, 1, 3, 3)), - #("4_2", (1, 2, 2, 3), (3, 2)), - + # Following cases use torch.ops.aten.bmm.defauly + # ("4_3", (3,1,3,2), (2,2,3)), + # ("3_4", (3,1,3,2), (2,2,3)), + # ("3_4", (2, 2, 3), (3, 1, 3, 3)), + # ("4_2", (1, 2, 2, 3), (3, 2)), ] ) def test_matmul_input_constant(self, _, input_shape, other_shape): @@ -75,8 +74,8 @@ def forward(self, other): MatMul(), inputs, expected_ops={torch.ops.aten.mm.default}, - test_explicit_batch_dim=True - #test_explicit_batch_dim=(len(other_shape) <= 2), + test_explicit_batch_dim=True + # test_explicit_batch_dim=(len(other_shape) <= 2), ) @parameterized.expand( @@ -96,7 +95,7 @@ def forward(self, input, other): return torch.matmul(input, other) inputs = [torch.randn(*input_shape), torch.randn(*other_shape)] - test_explicit_batch_dim = not( + test_explicit_batch_dim = not ( input_shape[0] == other_shape[0] and len(input_shape) > 2 and len(other_shape) > 2 @@ -108,7 +107,8 @@ def forward(self, input, other): test_explicit_batch_dim=test_explicit_batch_dim, ) - #FIXME: dynamic shape is giving bmm + # FIXME: dynamic shape is giving bmm + if __name__ == "__main__": - run_tests() \ No newline at end of file + run_tests() From ce7f122aa47d632120ec7c36e91f359afbe612d8 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 12:33:02 -0700 Subject: [PATCH 31/39] Correcting rsqrt and rsub operator --- .../fx/test/converters/aten_op/test_rsqrt_aten.py | 4 ++-- .../fx/test/converters/aten_op/test_rsub_aten.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py index da3aa30cb7..c80216654c 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -15,13 +15,13 @@ class TestRSubConverter(DispatchTestCase): def test_rsqrt(self, _, x, alpha): class rsqrt(nn.Module): def forward(self, input): - return torch.rsqrt(input, input, alpha) + return torch.rsqrt(input) inputs = [torch.randn(x) + 1] self.run_test( rsqrt(), inputs, - expected_ops=torch.ops.aten.rsqrt, + expected_ops={torch.ops.aten.rsqrt.default}, ) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py index 9be23fc419..dddd72f732 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py @@ -15,13 +15,13 @@ class TestRSubConverter(DispatchTestCase): def test_rsub(self, _, x, alpha): class rsub(nn.Module): def forward(self, input): - return torch.rsub(input, input, alpha) + return torch.rsub(input, input, alpha = alpha) inputs = [torch.randn(x)] self.run_test( rsub(), inputs, - expected_ops=torch.ops.aten.rsub, + expected_ops={torch.ops.aten.rsub.Tensor}, ) From 30c5fd6e654f0ac3a7025c49d60b24cd8f96df40 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 13:08:34 -0700 Subject: [PATCH 32/39] python linting issues and removing chunk test --- .../fx/converters/aten_ops_converters.py | 25 ++------ py/torch_tensorrt/fx/converters/operator.py | 12 +++- .../converters/aten_op/test_chunk_aten.py | 58 ------------------- .../test/converters/aten_op/test_rsub_aten.py | 2 +- 4 files changed, 15 insertions(+), 82 deletions(-) delete mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index d47f30a790..defa88d18b 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -672,23 +672,6 @@ def aten_ops_squeeze( return add_squeeze(network, target, kwargs_new, name) -# FIXME: need to confirm lower basic passes -# @tensorrt_converter(torch.ops.aten.chunk) -# def aten_ops_chunk( -# network: TRTNetwork, -# target: Target, -# args: Tuple[Argument, ...], -# kwargs: Dict[str, Argument], -# name: str, -# ) -> Union[TRTTensor, Sequence[TRTTensor]]: -# kwargs_new = { -# "input": args[0], -# "chunks": args[1], -# "dim": args[2], -# } -# return add_chunk(network, target, kwargs_new, name) - - @tensorrt_converter(torch.ops.aten.where.self) def aten_ops_where( network: TRTNetwork, @@ -705,7 +688,7 @@ def aten_ops_where( return add_where(network, target, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.rsub) +@tensorrt_converter(torch.ops.aten.rsub.Tensor) def aten_ops_rsub( network: TRTNetwork, target: Target, @@ -713,15 +696,17 @@ def aten_ops_rsub( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + if "alpha" in kwargs: + alpha = kwargs["alpha"] kwargs_new = { "input": args[0], "other": args[1], - "alpha": args[2], + "alpha": alpha, } return add_rsub(network, target, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.rsqrt) +@tensorrt_converter(torch.ops.aten.rsqrt.default) def aten_ops_rsqrt( network: TRTNetwork, target: Target, diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 1e53b1ccc5..ffd6a1bab5 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1526,7 +1526,13 @@ def add_scale(network, target, kwargs, name): def add_rsub(network, target, kwargs, name): - scaled_tensor = add_scale(network, target, kwargs, name) + kwargs_new = {} + if "alpha" in kwargs: + kwargs_new["input"] = kwargs["other"] + kwargs_new["other"] = kwargs["alpha"] + scaled_tensor = add_mul(network, target, kwargs_new, name + "_mul") + else: + scaled_tensor = kwargs["other"] input = kwargs["input"] return add_binary_elementwise_layer( network, @@ -1534,7 +1540,7 @@ def add_rsub(network, target, kwargs, name): scaled_tensor, trt.ElementWiseOperation.SUB, target, - name, + name + "_sub", ) @@ -1546,7 +1552,7 @@ def add_sqrt(network, target, kwargs, name): def add_rsqrt(network, target, kwargs, name): sqrt_trt = add_sqrt(network, target, kwargs, name) - div_trt = add_binary_elementwise_layer( + return add_binary_elementwise_layer( network, 1, sqrt_trt, diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py deleted file mode 100644 index 8fae6da293..0000000000 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py +++ /dev/null @@ -1,58 +0,0 @@ -import unittest - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec - - -class TestSelectConverterImplicitBatch(DispatchTestCase): - @parameterized.expand( - [ - ("select_chunk_dim", 6, 0), - ] - ) - def test_chunk(self, _, chunk, dim): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - out = torch.ops.aten.chunk(input, chunk, dim) - return out - - input = [torch.randn(11)] - self.run_test( - TestModule(), - input, - expected_ops={torch.ops.aten.chunk}, - ) - - -class TestSelectConverterExplicitBatch(DispatchTestCase): - @parameterized.expand( - [ - ("select_chunk_dim", 6, 0), - ] - ) - def test_chunk(self, _, chunk, dim): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - out = torch.ops.aten.chunk(input, chunk, dim) - return out - - input = [torch.randn(12)] - self.run_test( - TestModule(), - input, - expected_ops={torch.ops.aten.chunk}, - test_explicit_precision=True, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py index dddd72f732..268df8ccfd 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py @@ -15,7 +15,7 @@ class TestRSubConverter(DispatchTestCase): def test_rsub(self, _, x, alpha): class rsub(nn.Module): def forward(self, input): - return torch.rsub(input, input, alpha = alpha) + return torch.rsub(input, input, alpha=alpha) inputs = [torch.randn(x)] self.run_test( From 7ab071d91cc09281d8d518ce4f0dd406c6537955 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 16:00:48 -0700 Subject: [PATCH 33/39] Correcting acc squeeze test --- .../fx/test/converters/acc_op/test_squeeze.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py index d265def896..c9b4776dd3 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py @@ -12,7 +12,12 @@ def forward(self, x): return x.squeeze(2) inputs = [torch.randn(1, 2, 1)] - self.run_test(Squeeze(), inputs, expected_ops={acc_ops.squeeze}) + self.run_test( + Squeeze(), + inputs, + expected_ops={acc_ops.squeeze}, + test_implicit_batch_dim=False, + ) # Testing with shape=(-1, -1, -1, -1) results in error: # AssertionError: We don't support squeeze dynamic dim. From 36ac0cf341286865cdea67c87fac2f3f9cf8b8b9 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 17:23:15 -0700 Subject: [PATCH 34/39] test_reshape expected ops aten.reshape since aten.view has been removed in lowering --- .../fx/test/converters/aten_op/test_reshape_aten.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py index 538e575d6e..385ec05b8b 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py @@ -31,7 +31,7 @@ def forward(self, x): self.run_test( TestModule(target_shape), inputs, - expected_ops={torch.ops.aten.view.default}, + expected_ops={torch.ops.aten.reshape}, ) @parameterized.expand( @@ -64,7 +64,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( TestModule(target_shape), input_specs, - expected_ops={torch.ops.aten.view.default}, + expected_ops={torch.ops.aten.reshape}, ) @unittest.skipIf( @@ -94,7 +94,7 @@ def forward(self, x, y): self.run_test_with_dynamic_shape( TestModule(), input_specs, - expected_ops={torch.ops.aten.view.default}, + expected_ops={torch.ops.aten.reshape}, ) From eb851b19880dbc00acf6c69e78dd509e87bd1e81 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 21:43:07 -0700 Subject: [PATCH 35/39] removing aten.view in lowering pass --- py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py | 1 - .../fx/test/converters/aten_op/test_reshape_aten.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index 0d6b1c28de..6790962621 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -258,7 +258,6 @@ def remove_ops( for n in module.graph.nodes: if n.op == "call_function" and n.target in ( torch.ops.aten._unsafe_view.default, - torch.ops.aten.view.default, ): modified = True node = n diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py index 385ec05b8b..538e575d6e 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py @@ -31,7 +31,7 @@ def forward(self, x): self.run_test( TestModule(target_shape), inputs, - expected_ops={torch.ops.aten.reshape}, + expected_ops={torch.ops.aten.view.default}, ) @parameterized.expand( @@ -64,7 +64,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( TestModule(target_shape), input_specs, - expected_ops={torch.ops.aten.reshape}, + expected_ops={torch.ops.aten.view.default}, ) @unittest.skipIf( @@ -94,7 +94,7 @@ def forward(self, x, y): self.run_test_with_dynamic_shape( TestModule(), input_specs, - expected_ops={torch.ops.aten.reshape}, + expected_ops={torch.ops.aten.view.default}, ) From 6b234e0f34a9a27851eb438d70327c316976368e Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 22:47:12 -0700 Subject: [PATCH 36/39] layer_norm test --- .../aten_op/test_layer_norm_aten.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py index cf97e828d0..e204f4ec8b 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -19,26 +19,26 @@ def forward(self, x): ) -def test_layernorm_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.ln = torch.nn.LayerNorm([3, 224, 224]) - - def forward(self, x): - return self.ln(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, 224, 224), - dtype=torch.float32, - shape_ranges=[(1, 3, 1, 1)], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} - ) + def test_layernorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[(1, 3, 1, 1)], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) if __name__ == "__main__": From 95c1adab0143bf7d2f1aeb99988bde00f54a6be4 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 22:49:21 -0700 Subject: [PATCH 37/39] correcting linting error --- .../fx/test/converters/aten_op/test_layer_norm_aten.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py index e204f4ec8b..6662d91b9a 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -18,7 +18,6 @@ def forward(self, x): TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} ) - def test_layernorm_with_dynamic_shape(self): class TestModule(torch.nn.Module): def __init__(self): From 1a1b809b7b2f90043cbcab0318141f6302057021 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 21 Apr 2023 05:06:07 -0700 Subject: [PATCH 38/39] correcting dynamic shape layer norm --- .../fx/test/converters/aten_op/test_layer_norm_aten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py index 6662d91b9a..fab398ac0f 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -31,12 +31,12 @@ def forward(self, x): InputTensorSpec( shape=(-1, 3, 224, 224), dtype=torch.float32, - shape_ranges=[(1, 3, 1, 1)], + shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], ), ] self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + TestModule(), input_specs, expected_ops={torch.ops.aten.layer_norm.default} ) From 01e5aa14e9d7a37818cf2cdd053a4a00636fb69f Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 24 Apr 2023 10:54:03 -0700 Subject: [PATCH 39/39] removing aten_tracer and lower_basic_pass_aten changes --- .../fx/converters/nn_ops_converters.py | 33 ------------------- .../fx/passes/lower_basic_pass_aten.py | 28 +++------------- .../fx/tracer/dispatch_tracer/aten_tracer.py | 8 ++--- 3 files changed, 7 insertions(+), 62 deletions(-) delete mode 100644 py/torch_tensorrt/fx/converters/nn_ops_converters.py diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py deleted file mode 100644 index 9da6a71bfc..0000000000 --- a/py/torch_tensorrt/fx/converters/nn_ops_converters.py +++ /dev/null @@ -1,33 +0,0 @@ -import numpy as np - -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt -import torch - -from ..converter_registry import tensorrt_converter - -from .converter_utils import mark_as_int8_layer -import activation - - -@tensorrt_converter(torch.nn.functional.relu) -@tensorrt_converter(torch.nn.modules.activation.ReLU) -def relu(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 - return activation.add_relu(network, "tensorrt", kwargs, layer_name) - - -@tensorrt_converter(torch.nn.functional.leaky_relu) -@tensorrt_converter(torch.nn.modules.activation.leaky_relu) -def leaky_relu(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 - return activation.add_leaky_relu(network, "tensorrt", kwargs, layer_name) - - -@tensorrt_converter(torch.nn.modules.activation.Sigmoid) -def sigmoid(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 - return activation.add_sigmoid(network, "tensorrt", kwargs, layer_name) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index 6790962621..00063c3e21 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -416,46 +416,29 @@ def compose_bmm( node = n input_n = node.all_input_nodes[0] other_n = node.all_input_nodes[1] - - # If no input nodes are available, the bmm argument itself could be an input - # Alternatively, if the node has no users, it can be eliminated - if len(input_n.all_input_nodes) == 0 or len(node.users) == 0: - return PassResult(module, modified) - output = next(iter(node.users)) input_input_n = input_n.all_input_nodes[0] if ( input_input_n.target != torch.ops.aten.expand.default and input_n.target != torch.ops.aten.view.default ): - _LOGGER.warn( - "Bmm is addressed in fixed pattern. " - + f"A new pattern {input_input_n.target}, {input_n.target} is met! " - + "Skipping bmm lowering on this operation" + raise RuntimeError( + "Bmm is addressed in fixed pattern. A new pattern is met!" ) - return PassResult(module, modified) - real_input = input_input_n.all_input_nodes[0] input_other_n = other_n.all_input_nodes[0] if ( input_other_n.target != torch.ops.aten.expand.default and other_n.target != torch.ops.aten.view.default ): - _LOGGER.warn( - "Bmm is addressed in fixed pattern. " - + f"A new pattern {input_other_n.target}, {other_n.target} is met! " - + "Skipping bmm lowering on this operation" + raise RuntimeError( + "Bmm is addressed in fixed pattern. A new pattern is met!" ) - return PassResult(module, modified) - real_other = input_other_n.all_input_nodes[0] if len(real_other.meta["val"].size()) == 2: new_func = aten_compose_bmm_2d - elif len(real_other.meta["val"].size()) == 3: + if len(real_other.meta["val"].size()) == 3: new_func = aten_compose_bmm_3d - else: - # No valid bmm replacement exists for the specified dimensions - return PassResult(module, modified) with module.graph.inserting_after(node): new_args = (real_input, real_other) @@ -466,7 +449,6 @@ def compose_bmm( kwargs=None, ) output.replace_all_uses_with(new_node) - modified = True module.graph.eliminate_dead_code() module.recompile() diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index 356ddc978e..e60c8f8d13 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -130,7 +130,7 @@ def trace(f, args, *rest): @req_torch_version("2.dev") -def opt_trace(f, args, perform_trace=True, *rest): +def opt_trace(f, args, *rest): """ Optimized trace with necessary passes which re-compose some ops or replace some ops These passes should be general and functional purpose @@ -148,11 +148,7 @@ def opt_trace(f, args, perform_trace=True, *rest): replace_inplace_ops, # remove it once functionalization is enabled ] - if perform_trace: - fx_module, _ = trace(f, args) - else: - fx_module = f - + fx_module, _ = trace(f, args) print(fx_module.graph) for passes in passes_list: pr: PassResult = passes(fx_module)