diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 77a9b92dfe..a321bb8dfe 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -27,6 +27,9 @@ ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous +from .activation import * +from .operator import * + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -675,159 +678,14 @@ def acc_ops_batch_norm( @tensorrt_converter(acc_ops.layer_norm) -def acc_ops_layer_norm(network, target, args, kwargs, name): - input_val = kwargs["input"] - - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"LayerNorm received input {input_val} that is not part " - "of the TensorRT region!" - ) - - gamma = kwargs["weight"].detach().cpu().float().numpy() - gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) - beta = kwargs["bias"].detach().cpu().float().numpy() - beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) - eps_field = trt.PluginField( - "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32 - ) - try: - normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32) - except TypeError: - _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") - normalized_shape = np.array([], dtype=np.int32) - - normalized_shape_filed = trt.PluginField( - "normalized_shape", normalized_shape, trt.PluginFieldType.INT32 - ) - field_collection = trt.PluginFieldCollection( - [gamma_field, beta_field, eps_field, normalized_shape_filed] - ) - - try: - if network.has_implicit_batch_dimension: - plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") - else: - plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") - except AssertionError: - _LOGGER.error( - "Unable to find layer norm plugin, fall back to TensorRT implementation." - ) - return layer_norm(network, target, args, kwargs, name) - layer = network.add_plugin_v2([input_val], plugin) - layer.name = name - return layer.get_output(0) - - -def layer_norm( +def acc_ops_layer_norm( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"LayerNorm received input {input_val} that is not part " - "of the TensorRT region!" - ) - - shape = kwargs["weight"].shape # type: ignore[union-attr] - broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape - gamma = to_numpy(kwargs["weight"].reshape(*shape)) # type: ignore[union-attr] - beta = to_numpy(kwargs["bias"].reshape(*shape)) # type: ignore[union-attr] - eps = kwargs["eps"] - - axes = 0 - for d in range(len(shape)): - axes |= 1 << (len(input_val.shape) - d - 1) - - # E[x] - mean_expected_layer = network.add_reduce( - input_val, trt.ReduceOperation.AVG, axes, keep_dims=True - ) - set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") - - # X-E[x] - sub_trt = add_binary_elementwise_layer( - network, - input_val, - mean_expected_layer.get_output(0), - trt.ElementWiseOperation.SUB, - target, - f"{name}_sub", - ) - # Variance = mean(pow(x_sub_mean,2)) - pow_tensor = network.add_constant( - (1,) * len(input_val.shape), - trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), - ) - pow_tensor.name = f"{name}_power" - pow_var = add_binary_elementwise_layer( - network, - sub_trt, - pow_tensor.get_output(0), - trt.ElementWiseOperation.POW, - target, - f"{name}_pow_var", - ) - mean_trt_layer = network.add_reduce( - pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True - ) - set_layer_name(mean_trt_layer, target, f"{name}_mean") - # Variance + eps - eps_tensor = network.add_constant( - (1,) * len(input_val.shape), - trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), - ) - eps_tensor.name = f"{name}_eps" - add_trt = add_binary_elementwise_layer( - network, - mean_trt_layer.get_output(0), - eps_tensor.get_output(0), - trt.ElementWiseOperation.SUM, - target, - f"{name}_add", - ) - # SQRT((Var + eps)) - sqrt_trt = add_unary_layer( - network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt" - ) - # (x - E[x]) / sqrt((var + eps)) - div_trt = add_binary_elementwise_layer( - network, - sub_trt, - sqrt_trt, - trt.ElementWiseOperation.DIV, - target, - f"{name}_div_trt", - ) - - assert gamma is not None - gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) # type: ignore[attr-defined] - gamma_tensor.name = f"{name}_gamma" - assert beta is not None - beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] - beta_tensor.name = f"{name}_beta" - # y * gamma + beta - scale_layer = add_binary_elementwise_layer( - network, - div_trt, - gamma_tensor.get_output(0), - trt.ElementWiseOperation.PROD, - target, - f"{name}_scale", - ) - return add_binary_elementwise_layer( - network, - scale_layer, - beta_tensor.get_output(0), - trt.ElementWiseOperation.SUM, - target, - name, - ) + return add_layer_norm(network, target, kwargs, name) @tensorrt_converter(acc_ops.softmax) @@ -838,37 +696,7 @@ def acc_ops_softmax( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"softmax received input {input_val} that is not part " - "of the TensorRT region!" - ) - - # Used to get dim when dim is None. Copied from PyTorch softmax implementation. - def get_softmax_dim(ndim: int) -> int: - if ndim == 0 or ndim == 1 or ndim == 3: - ret = 0 - else: - ret = 1 - return ret - - if kwargs["dim"] is None: - dim = get_softmax_dim(input_ranks) - else: - dim = cast(int, kwargs["dim"]) - - dim = get_positive_dim(dim, input_ranks) - if network.has_implicit_batch_dimension: - assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." - dim -= 1 - - layer = network.add_softmax(input_val) - layer.axes = 1 << dim - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_softmax(network, target, kwargs, name) @tensorrt_converter(acc_ops.tile) @@ -879,103 +707,7 @@ def acc_ops_tile( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_t = kwargs["input"] - input_val = get_trt_tensor(network, input_t, f"{name}_input") - - dims = tuple(cast(Sequence[int], kwargs["dims"])) - n_input_dims = len(input_val.shape) + ( - 1 if network.has_implicit_batch_dimension else 0 - ) - - if len(dims) > n_input_dims: - assert not network.has_implicit_batch_dimension - layer = network.add_shuffle(input_val) - layer.name = f"{name}_reshape" - num_preceding_ones = len(dims) - n_input_dims - - if len(get_dynamic_dims(input_val.shape)) > 1: - input_shape_layer = network.add_shape(input_val) - input_shape_layer.name = f"{name}_input_shape" - preceding_ones = network.add_constant( - (num_preceding_ones,), - np.ascontiguousarray([1] * num_preceding_ones, np.int32), - ).get_output(0) - reshape_layer = network.add_concatenation( - [preceding_ones, input_shape_layer.get_output(0)] - ) - reshape_layer.axis = 0 - reshape_layer.name = f"{name}_reshape_dims" - layer.set_input(1, reshape_layer.get_output(0)) - else: - layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple( - input_val.shape - ) - input_val = layer.get_output(0) - else: - dims = (1,) * (n_input_dims - len(dims)) + dims - - if network.has_implicit_batch_dimension: - assert dims[0] == 1, "Can't tile the batch dim when it's implicit." - dims = dims[1:] - starts = [0] * len(dims) - shapes = [] - if all(isinstance(d, int) for d in dims): - shapes = [i * j for i, j in zip(input_val.shape, dims)] # type: ignore[union-attr] - else: - shape = [] - for i, (s, d) in enumerate(zip(input_val.shape, dims)): - if isinstance(d, TRTTensor) and len(d.shape) == 0: - d = prepend_ones(network, d, f"{name}_{i}", 1) - else: - d = get_trt_tensor(network, d, f"{name}_{i}") - shape.append(d) - mul = add_binary_elementwise_layer( - network, - s, - d, - trt.ElementWiseOperation.PROD, - target, - f"{name}_mul_{i}", - ) - shapes.append(mul) - dims = shape - # If there's dynmaic dim then there would be negative dims in shapes which is not allowed. - # Here we build a dummy shapes array. - if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] - shapes = [1] * len(dims) - strides = [1] * len(dims) - layer = network.add_slice(input_val, starts, shapes, strides) - layer.mode = trt.SliceMode.WRAP - set_layer_name(layer, target, name) - - if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] - starts_tensor = network.add_constant( - (len(dims),), np.ascontiguousarray([0] * len(dims), np.int32) - ).get_output(0) - if all(isinstance(d, int) for d in dims): - dims_tensor = network.add_constant( - (len(dims),), np.ascontiguousarray(dims, np.int32) - ).get_output(0) - else: - assert all(isinstance(d, TRTTensor) for d in dims) - concat_dims_layer = network.add_concatenation(inputs=dims) - concat_dims_layer.axis = 0 - concat_dims_layer.name = f"{name}_tile_dim" - dims_tensor = concat_dims_layer.get_output(0) - input_shape_layer = network.add_shape(input_val) - input_shape_layer.name = f"{name}_slice_input_shape" - slice_shapes_tensor = add_binary_elementwise_layer( - network, - input_shape_layer.get_output(0), - dims_tensor, - trt.ElementWiseOperation.PROD, - target, - f"{name}_slice_shapes", - ) - layer.set_input(1, starts_tensor) - layer.set_input(2, slice_shapes_tensor) - - return layer.get_output(0) + return add_tile(network, target, kwargs, name) @tensorrt_converter(acc_ops.sign) @@ -1004,9 +736,7 @@ def acc_ops_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.RELU - return add_activation_layer(network, input_val, operation_type, target, name) + return add_relu(network, target, kwargs, name) @tensorrt_converter(acc_ops.leaky_relu) @@ -1017,12 +747,7 @@ def acc_ops_leaky_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - negative_slope = kwargs["negative_slope"] - operation_type = trt.ActivationType.LEAKY_RELU - return add_activation_layer( - network, input_val, operation_type, target, name, negative_slope - ) + return add_leaky_relu(network, target, kwargs, name) @tensorrt_converter(acc_ops.elu) @@ -1033,10 +758,7 @@ def acc_ops_elu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - alpha = kwargs["alpha"] - operation_type = trt.ActivationType.ELU - return add_activation_layer(network, input_val, operation_type, target, name, alpha) + return add_elu(network, target, kwargs, name) @tensorrt_converter(acc_ops.selu) @@ -1047,9 +769,7 @@ def acc_ops_selu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.SELU - return add_activation_layer(network, input_val, operation_type, target, name) + return add_selu(network, target, kwargs, name) @tensorrt_converter(acc_ops.softsign) @@ -1060,9 +780,7 @@ def acc_ops_softsign( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.SOFTSIGN - return add_activation_layer(network, input_val, operation_type, target, name) + return add_softsign(network, target, kwargs, name) @tensorrt_converter(acc_ops.sin) @@ -1138,9 +856,7 @@ def acc_ops_tanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.TANH - return add_activation_layer(network, input_val, operation_type, target, name) + return add_tanh(network, target, kwargs, name) @tensorrt_converter(acc_ops.asin) @@ -1216,9 +932,7 @@ def acc_ops_sqrt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.UnaryOperation.SQRT - return add_unary_layer(network, input_val, operation_type, target, name) + return add_sqrt(network, target, kwargs, name) @tensorrt_converter(acc_ops.reciprocal) @@ -1458,14 +1172,7 @@ def acc_ops_maximum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.MAX, - target, - name, - ) + return add_maximum(network, target, kwargs, name) @tensorrt_converter(acc_ops.minimum) @@ -1476,14 +1183,7 @@ def acc_ops_minimum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.MIN, - target, - name, - ) + return add_minimum(network, target, kwargs, name) @tensorrt_converter(acc_ops.dtype) @@ -1536,12 +1236,7 @@ def acc_ops_logical_not( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.UnaryOperation.NOT - # cast to bool type - if input_val.dtype in (trt.float32, trt.float16, trt.int32): - input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool) - return add_unary_layer(network, input_val, operation_type, target, name) + return add_logical_not(network, target, kwargs, name) @tensorrt_converter(acc_ops.logical_and, no_implicit_batch_dim=True) @@ -1553,43 +1248,7 @@ def acc_ops_logical_and( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `logical_and` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - # we only support both inputs are bool type - if target == acc_ops.bitwise_and: - - def check_is_bool(input_t): - if isinstance(input_t, TRTTensor): - assert ( - input_t.dtype == trt.bool - ), "We currently do not support input is non-bool" - elif isinstance(input_t, torch.Tensor): - assert ( - input_t.dtype == torch.bool - ), "We currently do not support input is non-bool" - else: - assert isinstance( - input_t.bool - ), "We currently do not support input is non-bool" - - check_is_bool(input_t) - check_is_bool(other_t) - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - if input_t.dtype != trt.bool: - input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool) - if other_t.dtype != trt.bool: - other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.AND, target, name - ) + return add_logical_and(network, target, kwargs, name) @tensorrt_converter(acc_ops.ne, no_implicit_batch_dim=True) @@ -1600,23 +1259,7 @@ def acc_ops_ne( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `ne` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - eq_t = add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name - ) - - return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) + return add_ne(network, target, kwargs, name) @tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True) @@ -1627,21 +1270,7 @@ def acc_ops_eq( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `eq` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name - ) + return add_eq(network, target, kwargs, name) @tensorrt_converter(acc_ops.gt, no_implicit_batch_dim=True) @@ -1652,21 +1281,7 @@ def acc_ops_gt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `gt` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name - ) + return add_gt(network, target, kwargs, name) @tensorrt_converter(acc_ops.lt, no_implicit_batch_dim=True) @@ -1677,21 +1292,7 @@ def acc_ops_lt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `le` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - - input_t = get_trt_tensor(network, input_t, f"{name}_input_t") - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - - input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name - ) + return add_lt(network, target, kwargs, name) @tensorrt_converter(acc_ops.logical_or, no_implicit_batch_dim=True) @@ -1702,33 +1303,7 @@ def acc_ops_logical_or( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `logical_or` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - if isinstance(other_t, (torch.Tensor, bool)): - if isinstance(other_t, bool): - other_t = int(other_t) - elif other_t.dtype == torch.bool: - other_t = other_t.to(torch.int32) - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - if input_t.dtype != trt.bool: - layer_i = network.add_identity(input_t) - layer_i.set_output_type(0, trt.bool) - set_layer_name(layer_i, target, f"{name}_input_dtype_change") - input_t = layer_i.get_output(0) - if other_t.dtype != trt.bool: - layer_o = network.add_identity(other_t) - layer_o.set_output_type(0, trt.bool) - set_layer_name(layer_o, target, f"{name}_other_dtype_change") - other_t = layer_o.get_output(0) - - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.OR, target, name - ) + return add_logical_or(network, target, kwargs, name) @tensorrt_converter(acc_ops.logical_xor, no_implicit_batch_dim=True) @@ -1739,33 +1314,7 @@ def acc_ops_logical_xor( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if network.has_implicit_batch_dimension: - raise RuntimeError( - "The `logical_xor` function should be called with explicit batch dimension." - ) - - input_t = kwargs["input"] - other_t = kwargs["other"] - if isinstance(other_t, (torch.Tensor, bool)): - if isinstance(other_t, bool): - other_t = int(other_t) - elif other_t.dtype == torch.bool: - other_t = other_t.to(torch.int32) - other_t = get_trt_tensor(network, other_t, f"{name}_other_t") - if input_t.dtype != trt.bool: - layer_i = network.add_identity(input_t) - layer_i.set_output_type(0, trt.bool) - set_layer_name(layer_i, target, f"{name}_input_dtype_change") - input_t = layer_i.get_output(0) - if other_t.dtype != trt.bool: - layer_o = network.add_identity(other_t) - layer_o.set_output_type(0, trt.bool) - set_layer_name(layer_o, target, f"{name}_other_dtype_change") - other_t = layer_o.get_output(0) - - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name - ) + return add_logical_xor(network, target, kwargs, name) # T113156424 Have some accuracy problems in hf_T5. @@ -1859,27 +1408,7 @@ def acc_ops_fmod( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it - trunc_div_value = trunc_div( - kwargs["input"], kwargs["other"], network, target, name + "_trunc_div" - ) - prod_value = add_binary_elementwise_layer( - network, - trunc_div_value, - kwargs["other"], - trt.ElementWiseOperation.PROD, - target, - name + "_prod", - ) - sub_value = add_binary_elementwise_layer( - network, - kwargs["input"], - prod_value, - trt.ElementWiseOperation.SUB, - target, - name + "_sub", - ) - return sub_value + return add_fmod(network, target, kwargs, name) # T113156424 embedding implemenatation is very limited and shows no usage in hf models due to the indices are int64. @@ -2064,40 +1593,7 @@ def acc_ops_squeeze( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"squeeze received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) - # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic - # dim, which is a very rare case. For now we just claim not supporting dim=None. - assert dim is not None, "We don't support dim=None right now for squeeze." - - dim = get_positive_dim( - dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - ) - if network.has_implicit_batch_dimension: - assert dim != 0, "We don't support squeeze batch dim when it's implicit." - dim -= 1 - - assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." - assert ( - len(get_dynamic_dims(input_val.shape)) <= 1 - ), "Currently more than one dynamic dim for input to squeeze is not supported." - - output_shape = [] - for i, s in enumerate(input_val.shape): - if i == dim and s == 1: - continue - output_shape.append(s) - layer = network.add_shuffle(input_val) - layer.reshape_dims = tuple(output_shape) - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_squeeze(network, target, kwargs, name) @tensorrt_converter(acc_ops.add) @@ -2108,14 +1604,7 @@ def acc_ops_add( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.SUM, - target, - name, - ) + return add_add(network, target, kwargs, name) @tensorrt_converter(acc_ops.sub) @@ -2126,14 +1615,7 @@ def acc_ops_sub( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.SUB, - target, - name, - ) + return add_sub(network, target, kwargs, name) @tensorrt_converter(acc_ops.div) @@ -2144,14 +1626,7 @@ def acc_ops_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.DIV, - target, - name, - ) + return add_div(network, target, kwargs, name) @tensorrt_converter(acc_ops.floor_div) @@ -2162,14 +1637,7 @@ def acc_ops_floor_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.FLOOR_DIV, - target, - name, - ) + return add_floor_div(network, target, kwargs, name) @tensorrt_converter(acc_ops.trunc_div) @@ -2180,7 +1648,7 @@ def acc_ops_trunc_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return trunc_div(kwargs["input"], kwargs["other"], network, target, name) + return add_trunc_div(network, target, kwargs, name) @tensorrt_converter(acc_ops.mul) @@ -2191,14 +1659,7 @@ def acc_ops_mul( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.PROD, - target, - name, - ) + return add_mul(network, target, kwargs, name) @tensorrt_converter(acc_ops.pow) @@ -2209,14 +1670,7 @@ def acc_ops_pow( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( - network, - kwargs["input"], - kwargs["exponent"], - trt.ElementWiseOperation.POW, - target, - name, - ) + return add_pow(network, target, kwargs, name) @tensorrt_converter(acc_ops.unsqueeze) @@ -2487,52 +1941,7 @@ def acc_ops_slice_tensor( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"slice_tensor received input {input_val} that is not part " - "of the TensorRT region!" - ) - - ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) - dynamic_shape = has_dynamic_shape(input_val.shape) - if network.has_implicit_batch_dimension: - if dim == 0: - raise RuntimeError( - f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" - ) - dim = dim - 1 - else: - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - - start_int = cast(int, kwargs["start"]) - stop_int = cast(int, kwargs["stop"]) - step_int = cast(int, kwargs["step"]) - start = [0] * len(input_val.shape) - start[dim] = start_int - stride = [1] * len(start) - stride[dim] = step_int - output_shape = list(input_val.shape) - output_shape[dim] = (stop_int - start_int) // step_int - - if dynamic_shape > 0: - output_shape = get_shape_with_dynamic_shape( - network, output_shape, input_val, target, name - ) - layer = network.add_slice( - input_val, - start=start, - shape=[] if dynamic_shape else output_shape, - stride=stride, - ) - if dynamic_shape: - layer.set_input(2, output_shape) - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_slice(network, target, kwargs, name) @tensorrt_converter(acc_ops.expand) @@ -2543,28 +1952,7 @@ def acc_ops_expand_tensor( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_t = kwargs["input"] - shape = list(kwargs["sizes"]) - - input_val = get_trt_tensor(network, input_t, f"{name}_input") - - if network.has_implicit_batch_dimension: - shape = shape[1:] - - ranks = len(input_val.shape) - # TRT does not support different dimension size - assert len(shape) == ranks - shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)] - - inshape = tuple(input_val.shape) - shape = tuple(shape) - start = tuple([0] * ranks) - stride = tuple( - [int(i == o) for i, o in zip(inshape, shape)] - ) # stride == 1 if dimensions match, 0 otherwise - layer = network.add_slice(input_val, start=start, shape=shape, stride=stride) - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_expand(network, target, kwargs, name) @tensorrt_converter(acc_ops.where) @@ -2575,89 +1963,7 @@ def acc_ops_where( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - - condition_t = kwargs["condition"] - x_t = kwargs["x"] - y_t = kwargs["y"] - - if type(x_t) != TRTTensor: - assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" - - if type(y_t) != TRTTensor: - assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" - - # get output shape - - x_shape = list(x_t.shape) - y_shape = list(y_t.shape) - condition_shape = list(condition_t.shape) - output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) - - # expand shape - if type(condition_t) != TRTTensor: - assert condition_t.dtype == torch.bool, "condition dtype is not bool" - if condition_shape != output_shape: - condition_t.expand(output_shape) - condition_t = condition_t.to(torch.int32) - condition_const = get_trt_tensor(network, condition_t, f"{name}_condition") - condition_layer = network.add_identity(condition_const) - condition_layer.set_output_type(0, trt.bool) - set_layer_name(condition_layer, target, f"{name}_condition") - condition_val = condition_layer.get_output(0) - else: - assert condition_t.dtype == trt.bool, "mask dtype is not bool!" - if condition_shape != output_shape: - condition_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": condition_t, "sizes": output_shape}, - name=f"{name}_expand", - ) - else: - condition_val = condition_t - - if type(x_t) != TRTTensor: - if x_shape != output_shape: - # special case where 1 element in x_t - if len(x_t.shape) == 0: - x_t = x_t.unsqueeze(0) - x_t = x_t.expand(output_shape) - x_val = get_trt_tensor(network, x_t, f"{name}_x") - else: - x_val = x_t - if x_shape != output_shape: - x_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": x_val, "sizes": output_shape}, - name=f"{name}_x_expand", - ) - - if type(y_t) != TRTTensor: - if y_shape != output_shape: - # special case where 1 element in y_t - if len(y_t.shape) == 0: - y_t = y_t.unsqueeze(0) - y_t = y_t.expand(output_shape) - y_val = get_trt_tensor(network, y_t, f"{name}_y") - else: - y_val = y_t - if y_shape != output_shape: - y_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": y_val, "sizes": output_shape}, - name=f"{name}_y_expand", - ) - - select_layer = network.add_select(condition_val, x_val, y_val) - - set_layer_name(select_layer, target, f"{name}_select") - - return select_layer.get_output(0) + return add_where(network, target, kwargs, name) @tensorrt_converter(acc_ops.masked_fill, no_implicit_batch_dim=True) @@ -2786,58 +2092,7 @@ def acc_ops_linear( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Linear received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dynamic_dims = get_dynamic_dims(input_val.shape) - assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, ( - "Currently we only support one dynmaic " - "dim for linear and it can't be the last dim." - ) - - if isinstance(kwargs["weight"], torch.Tensor): - weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight") - if target not in (acc_ops.linear, torch.ops.aten.linear): - weight_op = trt.MatrixOperation.TRANSPOSE - else: - weight_op = trt.MatrixOperation.NONE - else: - assert isinstance( - kwargs["weight"], TRTTensor - ), f"Expect weight to be trt tensor but got {type(kwargs['weight'])}" - weight = kwargs["weight"] - weight_op = trt.MatrixOperation.TRANSPOSE - - preset_diff = 0 - if len(input_val.shape) == 1: - preset_diff -= 1 - input_op = trt.MatrixOperation.VECTOR - else: - input_op = trt.MatrixOperation.NONE - - input_val, weight = broadcast( - network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff - ) - matmul_layer = network.add_matrix_multiply(input_val, input_op, weight, weight_op) - set_layer_name(matmul_layer, target, f"{name}_matmul") - res = matmul_layer.get_output(0) - - if kwargs["bias"] is not None: - bias = get_trt_tensor(network, kwargs["bias"], f"{name}_bias") # type: ignore[arg-type] - res = add_binary_elementwise_layer( - network, - matmul_layer.get_output(0), - bias, - trt.ElementWiseOperation.SUM, - target, - f"{name}_add", - ) - return res + return add_linear(network, target, kwargs, name) def add_clamp(network, input, val, op, name): @@ -3091,34 +2346,7 @@ def acc_ops_matmul( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input") - other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other") - - for i in [input_val, other_val]: - if not isinstance(i, TRTTensor): - raise RuntimeError( - f"matmul received input {i} that is not part of the TensorRT region!" - ) - - input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE - preset_diff = 0 - - if len(input_val.shape) == 1: - preset_diff -= 1 - input_matrix_op = trt.MatrixOperation.VECTOR - - if len(other_val.shape) == 1: - preset_diff += 1 - other_matrix_op = trt.MatrixOperation.VECTOR - - input_val, other_val = broadcast( - network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff - ) - layer = network.add_matrix_multiply( - input_val, input_matrix_op, other_val, other_matrix_op - ) - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_matmul(network, target, kwargs, name) @tensorrt_converter(acc_ops.hardsigmoid) @@ -3129,23 +2357,7 @@ def acc_ops_hard_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Hard sigmoid received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_activation_layer( - network, - input_val, - trt.ActivationType.HARD_SIGMOID, - target, - name, - alpha=1 / 6, - beta=0.5, - ) + return add_hard_sigmoid(network, target, kwargs, name) @tensorrt_converter(acc_ops.sigmoid) @@ -3156,17 +2368,7 @@ def acc_ops_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Sigmoid received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_activation_layer( - network, input_val, trt.ActivationType.SIGMOID, target, name - ) + return add_sigmoid(network, target, kwargs, name) @tensorrt_converter(acc_ops.permute) @@ -3367,33 +2569,7 @@ def acc_ops_gelu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - approximate = kwargs["approximate"] - if approximate != "none": - raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"GELU received input {input_val} that is not part " - "of the TensorRT region!" - ) - if network.has_implicit_batch_dimension: - raise RuntimeError( - "GeLU converter currently doesn't support implicit batch dimension" - ) - - plugin_name = "CustomGeluPluginDynamic" - # type_id 0 for float32, 1 for float16 - type_id = trt.PluginField( - "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 - ) - field_collection = TRTPluginFieldCollection([type_id]) - plugin_version = "1" - - plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) - - layer = network.add_plugin_v2([input_val], plugin) - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_gelu(network, target, kwargs, name) @tensorrt_converter(acc_ops.chunk) @@ -3404,62 +2580,7 @@ def acc_ops_chunk( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - chunks = cast(int, kwargs["chunks"]) - dim = cast(int, kwargs["dim"]) - input_dim_size = len(input_val.shape) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"chunk received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dynamic_shape = has_dynamic_shape(input_val.shape) - if network.has_implicit_batch_dimension: - input_dim_size += 1 - dim = get_positive_dim(dim, input_dim_size) - assert dim != 0, "Can't chunk on batch dim when it's implicit!" - dim -= 1 - else: - if dynamic_shape: - assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - dim = get_positive_dim(dim, input_dim_size) - - if chunks > input_val.shape[dim]: - warnings.warn( - f"Asked for {chunks} chunks along dimention " - f"{dim} on tensor with size {input_val.shape}, chunks " - f"will default to {input_val.shape[dim]}", - RuntimeWarning, - ) - chunks = input_val.shape[dim] - - start = [0] * len(input_val.shape) - stride = [1] * len(start) - offset = 0 - split_size = (input_val.shape[dim] + chunks - 1) // chunks - - max_offset = input_val.shape[dim] - # add slice layers - output = [] - for i in range(chunks): - shape = list(input_val.shape) - shape[dim] = min(split_size, max_offset - offset) - if dynamic_shape: - shape = get_shape_with_dynamic_shape( - network, shape, input_val, target, f"{name}_{i}" - ) - start[dim] = offset - layer = network.add_slice( - input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride - ) - if dynamic_shape: - layer.set_input(2, shape) - offset += split_size - set_layer_name(layer, target, f"{name}_{i}") - output.append(layer.get_output(0)) - return output + return add_chunk(network, target, kwargs, name) @tensorrt_converter(acc_ops.cumsum, no_implicit_batch_dim=True) @@ -3470,75 +2591,7 @@ def acc_ops_cumsum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - dim = cast(int, kwargs["dim"]) - input_shape = input_val.shape # type: ignore[union-attr] - input_dim_size = len(input_val.shape) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"cumsum received input {input_val} that is not part " - "of the TensorRT region!" - ) - if network.has_implicit_batch_dimension: - raise RuntimeError( - "cumsum converter currently doesn't support implicit batch dimension" - ) - dim = get_positive_dim(dim, input_dim_size) - loop = network.add_loop() - trip_limit = None - if input_shape[dim] > 0: - axis = torch.tensor(input_shape[dim], dtype=torch.int32) - trip_limit_layer = network.add_constant(axis.shape, to_numpy(axis)) - else: - input_shape = network.add_shape(input_val).get_output(0) - dim_value = torch.tensor(dim, dtype=torch.int32) - axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0) - trip_limit_layer = network.add_gather(input_shape, axis, 0) - set_layer_name(trip_limit_layer, target, f"{name}_trip_limit") - trip_limit = trip_limit_layer.get_output(0) - - loop.add_trip_limit(trip_limit, trt.TripLimit(0)) - iterator = loop.add_iterator(input_val, dim, False) - data = iterator.get_output(0) - new_dims = tuple(data.shape) - zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype)) - zero_tensor = network.add_constant( - zero_tensor.shape, to_numpy(zero_tensor) - ).get_output(0) - - running_sum = loop.add_recurrence(zero_tensor) - set_layer_name(running_sum, target, f"{name}_running_sum_1") - running_sum_tensor = running_sum.get_output(0) - - current_sum = add_binary_elementwise_layer( - network, - data, - running_sum_tensor, - trt.ElementWiseOperation.SUM, - target, - f"{name}_sum_1", - ) - running_sum.set_input(1, current_sum) - - running_sum = loop.add_recurrence(zero_tensor) - set_layer_name(running_sum, target, f"{name}_running_sum_2") - running_sum_tensor = running_sum.get_output(0) - - current_sum = add_binary_elementwise_layer( - network, - data, - running_sum_tensor, - trt.ElementWiseOperation.SUM, - target, - f"{name}_sum_2", - ) - running_sum.set_input(1, current_sum) - - loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim) - set_layer_name(loop_output, target, f"{name}_loop_output") - loop_output.set_input(1, trip_limit) - return loop_output.get_output(0) + return add_cumsum(network, target, kwargs, name) @tensorrt_converter(acc_ops.hardtanh) @@ -3549,23 +2602,7 @@ def acc_ops_hardtanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"hardtanh received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_activation_layer( - network, - input_val, - trt.ActivationType.CLIP, - target, - name, - alpha=kwargs["min_val"], - beta=kwargs["max_val"], - ) + return add_hard_tanh(network, target, kwargs, name) @tensorrt_converter(acc_ops.interpolate) diff --git a/py/torch_tensorrt/fx/converters/activation.py b/py/torch_tensorrt/fx/converters/activation.py index a7ab25152c..e27da49f1d 100644 --- a/py/torch_tensorrt/fx/converters/activation.py +++ b/py/torch_tensorrt/fx/converters/activation.py @@ -1,76 +1,197 @@ import numpy as np +import operator +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch +from torch.fx.node import Argument, Target -from ..converter_registry import tensorrt_converter +from ..utils import torch_dtype_from_trt from .converter_utils import mark_as_int8_layer +from .converter_utils import set_layer_name +from .converter_utils import get_trt_plugin + +from ..types import ( + Shape, + TRTDataType, + TRTElementWiseOp, + TRTLayer, + TRTNetwork, + TRTPlugin, + TRTPluginFieldCollection, + TRTTensor, +) + + +def add_activation_layer( + network: TRTNetwork, + input_val: TRTTensor, + operation_type: trt.ActivationType, + target: Target, + name: str, + alpha: Optional[Any] = None, + beta: Optional[Any] = None, + dyn_range_fn: Optional[Callable[[float, float], Any]] = None, +) -> TRTTensor: + """ + Add a TensorRT Activation layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): Input to the activation op. + Must be a TensorRT tensor. + op_type (trt.ElementWiseOperation): Type of the TensorRT activation + operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + alpha (Optional[Any]): If not None, we will use it to set the alpha + attribute of the created TensorRT activation layer. + beta (Optional[Any]): If not None, we will use it to set the beta + attribute of the created TensorRT activation layer. + dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range + + + Returns: + The output of TensorRT Activation layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_activation(input_val, operation_type) + if alpha is not None: + layer.alpha = alpha + if beta is not None: + layer.beta = beta + set_layer_name(layer, target, name) + + if input_val.dynamic_range is not None: + dyn_range = dyn_range_fn(input_val.dynamic_range) + mark_as_int8_layer(layer, dyn_range) + return layer.get_output(0) -def common_activation( - network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name -): - layer = network.add_activation(input=input_val, type=activation_type) - layer.name = layer_name +def add_relu(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.RELU + return add_activation_layer(network, input_val, operation_type, target, name) - if input_val.dynamic_range: - dyn_range = activation_dyn_range_fn(input_val.dynamic_range) - mark_as_int8_layer(layer, dyn_range) - return layer.get_output(0) +def add_leaky_relu(network, target, kwargs, name): + input_val = kwargs["input"] + negative_slope = kwargs["negative_slope"] + operation_type = trt.ActivationType.LEAKY_RELU + return add_activation_layer( + network, input_val, operation_type, target, name, negative_slope + ) -@tensorrt_converter(torch.nn.functional.relu) -@tensorrt_converter(torch.nn.modules.activation.ReLU) -def relu(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 +def add_elu(network, target, kwargs, name): input_val = kwargs["input"] + alpha = kwargs["alpha"] + operation_type = trt.ActivationType.ELU + return add_activation_layer(network, input_val, operation_type, target, name, alpha) + - if not isinstance(input_val, trt.tensorrt.ITensor): +def add_selu(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.SELU + return add_activation_layer(network, input_val, operation_type, target, name) + + +def add_softsign(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.SOFTSIGN + return add_activation_layer(network, input_val, operation_type, target, name) + + +def add_tanh(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.ActivationType.TANH + return add_activation_layer(network, input_val, operation_type, target, name) + + +def add_gelu(network, target, kwargs, name): + input_val = kwargs["input"] + if "approximate" in kwargs.keys(): + approximate = kwargs["approximate"] + if approximate != "none": + raise RuntimeError( + "GeLU converter currently doesn't support fast gelu compute" + ) + if not isinstance(input_val, TRTTensor): raise RuntimeError( - f"ReLU received input {input_val} that is not part " + f"GELU received input {input_val} that is not part " "of the TensorRT region!" ) + if network.has_implicit_batch_dimension: + raise RuntimeError( + "GeLU converter currently doesn't support implicit batch dimension" + ) + + plugin_name = "CustomGeluPluginDynamic" + # type_id 0 for float32, 1 for float16 + type_id = trt.PluginField( + "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 + ) + field_collection = TRTPluginFieldCollection([type_id]) + plugin_version = "1" + + plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) + + layer = network.add_plugin_v2([input_val], plugin) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_hard_sigmoid(network, target, kwargs, name): + input_val = kwargs["input"] - def activation_dyn_range_fn(dyn_range): - return max(0, dyn_range[0]), max(0, dyn_range[1]) + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"Hard sigmoid received input {input_val} that is not part " + "of the TensorRT region!" + ) - return common_activation( + return add_activation_layer( network, - submod, input_val, - trt.ActivationType.RELU, - activation_dyn_range_fn, - layer_name, + trt.ActivationType.HARD_SIGMOID, + target, + name, + alpha=1 / 6, + beta=0.5, ) -@tensorrt_converter(torch.nn.modules.activation.Sigmoid) -def sigmoid(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 +def add_sigmoid(network, target, kwargs, name): input_val = kwargs["input"] - if not isinstance(input_val, trt.tensorrt.ITensor): + if not isinstance(input_val, TRTTensor): raise RuntimeError( f"Sigmoid received input {input_val} that is not part " "of the TensorRT region!" ) - def activation_dyn_range_fn(dyn_range): - def sigmoid_fn(x): - return 1 / (1 + np.exp(-x)) + return add_activation_layer( + network, input_val, trt.ActivationType.SIGMOID, target, name + ) - return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1]) - return common_activation( - network, - submod, - input_val, - trt.ActivationType.SIGMOID, - activation_dyn_range_fn, - layer_name, +def add_hard_tanh(network, target, kwargs, name): + input_val = kwargs["input"] + alpha = kwargs["min_val"] + beta = kwargs["max_val"] + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"hardtanh received input {input_val} that is not part " + "of the TensorRT region!" + ) + operation_type = trt.ActivationType.CLIP + return add_activation_layer( + network, input_val, operation_type, target, name, alpha, beta ) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index c86f2bd228..b6e770cf0a 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -21,8 +21,12 @@ from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt from .converter_utils import * # noqa: F403 +from .activation import * +from .operator import * + import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils + _LOGGER: logging.Logger = logging.getLogger(__name__) ## converter list in alphabetic order @@ -38,7 +42,7 @@ def aten_ops_add( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name) + return add_add(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default) @@ -154,15 +158,11 @@ def aten_ops_div( } rounding_mode = kwargs.get("rounding_mode") if rounding_mode is None: - return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name) + return add_div(network, target, kwargs_new, name) elif rounding_mode == "floor": - return acc_ops_converters.acc_ops_floor_div( - network, target, None, kwargs_new, name - ) + return add_floor_div(network, target, kwargs_new, name) elif rounding_mode == "trunc": - return acc_ops_converters.acc_ops_trunc_div( - network, target, None, kwargs_new, name - ) + return add_trunc_div(network, target, kwargs_new, name) else: raise RuntimeError( f"Target {target} does not support rounding mode {rounding_mode}" @@ -181,7 +181,7 @@ def aten_ops_floor_div( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_floor_div(network, target, None, kwargs_new, name) + return add_floor_div(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.fmod.Scalar) @@ -197,7 +197,7 @@ def aten_ops_fmod( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name) + return add_fmod(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.linear) @@ -214,7 +214,7 @@ def aten_ops_linear( "bias": args[2], } - return acc_ops_converters.acc_ops_linear(network, target, None, kwargs_new, name) + return add_linear(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.max_pool3d) @@ -263,7 +263,23 @@ def aten_ops_mul( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_mul(network, target, None, kwargs_new, name) + return add_mul(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.matmul) +@tensorrt_converter(torch.ops.aten.mm.default) +def aten_ops_matmul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return add_matmul(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) @@ -279,7 +295,7 @@ def aten_ops_pow( "input": args[0], "exponent": args[1], } - return acc_ops_converters.acc_ops_pow(network, target, None, kwargs_new, name) + return add_pow(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.relu.default) @@ -293,7 +309,7 @@ def aten_ops_relu( kwargs_new = { "input": args[0], } - return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name) + return add_relu(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.sub.Tensor) @@ -308,7 +324,7 @@ def aten_ops_sub( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name) + return add_sub(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.view.default) @@ -375,9 +391,7 @@ def aten_ops_expand( "input": args[0], "sizes": args[1], } - return acc_ops_converters.acc_ops_expand_tensor( - network, target, None, kwargs_new, name - ) + return add_expand(network, target, kwargs_new, name) @tensorrt_converter(operator.floordiv) @@ -392,7 +406,7 @@ def aten_ops_operator_floordiv( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_floor_div(network, target, None, kwargs_new, name) + return add_floor_div(network, target, kwargs_new, name) @tensorrt_converter(operator.mul) @@ -407,7 +421,7 @@ def aten_ops_operator_mul( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_mul(network, target, None, kwargs_new, name) + return add_mul(network, target, kwargs_new, name) @tensorrt_converter(operator.add) @@ -422,7 +436,7 @@ def aten_ops_operator_add( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name) + return add_add(network, target, kwargs_new, name) @tensorrt_converter(operator.sub) @@ -437,7 +451,7 @@ def aten_ops_operator_sub( "input": args[0], "other": args[1], } - return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name) + return add_sub(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.sym_numel) @@ -479,3 +493,242 @@ def aten_ops_sym_size( ) set_layer_name(slice_layer, target, "_slice_layer") return slice_layer.get_output(0) + + +@tensorrt_converter(torch.ops.aten.leaky_relu.default) +def aten_ops_leaky_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = {"input": args[0], "negative_slope": args[1]} + return add_leaky_relu(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.elu.default) +def aten_ops_elu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if len(args) > 2: + kwargs_new = { + "input": args[0], + } + return add_selu(network, target, kwargs_new, name) + else: + kwargs_new = { + "input": args[0], + "alpha": args[1], + } + return add_elu(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.gelu.default) +def aten_ops_gelu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return add_gelu(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.tanh.default) +def aten_ops_tanh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return add_tanh(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.hardtanh.default) +def aten_ops_hard_tanh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "min_val": args[1], + "max_val": args[2], + } + return add_hard_tanh(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.sigmoid.default) +def aten_ops_sigmoid( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return add_sigmoid(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.select.int) +def aten_ops_select( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + "index": args[2], + } + return add_select(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.slice.Tensor) +def aten_ops_slice( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + "start": args[2], + "stop": args[3], + "step": args[4], + } + return add_slice(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.matmul) +@tensorrt_converter(torch.ops.aten.mm.default) +def aten_ops_matmul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return add_matmul(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.layer_norm.default) +def aten_ops_layernorm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "normalized_shape": args[1], + "weight": args[2], + "bias": args[3], + "eps": args[4], + } + return add_layer_norm(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten._softmax.default) +def aten_ops_softmax( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + } + return add_softmax(network, target, kwargs_new, name) + + +# FIXME: need to look at case where dim is tuple +@tensorrt_converter(torch.ops.aten.squeeze.dim) +@tensorrt_converter(torch.ops.aten.squeeze.dims) +def aten_ops_squeeze( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + } + return add_squeeze(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.where.self) +def aten_ops_where( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "condition": args[0], + "x": args[1], + "y": args[2], + } + return add_where(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.rsub.Tensor) +def aten_ops_rsub( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if "alpha" in kwargs: + alpha = kwargs["alpha"] + kwargs_new = { + "input": args[0], + "other": args[1], + "alpha": alpha, + } + return add_rsub(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.rsqrt.default) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return add_rsqrt(network, target, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 17a0cef456..2d79014ebd 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -77,7 +77,7 @@ def get_positive_dim(dim: int, dim_size: int) -> int: return dim -def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None: +def set_layer_name(layer: TRTLayer, target: Target, name: str, is_acc=True) -> None: """ Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]" @@ -87,7 +87,13 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None: the node represents. name (str): Consists of fx node.name with optional suffix. """ - target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}" + target_name = ( + target + if isinstance(target, str) + else f"acc_ops.{target.__name__}" + if is_acc + else f"aten_ops.{target.__name__}" + ) layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]" @@ -120,30 +126,6 @@ def extend_mod_attr_to_tuple(mod: torch.nn.Module, name: str, size: int): return extend_attr_to_tuple(val, size) -def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: - """ - Convert a PyTorch Tensor to a Numpy Array. If the tensor is - quantized it will be dequantized first. - - Args: - tensor (Optional[torch.Tensor]): A PyTorch tensor or None. - - Returns: - A Numpy array. - """ - - if tensor is None: - return tensor - - assert isinstance( - tensor, torch.Tensor - ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}" - if tensor.is_quantized: - tensor = tensor.dequantize() - - return tensor.cpu().detach().contiguous().numpy() - - def has_dynamic_shape(shape: Shape) -> bool: """ Determine if the given shape has dynamic dim. i.e. if there're -1 in shape. @@ -312,6 +294,39 @@ def prepend_ones( return layer.get_output(0) +def broadcastable( + a: TRTTensor, + b: TRTTensor, +) -> bool: + "Check if two tensors are broadcastable according to torch rules" + a_shape = tuple(a.shape) + b_shape = tuple(b.shape) + print("a shape is", a_shape) + print("b shape is", b_shape) + # check from the trailing + diff = len(a_shape) - len(b_shape) + if diff == 0: + return True + if diff > 0: + max = len(a_shape) + min = len(b_shape) + greater_tensor = a_shape + lesser_tensor = b_shape + elif diff < 0: + max = len(b_shape) + min = len(a_shape) + greater_tensor = b_shape + lesser_tensor = a_shape + j = min - 1 + for i in range(max - 1, diff - 1, -1): + if not ( + greater_tensor[i] != lesser_tensor[j] + and (greater_tensor[i] == 1 or lesser_tensor[i] == 1) + ): + return False + return True + + def broadcast( network: TRTNetwork, a: TRTTensor, @@ -352,171 +367,6 @@ def broadcast( return a, b -def get_shape_with_dynamic_shape( - network: TRTNetwork, - shape: Union[list, tuple, torch.Tensor], - input_val: TRTTensor, - target: Target, - name: str, -) -> TRTTensor: - """ - Prepare the real output tensor shape for dynamic shape mode tensor input. - How this functions works: - Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation - output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual - reduce operation output shape. Steps of calculations are: - 1. get the actual tensor shape of input_val via add_shape layer; - 2. create a all 0 tensor [0, 0, 0]; - 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; - 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace - all -1 dynamic shape dimensions with actual batch_size value; - 5. output shape with actual batch_size as [2048, 128, 256] - - Args: - network (TRTNetwork): TensorRT network object. - shape: calculated shape of the expected output tensor - input_val (TRTTensor): A TensorRT ITensor. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - Returns: - TensorRT ITensors that represents the actual shape of the input_val - """ - # Ger real shape info for input_val - input_shape = network.add_shape(input_val).get_output(0) - - scale_layer = network.add_constant( - input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) - ) - set_layer_name(scale_layer, target, f"{name}_scale") - scale_res = scale_layer.get_output(0) - - length = input_shape.shape[0] - zero_layer = network.add_constant( - input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) - ) - set_layer_name(zero_layer, target, f"{name}_zeros") - - condition_val = add_binary_elementwise_layer( - network, - scale_res, - zero_layer.get_output(0), - trt.ElementWiseOperation.LESS, - target, - f"{name}_shape", - ) - select_layer = network.add_select(condition_val, input_shape, scale_res) - set_layer_name(select_layer, target, f"{name}_select") - return select_layer.get_output(0) - - -def add_binary_elementwise_layer( - network: TRTNetwork, - lhs_val: Union[int, float, TRTTensor, torch.Tensor], - rhs_val: Union[int, float, TRTTensor, torch.Tensor], - op_type: trt.ElementWiseOperation, - target: Target, - name: str, -) -> TRTTensor: - """ - This function adds a TensorRT elementwise layer. We allow both operands to be - constant (not a trt tensor) because in implicit batch dimension mode, we could - introduce constant via .size() op. Other scenario should be const folded first. - If any operand is not a trt tensor, we make it a trt constant layer while preserve - its dtype. Then we broadcast these two inputs to have the same number of dimensions. - - Limitation: - If we are using implicit batch dim mode, the operand that is not a trt - tensor are not allowed to have larger ranks than the trt tensor operand. - - Args: - network (TRTNetwork): TensorRT network object. - lhs_val (TRTTensor): Left operand of the binary operation. Could - be a TensorRT tensor, a PyTorch tensor or a simple value. - rhs_val (TRTTensor): Right operand of the binary operation. Similar - to lhs_val. - op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - - Returns: - The output of TensorRT Elementwise layer. - """ - lhs_dtype = None - rhs_dtype = None - is_lhs_trt_tensor = False - is_rhs_trt_tensor = False - - if isinstance(lhs_val, TRTTensor): - lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) - is_lhs_trt_tensor = True - if isinstance(rhs_val, TRTTensor): - rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) - is_rhs_trt_tensor = True - - if not is_lhs_trt_tensor and not is_rhs_trt_tensor: - warnings.warn( - f"Both operands of the binary elementwise op {name} " - "are constant. In this case, please consider constant fold the model first." - ) - return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) - - # If the following conditions are true: - # 1. the network has implicit batch dimension, - # 2. one operand has shape [] (real shape is [batch_size]), - # 3. another operand is a scalar, - # then the result should also have shape [] (real shape is [batch_size]). - # - # In such case, we need to convert the scalar operand to tensor, because - # this way the shape will become [1], and then will be properly squeezed - # into [], meaning that the result will have shape [], which is what we - # expect. - # - # Note that the dtype here is supposed to be the same as the scalar - # dtype but we don't have a way to detect whether it makes sense for the - # scalar to be float or half. Hence we go with the lhs dtype. - if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): - rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) - if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): - lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) - - # When lhs is scalar, and rhs has shape [1,], then currently the assert - # will fail because lhs shape has fewer dimensions than rhs shape. This - # happens when using implicit batch dimension, when we removed the 1st - # dimension from input tensor, causing it to have shape [] - a scalar. We - # fix it by reducing the rhs constant with a squeeze_left, so it becomes a - # scalar too. More generally, we squeeze_left on input if it's a constant - # tensor. This is safe because broadcast will pad dimensions on the left - # (prepend) to make lhs and rhs shape compatible. - if network.has_implicit_batch_dimension: - if isinstance(lhs_val, torch.Tensor): - lhs_val = squeeze_left(lhs_val) - if isinstance(rhs_val, torch.Tensor): - rhs_val = squeeze_left(rhs_val) - - lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) - rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) - - # Check the limitation in the doc string. - if network.has_implicit_batch_dimension: - if is_lhs_trt_tensor and not is_rhs_trt_tensor: - assert len(lhs_val.shape) >= len( - rhs_val.shape - ), f"{lhs_val.shape} >= {rhs_val.shape}" - elif not is_lhs_trt_tensor and is_rhs_trt_tensor: - assert len(rhs_val.shape) >= len( - lhs_val.shape - ), f"{rhs_val.shape} >= {lhs_val.shape}" - - lhs_val, rhs_val = broadcast( - network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" - ) - layer = network.add_elementwise(lhs_val, rhs_val, op_type) - set_layer_name(layer, target, name) - output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ - return output - - def squeeze_left(const: torch.Tensor): """ Squeeze the size-1 dimensions on the left side of the shape tuple. @@ -528,80 +378,6 @@ def squeeze_left(const: torch.Tensor): return const -def add_unary_layer( - network: TRTNetwork, - input_val: TRTTensor, - operation_type: trt.UnaryOperation, - target: Target, - name: str, -) -> TRTTensor: - """ - Add a TensorRT Unary layer to `network`. - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. - op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - - Returns: - The output of TensorRT Unary layer. - """ - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"{operation_type} received input {input_val} that is not part " - "of the TensorRT region!" - ) - layer = network.add_unary(input_val, operation_type) - set_layer_name(layer, target, name) - output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ - return layer.get_output(0) - - -def add_activation_layer( - network: TRTNetwork, - input_val: TRTTensor, - operation_type: trt.ActivationType, - target: Target, - name: str, - alpha: Optional[Any] = None, - beta: Optional[Any] = None, -) -> TRTTensor: - """ - Add a TensorRT Activation layer to `network`. - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): Input to the activation op. - Must be a TensorRT tensor. - op_type (trt.ElementWiseOperation): Type of the TensorRT activation - operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - alpha (Optional[Any]): If not None, we will use it to set the alpha - attribute of the created TensorRT activation layer. - beta (Optional[Any]): If not None, we will use it to set the beta - attribute of the created TensorRT activation layer. - - Returns: - The output of TensorRT Activation layer. - """ - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"{operation_type} received input {input_val} that is not part " - "of the TensorRT region!" - ) - layer = network.add_activation(input_val, operation_type) - if alpha is not None: - layer.alpha = alpha - if beta is not None: - layer.beta = beta - set_layer_name(layer, target, name) - return layer.get_output(0) - - def add_reduce_layer( network: TRTNetwork, target: Target, @@ -706,139 +482,6 @@ def get_inputs_from_args_and_kwargs(args, kwargs, input_names): return inputs -def sign( - network: TRTNetwork, input_val: TRTTensor, target: Target, name: str -) -> TRTTensor: - """ - Sign is calculated as below: - x = input - sign = (exp(x) // exp(abs(x))) * 2 - 1 - For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. - With multiply 2, the value become 2(for pos and 0) and 0(for neg). - Finally minus 1, the value become 1(for pos and 0) and -1(for neg). - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): The input tensor. - target (Target): fx node target. - name (str): Name of the fx node with optional suffix. - - Returns: - A TensorRT tensor represent the result of sign operator. - """ - input_exp_output = add_unary_layer( - network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp" - ) - input_abs_output = add_unary_layer( - network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs" - ) - input_abs_exp_output = add_unary_layer( - network, - input_abs_output, - trt.UnaryOperation.EXP, - target, - f"{name}_prod_abs_exp", - ) - floor_div_output = add_binary_elementwise_layer( - network, - input_exp_output, - input_abs_exp_output, - trt.ElementWiseOperation.FLOOR_DIV, - target, - f"{name}_exp_floor_div", - ) - double_floor_div_output = add_binary_elementwise_layer( - network, - floor_div_output, - 2, - trt.ElementWiseOperation.PROD, - target, - f"{name}_floor_div*2", - ) - return add_binary_elementwise_layer( - network, - double_floor_div_output, - 1, - trt.ElementWiseOperation.SUB, - target, - f"{name}_sign", - ) - - -def trunc_div( - input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str -) -> TRTTensor: - """ - Perform trunc divide on Tensor, result of divide will be round toward zero. - This means for positive number, it will be floor round; for negative number, - it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. - - Args: - input: divisor. - other: dividend. - network: INetworkDefinition. - target: node target. - name: namespace for the op - - Returns: - A TensorRT tensor represent the result of trunc divide. - """ - prod_output = add_binary_elementwise_layer( - network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod" - ) - sign_output = sign(network, prod_output, target, name) - - # Convert constant input into ITensor for UnaryOperation - if not isinstance(input, trt.tensorrt.ITensor): - input = get_trt_tensor(network, input, f"{name}_input") - if not isinstance(other, trt.tensorrt.ITensor): - other = get_trt_tensor( - network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype) - ) - - abs_input_output = add_unary_layer( - network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input" - ) - abs_other_output = add_unary_layer( - network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other" - ) - abs_floor_output = add_binary_elementwise_layer( - network, - abs_input_output, - abs_other_output, - trt.ElementWiseOperation.FLOOR_DIV, - target, - f"{name}_floor_div", - ) - output = add_binary_elementwise_layer( - network, - abs_floor_output, - sign_output, - trt.ElementWiseOperation.PROD, - target, - f"{name}_output", - ) - - return output - - -def get_python_op_from_trt_elementwise_op( - trt_op: TRTElementWiseOp, -) -> Callable[[Any, Any], Any]: - if trt_op == trt.ElementWiseOperation.SUM: - return operator.add - elif trt_op == trt.ElementWiseOperation.PROD: - return operator.mul - elif trt_op == trt.ElementWiseOperation.SUB: - return operator.sub - elif trt_op == trt.ElementWiseOperation.DIV: - return operator.truediv - elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: - return operator.floordiv - else: - raise RuntimeError(f"{trt_op} is not supported yet!") - - def dtype_uniform( network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor ): @@ -868,27 +511,25 @@ def dtype_uniform( return input, other -def type_cast( - network: TRTNetwork, - target: Target, - name: str, - input: TRTTensor, - cast_type: TRTDataType, -): +def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: """ - This function helps to cast the input type to cast_type + Convert a PyTorch Tensor to a Numpy Array. If the tensor is + quantized it will be dequantized first. + + Args: + tensor (Optional[torch.Tensor]): A PyTorch tensor or None. + + Returns: + A Numpy array. """ - layer_i = network.add_identity(input) - layer_i.set_output_type(0, cast_type) - set_layer_name(layer_i, target, f"{name}_dtype_change") - return layer_i.get_output(0) + if tensor is None: + return tensor + + assert isinstance( + tensor, torch.Tensor + ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}" + if tensor.is_quantized: + tensor = tensor.dequantize() -def trt_dtype_to_torch_dtype(trt_dtype): - table = { - trt.bool: torch.bool, - trt.int32: torch.int32, - trt.float16: torch.float16, - trt.float32: torch.float32, - } - return table[trt_dtype] + return tensor.cpu().detach().contiguous().numpy() diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py new file mode 100644 index 0000000000..ffd6a1bab5 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -0,0 +1,1562 @@ +import numpy as np +import operator +import warnings +import logging +import math +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union + +import tensorrt as trt +import torch +from torch.fx.node import Argument, Target +from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt + +from ..tracer.acc_tracer import acc_ops + +from .converter_utils import get_trt_tensor +from .converter_utils import set_layer_name +from .converter_utils import get_trt_tensor +from .converter_utils import broadcast +from .converter_utils import broadcastable +from .converter_utils import squeeze_left +from .converter_utils import dtype_uniform +from .converter_utils import get_trt_plugin +from .converter_utils import get_positive_dim +from .converter_utils import prepend_ones +from .converter_utils import has_dynamic_shape +from .converter_utils import to_numpy + +from ..types import ( + Shape, + TRTDataType, + TRTElementWiseOp, + TRTLayer, + TRTNetwork, + TRTPlugin, + TRTPluginFieldCollection, + TRTTensor, +) + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def get_python_op_from_trt_elementwise_op( + trt_op: TRTElementWiseOp, +) -> Callable[[Any, Any], Any]: + if trt_op == trt.ElementWiseOperation.SUM: + return operator.add + elif trt_op == trt.ElementWiseOperation.PROD: + return operator.mul + elif trt_op == trt.ElementWiseOperation.SUB: + return operator.sub + elif trt_op == trt.ElementWiseOperation.DIV: + return operator.truediv + elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: + return operator.floordiv + else: + raise RuntimeError(f"{trt_op} is not supported yet!") + + +def add_binary_elementwise_layer( + network: TRTNetwork, + lhs_val: Union[int, float, TRTTensor, torch.Tensor], + rhs_val: Union[int, float, TRTTensor, torch.Tensor], + op_type: trt.ElementWiseOperation, + target: Target, + name: str, +) -> TRTTensor: + """ + This function adds a TensorRT elementwise layer. We allow both operands to be + constant (not a trt tensor) because in implicit batch dimension mode, we could + introduce constant via .size() op. Other scenario should be const folded first. + If any operand is not a trt tensor, we make it a trt constant layer while preserve + its dtype. Then we broadcast these two inputs to have the same number of dimensions. + + Limitation: + If we are using implicit batch dim mode, the operand that is not a trt + tensor are not allowed to have larger ranks than the trt tensor operand. + + Args: + network (TRTNetwork): TensorRT network object. + lhs_val (TRTTensor): Left operand of the binary operation. Could + be a TensorRT tensor, a PyTorch tensor or a simple value. + rhs_val (TRTTensor): Right operand of the binary operation. Similar + to lhs_val. + op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + + Returns: + The output of TensorRT Elementwise layer. + """ + lhs_dtype = None + rhs_dtype = None + is_lhs_trt_tensor = False + is_rhs_trt_tensor = False + + if isinstance(lhs_val, TRTTensor): + lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) + is_lhs_trt_tensor = True + if isinstance(rhs_val, TRTTensor): + rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) + is_rhs_trt_tensor = True + + if not is_lhs_trt_tensor and not is_rhs_trt_tensor: + warnings.warn( + f"Both operands of the binary elementwise op {name} " + "are constant. In this case, please consider constant fold the model first." + ) + return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) + + # If the following conditions are true: + # 1. the network has implicit batch dimension, + # 2. one operand has shape [] (real shape is [batch_size]), + # 3. another operand is a scalar, + # then the result should also have shape [] (real shape is [batch_size]). + # + # In such case, we need to convert the scalar operand to tensor, because + # this way the shape will become [1], and then will be properly squeezed + # into [], meaning that the result will have shape [], which is what we + # expect. + # + # Note that the dtype here is supposed to be the same as the scalar + # dtype but we don't have a way to detect whether it makes sense for the + # scalar to be float or half. Hence we go with the lhs dtype. + if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): + rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) + if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): + lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) + + # When lhs is scalar, and rhs has shape [1,], then currently the assert + # will fail because lhs shape has fewer dimensions than rhs shape. This + # happens when using implicit batch dimension, when we removed the 1st + # dimension from input tensor, causing it to have shape [] - a scalar. We + # fix it by reducing the rhs constant with a squeeze_left, so it becomes a + # scalar too. More generally, we squeeze_left on input if it's a constant + # tensor. This is safe because broadcast will pad dimensions on the left + # (prepend) to make lhs and rhs shape compatible. + if network.has_implicit_batch_dimension: + if isinstance(lhs_val, torch.Tensor): + lhs_val = squeeze_left(lhs_val) + if isinstance(rhs_val, torch.Tensor): + rhs_val = squeeze_left(rhs_val) + + lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) + + # Check the limitation in the doc string. + if network.has_implicit_batch_dimension: + if is_lhs_trt_tensor and not is_rhs_trt_tensor: + assert len(lhs_val.shape) >= len( + rhs_val.shape + ), f"{lhs_val.shape} >= {rhs_val.shape}" + elif not is_lhs_trt_tensor and is_rhs_trt_tensor: + assert len(rhs_val.shape) >= len( + lhs_val.shape + ), f"{rhs_val.shape} >= {lhs_val.shape}" + + lhs_val, rhs_val = broadcast( + network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + layer = network.add_elementwise(lhs_val, rhs_val, op_type) + set_layer_name(layer, target, name) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return output + + +def sign( + network: TRTNetwork, input_val: TRTTensor, target: Target, name: str +) -> TRTTensor: + """ + Sign is calculated as below: + x = input + sign = (exp(x) // exp(abs(x))) * 2 - 1 + For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. + With multiply 2, the value become 2(for pos and 0) and 0(for neg). + Finally minus 1, the value become 1(for pos and 0) and -1(for neg). + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): The input tensor. + target (Target): fx node target. + name (str): Name of the fx node with optional suffix. + + Returns: + A TensorRT tensor represent the result of sign operator. + """ + input_exp_output = add_unary_layer( + network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp" + ) + input_abs_output = add_unary_layer( + network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs" + ) + input_abs_exp_output = add_unary_layer( + network, + input_abs_output, + trt.UnaryOperation.EXP, + target, + f"{name}_prod_abs_exp", + ) + floor_div_output = add_binary_elementwise_layer( + network, + input_exp_output, + input_abs_exp_output, + trt.ElementWiseOperation.FLOOR_DIV, + target, + f"{name}_exp_floor_div", + ) + double_floor_div_output = add_binary_elementwise_layer( + network, + floor_div_output, + 2, + trt.ElementWiseOperation.PROD, + target, + f"{name}_floor_div*2", + ) + return add_binary_elementwise_layer( + network, + double_floor_div_output, + 1, + trt.ElementWiseOperation.SUB, + target, + f"{name}_sign", + ) + + +def trunc_div( + input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str +) -> TRTTensor: + """ + Perform trunc divide on Tensor, result of divide will be round toward zero. + This means for positive number, it will be floor round; for negative number, + it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. + + Args: + input: divisor. + other: dividend. + network: INetworkDefinition. + target: node target. + name: namespace for the op + + Returns: + A TensorRT tensor represent the result of trunc divide. + """ + prod_output = add_binary_elementwise_layer( + network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod" + ) + sign_output = sign(network, prod_output, target, name) + + # Convert constant input into ITensor for UnaryOperation + if not isinstance(input, trt.tensorrt.ITensor): + input = get_trt_tensor(network, input, f"{name}_input") + if not isinstance(other, trt.tensorrt.ITensor): + other = get_trt_tensor( + network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype) + ) + + abs_input_output = add_unary_layer( + network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input" + ) + abs_other_output = add_unary_layer( + network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other" + ) + abs_floor_output = add_binary_elementwise_layer( + network, + abs_input_output, + abs_other_output, + trt.ElementWiseOperation.FLOOR_DIV, + target, + f"{name}_floor_div", + ) + output = add_binary_elementwise_layer( + network, + abs_floor_output, + sign_output, + trt.ElementWiseOperation.PROD, + target, + f"{name}_output", + ) + + return output + + +def trt_dtype_to_torch_dtype(trt_dtype): + table = { + trt.bool: torch.bool, + trt.int32: torch.int32, + trt.float16: torch.float16, + trt.float32: torch.float32, + } + return table[trt_dtype] + + +def get_shape_with_dynamic_shape( + network: TRTNetwork, + shape: Union[list, tuple, torch.Tensor], + input_val: TRTTensor, + target: Target, + name: str, +) -> TRTTensor: + """ + Prepare the real output tensor shape for dynamic shape mode tensor input. + How this functions works: + Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation + output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual + reduce operation output shape. Steps of calculations are: + 1. get the actual tensor shape of input_val via add_shape layer; + 2. create a all 0 tensor [0, 0, 0]; + 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; + 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace + all -1 dynamic shape dimensions with actual batch_size value; + 5. output shape with actual batch_size as [2048, 128, 256] + + Args: + network (TRTNetwork): TensorRT network object. + shape: calculated shape of the expected output tensor + input_val (TRTTensor): A TensorRT ITensor. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + Returns: + TensorRT ITensors that represents the actual shape of the input_val + """ + # Ger real shape info for input_val + input_shape = network.add_shape(input_val).get_output(0) + + scale_layer = network.add_constant( + input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) + ) + set_layer_name(scale_layer, target, f"{name}_scale") + scale_res = scale_layer.get_output(0) + + length = input_shape.shape[0] + zero_layer = network.add_constant( + input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) + ) + set_layer_name(zero_layer, target, f"{name}_zeros") + + condition_val = add_binary_elementwise_layer( + network, + scale_res, + zero_layer.get_output(0), + trt.ElementWiseOperation.LESS, + target, + f"{name}_shape", + ) + select_layer = network.add_select(condition_val, input_shape, scale_res) + set_layer_name(select_layer, target, f"{name}_select") + return select_layer.get_output(0) + + +def type_cast( + network: TRTNetwork, + target: Target, + name: str, + input: TRTTensor, + cast_type: TRTDataType, +): + """ + This function helps to cast the input type to cast_type + """ + layer_i = network.add_identity(input) + layer_i.set_output_type(0, cast_type) + set_layer_name(layer_i, target, f"{name}_dtype_change") + return layer_i.get_output(0) + + +def add_tile(network, target, kwargs, name): + input_t = kwargs["input"] + input_val = get_trt_tensor(network, input_t, f"{name}_input") + + dims = tuple(cast(Sequence[int], kwargs["dims"])) + n_input_dims = len(input_val.shape) + ( + 1 if network.has_implicit_batch_dimension else 0 + ) + + if len(dims) > n_input_dims: + assert not network.has_implicit_batch_dimension + layer = network.add_shuffle(input_val) + layer.name = f"{name}_reshape" + num_preceding_ones = len(dims) - n_input_dims + + if len(get_dynamic_dims(input_val.shape)) > 1: + input_shape_layer = network.add_shape(input_val) + input_shape_layer.name = f"{name}_input_shape" + preceding_ones = network.add_constant( + (num_preceding_ones,), + np.ascontiguousarray([1] * num_preceding_ones, np.int32), + ).get_output(0) + reshape_layer = network.add_concatenation( + [preceding_ones, input_shape_layer.get_output(0)] + ) + reshape_layer.axis = 0 + reshape_layer.name = f"{name}_reshape_dims" + layer.set_input(1, reshape_layer.get_output(0)) + else: + layer.reshape_dims = (1,) * (len(dims) - n_input_dims) + tuple( + input_val.shape + ) + input_val = layer.get_output(0) + else: + dims = (1,) * (n_input_dims - len(dims)) + dims + + if network.has_implicit_batch_dimension: + assert dims[0] == 1, "Can't tile the batch dim when it's implicit." + dims = dims[1:] + starts = [0] * len(dims) + shapes = [] + if all(isinstance(d, int) for d in dims): + shapes = [i * j for i, j in zip(input_val.shape, dims)] # type: ignore[union-attr] + else: + shape = [] + for i, (s, d) in enumerate(zip(input_val.shape, dims)): + if isinstance(d, TRTTensor) and len(d.shape) == 0: + d = prepend_ones(network, d, f"{name}_{i}", 1) + else: + d = get_trt_tensor(network, d, f"{name}_{i}") + shape.append(d) + mul = add_binary_elementwise_layer( + network, + s, + d, + trt.ElementWiseOperation.PROD, + target, + f"{name}_mul_{i}", + ) + shapes.append(mul) + dims = shape + # If there's dynmaic dim then there would be negative dims in shapes which is not allowed. + # Here we build a dummy shapes array. + if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] + shapes = [1] * len(dims) + strides = [1] * len(dims) + layer = network.add_slice(input_val, starts, shapes, strides) + layer.mode = trt.SliceMode.WRAP + set_layer_name(layer, target, name) + + if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] + starts_tensor = network.add_constant( + (len(dims),), np.ascontiguousarray([0] * len(dims), np.int32) + ).get_output(0) + if all(isinstance(d, int) for d in dims): + dims_tensor = network.add_constant( + (len(dims),), np.ascontiguousarray(dims, np.int32) + ).get_output(0) + else: + assert all(isinstance(d, TRTTensor) for d in dims) + concat_dims_layer = network.add_concatenation(inputs=dims) + concat_dims_layer.axis = 0 + concat_dims_layer.name = f"{name}_tile_dim" + dims_tensor = concat_dims_layer.get_output(0) + input_shape_layer = network.add_shape(input_val) + input_shape_layer.name = f"{name}_slice_input_shape" + slice_shapes_tensor = add_binary_elementwise_layer( + network, + input_shape_layer.get_output(0), + dims_tensor, + trt.ElementWiseOperation.PROD, + target, + f"{name}_slice_shapes", + ) + layer.set_input(1, starts_tensor) + layer.set_input(2, slice_shapes_tensor) + + return layer.get_output(0) + + +def add_linear(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"Linear received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dynamic_dims = get_dynamic_dims(input_val.shape) + assert len(dynamic_dims) < 2 and input_val.shape[-1] != -1, ( + "Currently we only support one dynmaic " + "dim for linear and it can't be the last dim." + ) + + if isinstance(kwargs["weight"], torch.Tensor): + weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight") + if target not in (acc_ops.linear, torch.ops.aten.linear): + weight_op = trt.MatrixOperation.TRANSPOSE + else: + weight_op = trt.MatrixOperation.NONE + else: + assert isinstance( + kwargs["weight"], TRTTensor + ), f"Expect weight to be trt tensor but got {type(kwargs['weight'])}" + weight = kwargs["weight"] + weight_op = trt.MatrixOperation.TRANSPOSE + + preset_diff = 0 + if len(input_val.shape) == 1: + preset_diff -= 1 + input_op = trt.MatrixOperation.VECTOR + else: + input_op = trt.MatrixOperation.NONE + + input_val, weight = broadcast( + network, input_val, weight, f"{name}_input", f"{name}_weight", preset_diff + ) + matmul_layer = network.add_matrix_multiply(input_val, input_op, weight, weight_op) + set_layer_name(matmul_layer, target, f"{name}_matmul") + res = matmul_layer.get_output(0) + + if kwargs["bias"] is not None: + bias = get_trt_tensor(network, kwargs["bias"], f"{name}_bias") # type: ignore[arg-type] + res = add_binary_elementwise_layer( + network, + matmul_layer.get_output(0), + bias, + trt.ElementWiseOperation.SUM, + target, + f"{name}_add", + ) + return res + + +def add_unary_layer( + network: TRTNetwork, + input_val: TRTTensor, + operation_type: trt.UnaryOperation, + target: Target, + name: str, +) -> TRTTensor: + """ + Add a TensorRT Unary layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. + op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + + Returns: + The output of TensorRT Unary layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_unary(input_val, operation_type) + set_layer_name(layer, target, name) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return layer.get_output(0) + + +def layer_norm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"LayerNorm received input {input_val} that is not part " + "of the TensorRT region!" + ) + + shape = kwargs["weight"].shape # type: ignore[union-attr] + broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape + gamma = to_numpy(kwargs["weight"].reshape(*shape)) # type: ignore[union-attr] + beta = to_numpy(kwargs["bias"].reshape(*shape)) # type: ignore[union-attr] + eps = kwargs["eps"] + + axes = 0 + for d in range(len(shape)): + axes |= 1 << (len(input_val.shape) - d - 1) + + # E[x] + mean_expected_layer = network.add_reduce( + input_val, trt.ReduceOperation.AVG, axes, keep_dims=True + ) + set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") + + # X-E[x] + sub_trt = add_binary_elementwise_layer( + network, + input_val, + mean_expected_layer.get_output(0), + trt.ElementWiseOperation.SUB, + target, + f"{name}_sub", + ) + # Variance = mean(pow(x_sub_mean,2)) + pow_tensor = network.add_constant( + (1,) * len(input_val.shape), + trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), + ) + pow_tensor.name = f"{name}_power" + pow_var = add_binary_elementwise_layer( + network, + sub_trt, + pow_tensor.get_output(0), + trt.ElementWiseOperation.POW, + target, + f"{name}_pow_var", + ) + mean_trt_layer = network.add_reduce( + pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True + ) + set_layer_name(mean_trt_layer, target, f"{name}_mean") + # Variance + eps + eps_tensor = network.add_constant( + (1,) * len(input_val.shape), + trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), + ) + eps_tensor.name = f"{name}_eps" + add_trt = add_binary_elementwise_layer( + network, + mean_trt_layer.get_output(0), + eps_tensor.get_output(0), + trt.ElementWiseOperation.SUM, + target, + f"{name}_add", + ) + # SQRT((Var + eps)) + sqrt_trt = add_unary_layer( + network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt" + ) + # (x - E[x]) / sqrt((var + eps)) + div_trt = add_binary_elementwise_layer( + network, + sub_trt, + sqrt_trt, + trt.ElementWiseOperation.DIV, + target, + f"{name}_div_trt", + ) + + assert gamma is not None + gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) # type: ignore[attr-defined] + gamma_tensor.name = f"{name}_gamma" + assert beta is not None + beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] + beta_tensor.name = f"{name}_beta" + # y * gamma + beta + scale_layer = add_binary_elementwise_layer( + network, + div_trt, + gamma_tensor.get_output(0), + trt.ElementWiseOperation.PROD, + target, + f"{name}_scale", + ) + return add_binary_elementwise_layer( + network, + scale_layer, + beta_tensor.get_output(0), + trt.ElementWiseOperation.SUM, + target, + name, + ) + + +def add_add(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.SUM, + target, + name, + ) + + +def add_matmul(network, target, kwargs, name): + input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input") + other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other") + + for i in [input_val, other_val]: + if not isinstance(i, TRTTensor): + raise RuntimeError( + f"matmul received input {i} that is not part of the TensorRT region!" + ) + + input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE + preset_diff = 0 + + if len(input_val.shape) == 1: + preset_diff -= 1 + input_matrix_op = trt.MatrixOperation.VECTOR + + if len(other_val.shape) == 1: + preset_diff += 1 + other_matrix_op = trt.MatrixOperation.VECTOR + + input_val, other_val = broadcast( + network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff + ) + layer = network.add_matrix_multiply( + input_val, input_matrix_op, other_val, other_matrix_op + ) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_layer_norm(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError( + f"LayerNorm received input {input_val} that is not part " + "of the TensorRT region!" + ) + + gamma = kwargs["weight"].detach().cpu().float().numpy() + gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) + beta = kwargs["bias"].detach().cpu().float().numpy() + beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) + eps_field = trt.PluginField( + "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32 + ) + try: + normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32) + except TypeError: + _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") + normalized_shape = np.array([], dtype=np.int32) + + normalized_shape_filed = trt.PluginField( + "normalized_shape", normalized_shape, trt.PluginFieldType.INT32 + ) + field_collection = trt.PluginFieldCollection( + [gamma_field, beta_field, eps_field, normalized_shape_filed] + ) + + try: + if network.has_implicit_batch_dimension: + plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") + else: + plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") + except AssertionError: + _LOGGER.error( + "Unable to find layer norm plugin, fall back to TensorRT implementation." + ) + args = [] + return layer_norm(network, target, args, kwargs, name) + layer = network.add_plugin_v2([input_val], plugin) + layer.name = name + return layer.get_output(0) + + +def add_cumsum(network, target, kwargs, name): + input_val = kwargs["input"] + dim = cast(int, kwargs["dim"]) + input_shape = input_val.shape # type: ignore[union-attr] + input_dim_size = len(input_val.shape) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"cumsum received input {input_val} that is not part " + "of the TensorRT region!" + ) + if network.has_implicit_batch_dimension: + raise RuntimeError( + "cumsum converter currently doesn't support implicit batch dimension" + ) + dim = get_positive_dim(dim, input_dim_size) + loop = network.add_loop() + trip_limit = None + if input_shape[dim] > 0: + axis = torch.tensor(input_shape[dim], dtype=torch.int32) + trip_limit_layer = network.add_constant(axis.shape, to_numpy(axis)) + else: + input_shape = network.add_shape(input_val).get_output(0) + dim_value = torch.tensor(dim, dtype=torch.int32) + axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0) + trip_limit_layer = network.add_gather(input_shape, axis, 0) + set_layer_name(trip_limit_layer, target, f"{name}_trip_limit") + trip_limit = trip_limit_layer.get_output(0) + + loop.add_trip_limit(trip_limit, trt.TripLimit(0)) + iterator = loop.add_iterator(input_val, dim, False) + data = iterator.get_output(0) + new_dims = tuple(data.shape) + zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype)) + zero_tensor = network.add_constant( + zero_tensor.shape, to_numpy(zero_tensor) + ).get_output(0) + + running_sum = loop.add_recurrence(zero_tensor) + set_layer_name(running_sum, target, f"{name}_running_sum_1") + running_sum_tensor = running_sum.get_output(0) + + current_sum = add_binary_elementwise_layer( + network, + data, + running_sum_tensor, + trt.ElementWiseOperation.SUM, + target, + f"{name}_sum_1", + ) + running_sum.set_input(1, current_sum) + + running_sum = loop.add_recurrence(zero_tensor) + set_layer_name(running_sum, target, f"{name}_running_sum_2") + running_sum_tensor = running_sum.get_output(0) + + current_sum = add_binary_elementwise_layer( + network, + data, + running_sum_tensor, + trt.ElementWiseOperation.SUM, + target, + f"{name}_sum_2", + ) + running_sum.set_input(1, current_sum) + + loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim) + set_layer_name(loop_output, target, f"{name}_loop_output") + loop_output.set_input(1, trip_limit) + return loop_output.get_output(0) + + +def add_maximum(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.MAX, + target, + name, + ) + + +def add_mul(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.PROD, + target, + name, + ) + + +def add_pow(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["exponent"], + trt.ElementWiseOperation.POW, + target, + name, + ) + + +def add_floor_div(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.FLOOR_DIV, + target, + name, + ) + + +def add_div(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.DIV, + target, + name, + ) + + +def add_sub(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.SUB, + target, + name, + ) + + +def add_minimum(network, target, kwargs, name): + return add_binary_elementwise_layer( + network, + kwargs["input"], + kwargs["other"], + trt.ElementWiseOperation.MIN, + target, + name, + ) + + +def add_logical_not(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.NOT + # cast to bool type + if input_val.dtype in (trt.float32, trt.float16, trt.int32): + input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool) + return add_unary_layer(network, input_val, operation_type, target, name) + + +def add_logical_and(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `logical_and` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + # we only support both inputs are bool type + if target == acc_ops.bitwise_and: + + def check_is_bool(input_t): + if isinstance(input_t, TRTTensor): + assert ( + input_t.dtype == trt.bool + ), "We currently do not support input is non-bool" + elif isinstance(input_t, torch.Tensor): + assert ( + input_t.dtype == torch.bool + ), "We currently do not support input is non-bool" + else: + assert isinstance( + input_t.bool + ), "We currently do not support input is non-bool" + + check_is_bool(input_t) + check_is_bool(other_t) + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + if input_t.dtype != trt.bool: + input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool) + if other_t.dtype != trt.bool: + other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool) + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.AND, target, name + ) + + +def add_ne(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `ne` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) + eq_t = add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + ) + + return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) + + +def add_eq(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `eq` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + ) + + +def add_gt(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `gt` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name + ) + + +def add_lt(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `le` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + + input_t = get_trt_tensor(network, input_t, f"{name}_input_t") + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + + input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name + ) + + +def add_logical_or(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `logical_or` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + if isinstance(other_t, (torch.Tensor, bool)): + if isinstance(other_t, bool): + other_t = int(other_t) + elif other_t.dtype == torch.bool: + other_t = other_t.to(torch.int32) + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + if input_t.dtype != trt.bool: + layer_i = network.add_identity(input_t) + layer_i.set_output_type(0, trt.bool) + set_layer_name(layer_i, target, f"{name}_input_dtype_change") + input_t = layer_i.get_output(0) + if other_t.dtype != trt.bool: + layer_o = network.add_identity(other_t) + layer_o.set_output_type(0, trt.bool) + set_layer_name(layer_o, target, f"{name}_other_dtype_change") + other_t = layer_o.get_output(0) + + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.OR, target, name + ) + + +def add_logical_xor(network, target, kwargs, name): + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `logical_xor` function should be called with explicit batch dimension." + ) + + input_t = kwargs["input"] + other_t = kwargs["other"] + if isinstance(other_t, (torch.Tensor, bool)): + if isinstance(other_t, bool): + other_t = int(other_t) + elif other_t.dtype == torch.bool: + other_t = other_t.to(torch.int32) + other_t = get_trt_tensor(network, other_t, f"{name}_other_t") + if input_t.dtype != trt.bool: + layer_i = network.add_identity(input_t) + layer_i.set_output_type(0, trt.bool) + set_layer_name(layer_i, target, f"{name}_input_dtype_change") + input_t = layer_i.get_output(0) + if other_t.dtype != trt.bool: + layer_o = network.add_identity(other_t) + layer_o.set_output_type(0, trt.bool) + set_layer_name(layer_o, target, f"{name}_other_dtype_change") + other_t = layer_o.get_output(0) + + return add_binary_elementwise_layer( + network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name + ) + + +def add_fmod(network, target, kwargs, name): + # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it + trunc_div_value = trunc_div( + kwargs["input"], kwargs["other"], network, target, name + "_trunc_div" + ) + prod_value = add_binary_elementwise_layer( + network, + trunc_div_value, + kwargs["other"], + trt.ElementWiseOperation.PROD, + target, + name + "_prod", + ) + sub_value = add_binary_elementwise_layer( + network, + kwargs["input"], + prod_value, + trt.ElementWiseOperation.SUB, + target, + name + "_sub", + ) + return sub_value + + +def add_trunc_div(network, target, kwargs, name): + return trunc_div(kwargs["input"], kwargs["other"], network, target, name) + + +def add_expand(network, target, kwargs, name): + input_t = kwargs["input"] + shape = list(kwargs["sizes"]) + + input_val = get_trt_tensor(network, input_t, f"{name}_input") + + if network.has_implicit_batch_dimension: + shape = shape[1:] + + ranks = len(input_val.shape) + # TRT does not support different dimension size + # though this condition is not seen in the case of bmm + # where input_t and shape dimensions are not equal + assert len(shape) >= ranks + if len(shape) != ranks: + shape_tuple = tuple([0] * len(shape)) + shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape") + input_val, shape_tensor = broadcast( + network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val" + ) + ranks = len(shape) + + shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)] + inshape = tuple(input_val.shape) + shape = tuple(shape) + start = tuple([0] * ranks) + stride = tuple( + [int(i == o) for i, o in zip(inshape, shape)] + ) # stride == 1 if dimensions match, 0 otherwise + layer = network.add_slice(input_val, start=start, shape=shape, stride=stride) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_select(network, target, kwargs, name): + input_val = kwargs["input"] + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input_val} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert ( + input_val.shape[dim] != -1 + ), "Can't select on negative shape dimension!" + index = kwargs["index"] + + if index >= input_val.shape[dim]: + raise RuntimeError( + f"cannot have index greater than the dimension length! {input_val.shape[dim]}" + ) + output_shape = list(input_val.shape) + output_shape[dim] = 1 + if dynamic_shape > 0: + output_shape = get_shape_with_dynamic_shape( + network, output_shape, input_val, target, name + ) + index_value = torch.tensor(index, dtype=torch.int32) + indices_tensor = network.add_constant( + index_value.shape, to_numpy(index_value) + ).get_output(0) + layer = network.add_gather(input_val, indices_tensor, dim) + out = layer.get_output(0) + if len(out.shape) != 1: + layer = network.add_shuffle(out) + return layer.get_output(0) + + +def add_slice(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input_val} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, kwargs["dim"]), ranks) + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + + start_int = cast(int, kwargs["start"]) + stop_int = cast(int, kwargs["stop"]) + step_int = cast(int, kwargs["step"]) + start = [0] * len(input_val.shape) + start[dim] = start_int + stride = [1] * len(start) + stride[dim] = step_int + output_shape = list(input_val.shape) + output_shape[dim] = math.ceil((stop_int - start_int) / step_int) + + if dynamic_shape > 0: + output_shape = get_shape_with_dynamic_shape( + network, output_shape, input_val, target, name + ) + layer = network.add_slice( + input_val, + start=start, + shape=[] if dynamic_shape else output_shape, + stride=stride, + ) + if dynamic_shape: + layer.set_input(2, output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_matmul(network, target, kwargs, name): + input_val = get_trt_tensor(network, kwargs["input"], f"{name}_input") + other_val = get_trt_tensor(network, kwargs["other"], f"{name}_other") + + for i in [input_val, other_val]: + if not isinstance(i, TRTTensor): + raise RuntimeError( + f"matmul received input {i} that is not part of the TensorRT region!" + ) + + input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE + preset_diff = 0 + + if len(input_val.shape) == 1: + preset_diff -= 1 + input_matrix_op = trt.MatrixOperation.VECTOR + + if len(other_val.shape) == 1: + preset_diff += 1 + other_matrix_op = trt.MatrixOperation.VECTOR + + input_val, other_val = broadcast( + network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff + ) + layer = network.add_matrix_multiply( + input_val, input_matrix_op, other_val, other_matrix_op + ) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_softmax(network, target, kwargs, name): + input_val = kwargs["input"] + input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"softmax received input {input_val} that is not part " + "of the TensorRT region!" + ) + + # Used to get dim when dim is None. Copied from PyTorch softmax implementation. + def get_softmax_dim(ndim: int) -> int: + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + if kwargs["dim"] is None: + dim = get_softmax_dim(input_ranks) + else: + dim = cast(int, kwargs["dim"]) + + dim = get_positive_dim(dim, input_ranks) + if network.has_implicit_batch_dimension: + assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." + dim -= 1 + + layer = network.add_softmax(input_val) + layer.axes = 1 << dim + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_squeeze(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"squeeze received input {input_val} that is not part " + "of the TensorRT region!" + ) + dims = [] + if "dim" in kwargs: + if isinstance(kwargs["dim"], int): + dims.append(cast(Optional[int], kwargs["dim"])) + else: + for dim in kwargs["dim"]: + dims.append(cast(Optional[int], dim)) + + # dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) + # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic + # dim, which is a very rare case. For now we just claim not supporting dim=None. + assert not (len(dims) == 0), "We don't support dim=None right now for squeeze." + + for dim in dims: + dim = get_positive_dim( + dim, + len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0), + ) + if network.has_implicit_batch_dimension: + assert dim != 0, "We don't support squeeze batch dim when it's implicit." + dim -= 1 + + assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "Currently more than one dynamic dim for input to squeeze is not supported." + + output_shape = [] + for i, s in enumerate(input_val.shape): + if (i in dims) and s == 1: + continue + output_shape.append(s) + layer = network.add_shuffle(input_val) + layer.reshape_dims = tuple(output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_chunk(network, target, kwargs, name): + input_val = kwargs["input"] + chunks = cast(int, kwargs["chunks"]) + dim = cast(int, kwargs["dim"]) + input_dim_size = len(input_val.shape) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"chunk received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + input_dim_size += 1 + dim = get_positive_dim(dim, input_dim_size) + assert dim != 0, "Can't chunk on batch dim when it's implicit!" + dim -= 1 + else: + if dynamic_shape: + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + dim = get_positive_dim(dim, input_dim_size) + + if chunks > input_val.shape[dim]: + warnings.warn( + f"Asked for {chunks} chunks along dimention " + f"{dim} on tensor with size {input_val.shape}, chunks " + f"will default to {input_val.shape[dim]}", + RuntimeWarning, + ) + chunks = input_val.shape[dim] + + start = [0] * len(input_val.shape) + stride = [1] * len(start) + offset = 0 + split_size = (input_val.shape[dim] + chunks - 1) // chunks + + max_offset = input_val.shape[dim] + # add slice layers + output = [] + for i in range(chunks): + shape = list(input_val.shape) + shape[dim] = min(split_size, max_offset - offset) + if dynamic_shape: + shape = get_shape_with_dynamic_shape( + network, shape, input_val, target, f"{name}_{i}" + ) + start[dim] = offset + layer = network.add_slice( + input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride + ) + if dynamic_shape: + layer.set_input(2, shape) + offset += split_size + set_layer_name(layer, target, f"{name}_{i}") + output.append(layer.get_output(0)) + return output + + +def add_where(network, target, kwargs, name): + condition_t = kwargs["condition"] + x_t = kwargs["x"] + y_t = kwargs["y"] + + x_t_dim = len(tuple(x_t.shape)) + y_t_dim = len(tuple(y_t.shape)) + condition_t_dim = len(tuple(condition_t.shape)) + + if type(x_t) != TRTTensor: + assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" + + if type(y_t) != TRTTensor: + assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" + + if not (broadcastable(x_t, y_t)): + assert f"The two torch tensors should be broadcastable" + + # get output shape + # purpose of this is to bring x_t and y_t rank same as + # output_shape to input it to the add_expand operation + # condition_t will have dimension of either x_t or y_t + x_t, y_t = broadcast(network, x_t, y_t, f"{name}_x", f"{name}_y") + if len(tuple(condition_t.shape)) != len(tuple(x_t.shape)): + condition_t, x_t = broadcast( + network, condition_t, x_t, f"{name}_condition", f"{name}_x" + ) + + print("x_t shape", x_t.shape) + print("y_t shape", y_t.shape) + print("condition_t shape", condition_t.shape) + x_shape = list(x_t.shape) + y_shape = list(y_t.shape) + condition_shape = list(condition_t.shape) + output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) + + # expand shape + if type(condition_t) != TRTTensor: + assert condition_t.dtype == torch.bool, "condition dtype is not bool" + if condition_shape != output_shape: + condition_t.expand(output_shape) + condition_t = condition_t.to(torch.int32) + condition_const = get_trt_tensor(network, condition_t, f"{name}_condition") + condition_layer = network.add_identity(condition_const) + condition_layer.set_output_type(0, trt.bool) + set_layer_name(condition_layer, target, f"{name}_condition") + condition_val = condition_layer.get_output(0) + else: + assert condition_t.dtype == trt.bool, "mask dtype is not bool!" + if condition_shape != condition_t_dim: + condition_val = add_expand( + network, + target, + {"input": condition_t, "sizes": output_shape}, + name=f"{name}_expand", + ) + else: + condition_val = condition_t + + if type(x_t) != TRTTensor: + if x_shape != x_t_dim: + # special case where 1 element in x_t + if len(x_t.shape) == 0: + x_t = x_t.unsqueeze(0) + x_t = x_t.expand(output_shape) + x_val = get_trt_tensor(network, x_t, f"{name}_x") + else: + x_val = x_t + if x_shape != output_shape: + x_val = add_expand( + network, + target, + {"input": x_val, "sizes": output_shape}, + name=f"{name}_x_expand", + ) + + if type(y_t) != TRTTensor: + if y_shape != output_shape: + # special case where 1 element in y_t + if len(y_t.shape) == 0: + y_t = y_t.unsqueeze(0) + y_t = y_t.expand(output_shape) + y_val = get_trt_tensor(network, y_t, f"{name}_y") + else: + y_val = y_t + if y_shape != y_t_dim: + y_val = add_expand( + network, + target, + {"input": y_val, "sizes": output_shape}, + name=f"{name}_y_expand", + ) + + select_layer = network.add_select(condition_val, x_val, y_val) + + set_layer_name(select_layer, target, f"{name}_select") + + return select_layer.get_output(0) + + +def add_scale(network, target, kwargs, name): + other = kwargs["other"] + scale = kwargs["scale"] + if isinstance(other, TRTTensor): + other_dtype = torch_dtype_from_trt(other.dtype) + is_other_trt_tensor = True + + if not is_other_trt_tensor: + warnings.warn( + f"The value to be scaled is constant" + "In this case, please consider constant fold the model first." + ) + return other * scale + layer = network.add_scale(other, trt.ScaleMode.UNIFORM, 0, scale, 1) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_rsub(network, target, kwargs, name): + kwargs_new = {} + if "alpha" in kwargs: + kwargs_new["input"] = kwargs["other"] + kwargs_new["other"] = kwargs["alpha"] + scaled_tensor = add_mul(network, target, kwargs_new, name + "_mul") + else: + scaled_tensor = kwargs["other"] + input = kwargs["input"] + return add_binary_elementwise_layer( + network, + kwargs["input"], + scaled_tensor, + trt.ElementWiseOperation.SUB, + target, + name + "_sub", + ) + + +def add_sqrt(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.SQRT + return add_unary_layer(network, input_val, operation_type, target, name) + + +def add_rsqrt(network, target, kwargs, name): + sqrt_trt = add_sqrt(network, target, kwargs, name) + return add_binary_elementwise_layer( + network, + 1, + sqrt_trt, + trt.ElementWiseOperation.DIV, + target, + f"{name}_div_trt", + ) diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py index d265def896..c9b4776dd3 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py @@ -12,7 +12,12 @@ def forward(self, x): return x.squeeze(2) inputs = [torch.randn(1, 2, 1)] - self.run_test(Squeeze(), inputs, expected_ops={acc_ops.squeeze}) + self.run_test( + Squeeze(), + inputs, + expected_ops={acc_ops.squeeze}, + test_implicit_batch_dim=False, + ) # Testing with shape=(-1, -1, -1, -1) results in error: # AssertionError: We don't support squeeze dynamic dim. diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py new file mode 100644 index 0000000000..cd8ef1b48a --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestELUConverter(DispatchTestCase): + def test_elu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_elu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_elu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py new file mode 100644 index 0000000000..bcc9cbb761 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_gelu_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestGeLUConverter(DispatchTestCase): + def test_gelu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.gelu.default}) + + def test_gelu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + def test_gelu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_hard_tanh_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_hard_tanh_aten.py new file mode 100644 index 0000000000..644c09345f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_hard_tanh_aten.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestHardTanHConverter(DispatchTestCase): + def test_hardtanh(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + def test_hardtanh_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + def test_hardtanh_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py new file mode 100644 index 0000000000..fab398ac0f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -0,0 +1,44 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestLayerNormConverter(DispatchTestCase): + def test_layer_norm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + def test_layernorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py new file mode 100644 index 0000000000..7cdce77092 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestLeakyReLUConverter(DispatchTestCase): + def test_leaky_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py new file mode 100644 index 0000000000..e0dc05fded --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py @@ -0,0 +1,114 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + +import torch +import torch.nn as nn +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec + + +class TestMatMulConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("2_2", (2, 3), (3, 1)), + # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # (2,3), (3,) torch.ops.aten.mv.default + # Following cases use torch.ops.aten.bmm.defauly + # ("4_3", (3,1,3,2), (2,2,3)), + # ("3_4", (3,1,3,2), (2,2,3)), + # ("3_4", (2, 2, 3), (3, 1, 3, 3)), + # ("4_2", (1, 2, 2, 3), (3, 2)), + ] + ) + def test_matmul_other_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.other = nn.Parameter(torch.randn(*other_shape)) + + def forward(self, input): + return torch.matmul(input, self.other) + + inputs = [torch.randn(*input_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + test_explicit_batch_dim=(len(input_shape) >= 1), + ) + + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("1_2", (1, 3), (3, 2)), + # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # (2,3), (3,) torch.ops.aten.mv.default + # Following cases use torch.ops.aten.bmm.defauly + # ("4_3", (3,1,3,2), (2,2,3)), + # ("3_4", (3,1,3,2), (2,2,3)), + # ("3_4", (2, 2, 3), (3, 1, 3, 3)), + # ("4_2", (1, 2, 2, 3), (3, 2)), + ] + ) + def test_matmul_input_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.input = nn.Parameter(torch.randn(*input_shape)) + + def forward(self, other): + return torch.matmul(self.input, other) + + inputs = [torch.randn(*other_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + test_explicit_batch_dim=True + # test_explicit_batch_dim=(len(other_shape) <= 2), + ) + + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + # ("2_3", (2, 3), (2, 3, 4)), + # ("4_4", (2, 2, 2, 3), (2, 1, 3, 2)), + # ("4_2", (2, 1, 2, 3), (3, 2)), + # ("2_1", (2, 3), (3,)), + # ("1_2", (3,), (3, 2)), + # ("1_1", (3,), (3,)), + ] + ) + def test_matmul(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def forward(self, input, other): + return torch.matmul(input, other) + + inputs = [torch.randn(*input_shape), torch.randn(*other_shape)] + test_explicit_batch_dim = not ( + input_shape[0] == other_shape[0] + and len(input_shape) > 2 + and len(other_shape) > 2 + ) + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + test_explicit_batch_dim=test_explicit_batch_dim, + ) + + # FIXME: dynamic shape is giving bmm + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py new file mode 100644 index 0000000000..c80216654c --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops={torch.ops.aten.rsqrt.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py new file mode 100644 index 0000000000..268df8ccfd --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsub(self, _, x, alpha): + class rsub(nn.Module): + def forward(self, input): + return torch.rsub(input, input, alpha=alpha) + + inputs = [torch.randn(x)] + self.run_test( + rsub(), + inputs, + expected_ops={torch.ops.aten.rsub.Tensor}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py new file mode 100644 index 0000000000..1d5cb84f31 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py @@ -0,0 +1,83 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSelectConverterOne(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input = [torch.randn(1, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.select.int}, + test_explicit_precision=True, + ) + + +class TestSelectConverterTwo(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input = [torch.randn(4, 4, 4, 4)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.select.int}, + test_explicit_precision=True, + ) + + +class TestSelectConverterWithDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select_with_dynamic_shape(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input_spec = [ + InputTensorSpec( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_spec, expected_ops={torch.ops.aten.select.int} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py new file mode 100644 index 0000000000..a6e501daa0 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSeLUConverter(DispatchTestCase): + def test_selu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_selu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_selu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_sigmoid_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_sigmoid_aten.py new file mode 100644 index 0000000000..d9557d271b --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_sigmoid_aten.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSigmoidConverter(DispatchTestCase): + def test_sigmoid(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py new file mode 100644 index 0000000000..6ddc082657 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_slice_aten.py @@ -0,0 +1,89 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSelectConverterImplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 0, 0, 7, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input = [torch.randn(10, 2, 3, 1)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + +class TestSelectConverterExplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 1, 0, 7, 2), + ("select_dim_start_stop_step_exact", 1, 0, 10, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input = [torch.randn(10, 10, 3, 1)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.slice.Tensor}, + test_explicit_precision=True, + ) + + +class TestSelectConverterDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 1, 0, 7, 2), + ("select_dim_start_stop_step", 1, 0, 10, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input_specs = [ + InputTensorSpec( + shape=(1, 10, -1), + dtype=torch.float32, + shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py new file mode 100644 index 0000000000..31e293fc91 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py @@ -0,0 +1,44 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSoftMaxConverter(DispatchTestCase): + def test_softmax(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(1) + + def forward(self, x): + return self.softmax(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default} + ) + + def test_softmax_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(2) + + def forward(self, x): + return self.softmax(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py new file mode 100644 index 0000000000..5c655422de --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (0), (2, 1)), + ("3d_one_dim", (0), (2, 2, 1)), + ("3d_two_dim", (0, 1), (2, 1, 1)), + ("4d_dim", (0, 1, 2), (2, 2, 1, 1)), + ] + ) + def test_squeeze(self, _, dim, init_size): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + inputs = [torch.randn(*init_size)] + expected_op = {} + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + self.run_test( + Squeeze(), + inputs, + expected_ops=expected_op, + ) + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), + ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), + # ("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), + ] + ) + def test_squeeze(self, _, dim, init_size, shape_range): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + input_specs = [ + InputTensorSpec( + shape=init_size, + dtype=torch.float32, + shape_ranges=shape_range, + ), + ] + self.run_test_with_dynamic_shape( + Squeeze(), + input_specs, + expected_ops=expected_op, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py new file mode 100644 index 0000000000..581a5e589f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestTanhConverter(DispatchTestCase): + def test_tanh(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.tanh.default}) + + def test_tanh_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} + ) + + def test_tanh_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py new file mode 100644 index 0000000000..0d4849c21f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestWhereConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_condition_xshape_yshape", (2, 2), (2, 2)), + ("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)), + ("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)), + ("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)), + ] + ) + def test_(self, _, x_size, y_size): + class Where(nn.Module): + def forward(self, condition, x, y): + return torch.where(condition, x, y) + + inputX = torch.randn(*x_size) + inputOther = torch.randn(*y_size) + condition = inputX < 0 + self.run_test( + Where(), + (condition, inputX, inputOther), + expected_ops={torch.ops.aten.where.self}, + ) + + +# FIXME: How to specify condition for dynamic shape +# InputTensorSpec like case below where one input is dynamic another is not +# class TestWhereConverter(DispatchTestCase): +# @parameterized.expand( +# [ +# ("2d_dim", (-1, 2), [((1, 2), (2, 2), (2, 2))], (2,2)) +# #("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), +# #("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), +# ] +# ) +# def test_where(self, _, x_size, x_size_range, y_size): +# class Where(nn.Module): +# def forward(self, condition, x, y): +# return torch.where(condition, x, y) +# inputX = InputTensorSpec( +# shape=x_size, +# dtype=torch.float32, +# shape_ranges=x_size_range, +# ) +# inputOther = torch.randn(*y_size) +# condition = (inputOther < 0) +# input_specs = [ +# inputX, inputOther, condition +# ] +# self.run_test_with_dynamic_shape( +# Where(), +# input_specs, +# expected_ops=torch.ops.aten.where.self, +# ) + +# if __name__ == "__main__": +# run_tests() diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py new file mode 100644 index 0000000000..e53f0bc64e --- /dev/null +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -0,0 +1,114 @@ +import torch +import traceback +import torch._dynamo as td + +from torch_tensorrt.fx.fx2trt import ( + InputTensorSpec, + TRTInterpreter, +) +import tensorrt as trt +from torch_tensorrt.fx.tools.trt_splitter import ( + TRTSplitter, + TRTSplitterSetting, +) +from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision + +from torch._dynamo.backends.common import fake_tensor_unsupported + +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler + +from torch._inductor.decomposition import decompositions + +DECOMPOSITIONS = decompositions.copy() +MAX_SPLITS_THRESHOLD = 100 + + +def tensorrt_backend(gm, sample_inputs): + # Invoke AOTAutograd to compile model + return aot_module_simplified( + gm, + sample_inputs, + fw_compiler=make_boxed_compiler(fx2trt_compiler), + decompositions=DECOMPOSITIONS, + ) + + +def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): + model = gm + inputs = example_inputs + + # Perform lowering pass on model + model = aten_tracer.opt_trace(model, inputs, perform_trace=False) + + # Split out unsupported ops --> Needs rewrite/revision for ATEN + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = False + splitter = TRTSplitter(model, inputs, settings=splitter_setting) + + splitter.node_support_preview() + split_mod = splitter() + num_pieces = 0 + + for name, _ in split_mod.named_children(): + print(f"Graph is split into {name}") + num_pieces += 1 + + # Select threshold above which segmentation is not beneficial and run graph in Torch + if num_pieces > MAX_SPLITS_THRESHOLD: + raise AssertionError( + f"The graph module is split into {num_pieces} which is large than the \ + threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module." + ) + + precision = LowerPrecision.FP32 + + def get_submod_inputs(mod, submod, inputs): + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + + for name, _ in split_mod.named_children(): + if "_run_on_acc" in name: + submod = getattr(split_mod, name) + acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + logger_level=trt.Logger.VERBOSE, + ) + r = interp.run( + max_workspace_size=20 << 30, + lower_precision=precision, + profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, + ) + + trt_mod = TRTModule(*r) + + setattr(split_mod, name, trt_mod) + + return split_mod + + +@td.register_backend +@fake_tensor_unsupported +def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): + try: + trt_compiled = fx2trt(gm, example_inputs) + return trt_compiled + except Exception: + traceback.print_exc() + print( + "FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead" + ) + return gm.forward