Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fx2trt converters - change of prototype and addition of activation operation #1745

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
840de6b
fx2trt converters - reorg of the existing converters, addition of new…
apbose Mar 10, 2023
3f3a925
aten converter- matmul, tanh, gelu, slice select
apbose Mar 17, 2023
8c8e897
Fixing matmul, select, tanh tests
apbose Mar 17, 2023
4f742a9
Modifications to matmul and select tests
apbose Mar 17, 2023
e8c2786
Fixing aten::select test
apbose Mar 17, 2023
038520d
Removing matmul and select operator
apbose Mar 17, 2023
e8a8e38
fx2trt fixing add converter and python linting changes
apbose Mar 20, 2023
d7c82ab
binary operator changes in aten
apbose Mar 21, 2023
36f9a3f
Removing the aten.slice and add_slice
apbose Mar 21, 2023
979ab42
correcting selu, hard_tanh ops, adding tests for sigmoid, selu, elu a…
apbose Mar 22, 2023
79286e7
Moving funcs to_numpy and trt_dtype_to_torch_dtype from converter_uti…
apbose Mar 22, 2023
2bcf5f4
Move to_numpy implementation to converter_util
apbose Mar 22, 2023
96bdead
Implementation of slice and select operations
apbose Mar 22, 2023
a1d94c1
Fixing the acc tests, logical_and operator and the leaky_relu test
apbose Mar 22, 2023
85c755b
merging the bose_fx2trt_converter
apbose Mar 23, 2023
c8811cd
select test implementation
apbose Mar 23, 2023
4f18c0f
select aten test
apbose Mar 23, 2023
cf96dec
Adding add_slice function in operator.py
apbose Mar 23, 2023
8303cd5
aten::matmul, aten::slice, aten::select converters
apbose Mar 24, 2023
f1098f2
feat: Add sample torch.compile backend for tensorrt aten path
gs-olive Mar 20, 2023
243bf9b
Add decompositions to aot call
gs-olive Mar 21, 2023
76fd3c8
Mark FX2TRT converter as fake tensor unsupported
gs-olive Mar 27, 2023
6a8102c
Minor naming bugfix
gs-olive Mar 29, 2023
35cf89d
Merge branch 'main' into bose_fx2trt_converters_slice_select
apbose Mar 29, 2023
4d7f83d
Merge branch 'main' into bose_fx2trt_converters_transformer_encoder
apbose Mar 29, 2023
205b321
Merge branch 'bose_fx2trt_converters_slice_select' into bose_fx2trt_c…
apbose Apr 7, 2023
e97ed50
Implementing aten::chunk, aten::layer_norm, aten::softmax, aten::wher…
apbose Apr 7, 2023
c5a4744
Transformer operator changes
apbose Apr 10, 2023
8d4e4b4
Fixing acc split test
apbose Apr 11, 2023
1ab9af5
Bug fix for add_slice
apbose Apr 11, 2023
8de6c9d
dynamic test for slice
apbose Apr 11, 2023
ab89d2b
Correct the output_shape dimension for add_slice
apbose Apr 14, 2023
bde4860
Merge branch 'main' into bose_fx2trt_converters
apbose Apr 19, 2023
dba4988
Merge branch 'main' into bose_fx2trt_converters_slice_select
apbose Apr 19, 2023
09a52b9
matmul changes, bmm changes and adding broadcastable
apbose Apr 19, 2023
d1fd1d7
Correcting pre-commit hooks
apbose Apr 19, 2023
39e4da6
Resolving merge conflicts with bose_fx2trt_converters
apbose Apr 20, 2023
d9545be
Merge branch 'main' into bose_fx2trt_converters_transformer_encoder
apbose Apr 20, 2023
89f2fcf
Merge branch 'bose_fx2trt_converters_slice_select' into bose_fx2trt_c…
apbose Apr 20, 2023
ce7f122
Correcting rsqrt and rsub operator
apbose Apr 20, 2023
30c5fd6
python linting issues and removing chunk test
apbose Apr 20, 2023
7ab071d
Correcting acc squeeze test
apbose Apr 20, 2023
36ac0cf
test_reshape expected ops aten.reshape since aten.view has been remov…
apbose Apr 21, 2023
eb851b1
removing aten.view in lowering pass
apbose Apr 21, 2023
6b234e0
layer_norm test
apbose Apr 21, 2023
95c1ada
correcting linting error
apbose Apr 21, 2023
1a1b809
correcting dynamic shape layer norm
apbose Apr 21, 2023
f721ed1
Merge pull request #1814 from pytorch/bose_fx2trt_converters_transfor…
apbose Apr 21, 2023
ae2706d
Merge pull request #1770 from pytorch/bose_fx2trt_converters_slice_se…
apbose Apr 21, 2023
34268e7
merging main
apbose Apr 21, 2023
01e5aa1
removing aten_tracer and lower_basic_pass_aten changes
apbose Apr 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,051 changes: 44 additions & 1,007 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py

Large diffs are not rendered by default.

201 changes: 161 additions & 40 deletions py/torch_tensorrt/fx/converters/activation.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,197 @@
import numpy as np
import operator
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Argument, Target

from ..converter_registry import tensorrt_converter
from ..utils import torch_dtype_from_trt

from .converter_utils import mark_as_int8_layer
from .converter_utils import set_layer_name
from .converter_utils import get_trt_plugin

from ..types import (
Shape,
TRTDataType,
TRTElementWiseOp,
TRTLayer,
TRTNetwork,
TRTPlugin,
TRTPluginFieldCollection,
TRTTensor,
)


def add_activation_layer(
network: TRTNetwork,
input_val: TRTTensor,
operation_type: trt.ActivationType,
target: Target,
name: str,
alpha: Optional[Any] = None,
beta: Optional[Any] = None,
dyn_range_fn: Optional[Callable[[float, float], Any]] = None,
) -> TRTTensor:
"""
Add a TensorRT Activation layer to `network`.

Args:
network (TRTNetwork): TensorRT network object.
input_val (TRTTensor): Input to the activation op.
Must be a TensorRT tensor.
op_type (trt.ElementWiseOperation): Type of the TensorRT activation
operation.
target (Target): Target of fx node.
name (str): The name we want to assign to the created TensorRT layer.
alpha (Optional[Any]): If not None, we will use it to set the alpha
attribute of the created TensorRT activation layer.
beta (Optional[Any]): If not None, we will use it to set the beta
attribute of the created TensorRT activation layer.
dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range


Returns:
The output of TensorRT Activation layer.
"""
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"{operation_type} received input {input_val} that is not part "
"of the TensorRT region!"
)
layer = network.add_activation(input_val, operation_type)
if alpha is not None:
layer.alpha = alpha
if beta is not None:
layer.beta = beta
set_layer_name(layer, target, name)

if input_val.dynamic_range is not None:
dyn_range = dyn_range_fn(input_val.dynamic_range)
mark_as_int8_layer(layer, dyn_range)
return layer.get_output(0)


def common_activation(
network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name
):
layer = network.add_activation(input=input_val, type=activation_type)
layer.name = layer_name
def add_relu(network, target, kwargs, name):
input_val = kwargs["input"]
operation_type = trt.ActivationType.RELU
return add_activation_layer(network, input_val, operation_type, target, name)

if input_val.dynamic_range:
dyn_range = activation_dyn_range_fn(input_val.dynamic_range)
mark_as_int8_layer(layer, dyn_range)

return layer.get_output(0)
def add_leaky_relu(network, target, kwargs, name):
input_val = kwargs["input"]
negative_slope = kwargs["negative_slope"]
operation_type = trt.ActivationType.LEAKY_RELU
return add_activation_layer(
network, input_val, operation_type, target, name, negative_slope
)


@tensorrt_converter(torch.nn.functional.relu)
@tensorrt_converter(torch.nn.modules.activation.ReLU)
def relu(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0
def add_elu(network, target, kwargs, name):
input_val = kwargs["input"]
alpha = kwargs["alpha"]
operation_type = trt.ActivationType.ELU
return add_activation_layer(network, input_val, operation_type, target, name, alpha)


if not isinstance(input_val, trt.tensorrt.ITensor):
def add_selu(network, target, kwargs, name):
input_val = kwargs["input"]
operation_type = trt.ActivationType.SELU
return add_activation_layer(network, input_val, operation_type, target, name)


def add_softsign(network, target, kwargs, name):
input_val = kwargs["input"]
operation_type = trt.ActivationType.SOFTSIGN
return add_activation_layer(network, input_val, operation_type, target, name)


def add_tanh(network, target, kwargs, name):
input_val = kwargs["input"]
operation_type = trt.ActivationType.TANH
return add_activation_layer(network, input_val, operation_type, target, name)


def add_gelu(network, target, kwargs, name):
input_val = kwargs["input"]
if "approximate" in kwargs.keys():
approximate = kwargs["approximate"]
if approximate != "none":
raise RuntimeError(
"GeLU converter currently doesn't support fast gelu compute"
)
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"ReLU received input {input_val} that is not part "
f"GELU received input {input_val} that is not part "
"of the TensorRT region!"
)
if network.has_implicit_batch_dimension:
raise RuntimeError(
"GeLU converter currently doesn't support implicit batch dimension"
)

plugin_name = "CustomGeluPluginDynamic"
# type_id 0 for float32, 1 for float16
type_id = trt.PluginField(
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
)
field_collection = TRTPluginFieldCollection([type_id])
plugin_version = "1"

plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)

layer = network.add_plugin_v2([input_val], plugin)
set_layer_name(layer, target, name)
return layer.get_output(0)


def add_hard_sigmoid(network, target, kwargs, name):
input_val = kwargs["input"]

def activation_dyn_range_fn(dyn_range):
return max(0, dyn_range[0]), max(0, dyn_range[1])
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Hard sigmoid received input {input_val} that is not part "
"of the TensorRT region!"
)

return common_activation(
return add_activation_layer(
network,
submod,
input_val,
trt.ActivationType.RELU,
activation_dyn_range_fn,
layer_name,
trt.ActivationType.HARD_SIGMOID,
target,
name,
alpha=1 / 6,
beta=0.5,
)


@tensorrt_converter(torch.nn.modules.activation.Sigmoid)
def sigmoid(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0
def add_sigmoid(network, target, kwargs, name):
input_val = kwargs["input"]

if not isinstance(input_val, trt.tensorrt.ITensor):
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Sigmoid received input {input_val} that is not part "
"of the TensorRT region!"
)

def activation_dyn_range_fn(dyn_range):
def sigmoid_fn(x):
return 1 / (1 + np.exp(-x))
return add_activation_layer(
network, input_val, trt.ActivationType.SIGMOID, target, name
)

return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1])

return common_activation(
network,
submod,
input_val,
trt.ActivationType.SIGMOID,
activation_dyn_range_fn,
layer_name,
def add_hard_tanh(network, target, kwargs, name):
input_val = kwargs["input"]
alpha = kwargs["min_val"]
beta = kwargs["max_val"]
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"hardtanh received input {input_val} that is not part "
"of the TensorRT region!"
)
operation_type = trt.ActivationType.CLIP
return add_activation_layer(
network, input_val, operation_type, target, name, alpha, beta
)
Loading