-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fx2trt converters aten::slice,aten::select and aten::matmul #1770
Merged
apbose
merged 34 commits into
bose_fx2trt_converters
from
bose_fx2trt_converters_slice_select
Apr 21, 2023
Merged
Changes from 21 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
96bdead
Implementation of slice and select operations
apbose 85c755b
merging the bose_fx2trt_converter
apbose c8811cd
select test implementation
apbose 4f18c0f
select aten test
apbose 8303cd5
aten::matmul, aten::slice, aten::select converters
apbose f1098f2
feat: Add sample torch.compile backend for tensorrt aten path
gs-olive 243bf9b
Add decompositions to aot call
gs-olive 76fd3c8
Mark FX2TRT converter as fake tensor unsupported
gs-olive 6a8102c
Minor naming bugfix
gs-olive 35cf89d
Merge branch 'main' into bose_fx2trt_converters_slice_select
apbose 4d7f83d
Merge branch 'main' into bose_fx2trt_converters_transformer_encoder
apbose 205b321
Merge branch 'bose_fx2trt_converters_slice_select' into bose_fx2trt_c…
apbose e97ed50
Implementing aten::chunk, aten::layer_norm, aten::softmax, aten::wher…
apbose c5a4744
Transformer operator changes
apbose 8d4e4b4
Fixing acc split test
apbose 1ab9af5
Bug fix for add_slice
apbose 8de6c9d
dynamic test for slice
apbose ab89d2b
Correct the output_shape dimension for add_slice
apbose dba4988
Merge branch 'main' into bose_fx2trt_converters_slice_select
apbose 09a52b9
matmul changes, bmm changes and adding broadcastable
apbose d1fd1d7
Correcting pre-commit hooks
apbose 1d78f43
feat: Add ts converter support for aten::all.dim (#1840)
mfeliz-cruise 39e4da6
Resolving merge conflicts with bose_fx2trt_converters
apbose d9545be
Merge branch 'main' into bose_fx2trt_converters_transformer_encoder
apbose 89f2fcf
Merge branch 'bose_fx2trt_converters_slice_select' into bose_fx2trt_c…
apbose ce7f122
Correcting rsqrt and rsub operator
apbose 30c5fd6
python linting issues and removing chunk test
apbose 7ab071d
Correcting acc squeeze test
apbose 36ac0cf
test_reshape expected ops aten.reshape since aten.view has been remov…
apbose eb851b1
removing aten.view in lowering pass
apbose 6b234e0
layer_norm test
apbose 95c1ada
correcting linting error
apbose 1a1b809
correcting dynamic shape layer norm
apbose f721ed1
Merge pull request #1814 from pytorch/bose_fx2trt_converters_transfor…
apbose File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,76 +1,197 @@ | ||
import numpy as np | ||
import operator | ||
import warnings | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | ||
|
||
# @manual=//deeplearning/trt/python:py_tensorrt | ||
import tensorrt as trt | ||
import torch | ||
from torch.fx.node import Argument, Target | ||
|
||
from ..converter_registry import tensorrt_converter | ||
from ..utils import torch_dtype_from_trt | ||
|
||
from .converter_utils import mark_as_int8_layer | ||
from .converter_utils import set_layer_name | ||
from .converter_utils import get_trt_plugin | ||
|
||
from ..types import ( | ||
Shape, | ||
TRTDataType, | ||
TRTElementWiseOp, | ||
TRTLayer, | ||
TRTNetwork, | ||
TRTPlugin, | ||
TRTPluginFieldCollection, | ||
TRTTensor, | ||
) | ||
|
||
|
||
def add_activation_layer( | ||
network: TRTNetwork, | ||
input_val: TRTTensor, | ||
operation_type: trt.ActivationType, | ||
target: Target, | ||
name: str, | ||
alpha: Optional[Any] = None, | ||
beta: Optional[Any] = None, | ||
dyn_range_fn: Optional[Callable[[float, float], Any]] = None, | ||
) -> TRTTensor: | ||
""" | ||
Add a TensorRT Activation layer to `network`. | ||
|
||
Args: | ||
network (TRTNetwork): TensorRT network object. | ||
input_val (TRTTensor): Input to the activation op. | ||
Must be a TensorRT tensor. | ||
op_type (trt.ElementWiseOperation): Type of the TensorRT activation | ||
operation. | ||
target (Target): Target of fx node. | ||
name (str): The name we want to assign to the created TensorRT layer. | ||
alpha (Optional[Any]): If not None, we will use it to set the alpha | ||
attribute of the created TensorRT activation layer. | ||
beta (Optional[Any]): If not None, we will use it to set the beta | ||
attribute of the created TensorRT activation layer. | ||
dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range | ||
|
||
|
||
Returns: | ||
The output of TensorRT Activation layer. | ||
""" | ||
if not isinstance(input_val, TRTTensor): | ||
raise RuntimeError( | ||
f"{operation_type} received input {input_val} that is not part " | ||
"of the TensorRT region!" | ||
) | ||
layer = network.add_activation(input_val, operation_type) | ||
if alpha is not None: | ||
layer.alpha = alpha | ||
if beta is not None: | ||
layer.beta = beta | ||
set_layer_name(layer, target, name) | ||
|
||
if input_val.dynamic_range is not None: | ||
dyn_range = dyn_range_fn(input_val.dynamic_range) | ||
mark_as_int8_layer(layer, dyn_range) | ||
return layer.get_output(0) | ||
|
||
|
||
def common_activation( | ||
network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name | ||
): | ||
layer = network.add_activation(input=input_val, type=activation_type) | ||
layer.name = layer_name | ||
def add_relu(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
operation_type = trt.ActivationType.RELU | ||
return add_activation_layer(network, input_val, operation_type, target, name) | ||
|
||
if input_val.dynamic_range: | ||
dyn_range = activation_dyn_range_fn(input_val.dynamic_range) | ||
mark_as_int8_layer(layer, dyn_range) | ||
|
||
return layer.get_output(0) | ||
def add_leaky_relu(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
negative_slope = kwargs["negative_slope"] | ||
operation_type = trt.ActivationType.LEAKY_RELU | ||
return add_activation_layer( | ||
network, input_val, operation_type, target, name, negative_slope | ||
) | ||
|
||
|
||
@tensorrt_converter(torch.nn.functional.relu) | ||
@tensorrt_converter(torch.nn.modules.activation.ReLU) | ||
def relu(network, submod, args, kwargs, layer_name): | ||
# args/kwargs should have already been normalized to kwargs | ||
assert len(args) == 0 | ||
def add_elu(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
alpha = kwargs["alpha"] | ||
operation_type = trt.ActivationType.ELU | ||
return add_activation_layer(network, input_val, operation_type, target, name, alpha) | ||
|
||
|
||
if not isinstance(input_val, trt.tensorrt.ITensor): | ||
def add_selu(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
operation_type = trt.ActivationType.SELU | ||
return add_activation_layer(network, input_val, operation_type, target, name) | ||
|
||
|
||
def add_softsign(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
operation_type = trt.ActivationType.SOFTSIGN | ||
return add_activation_layer(network, input_val, operation_type, target, name) | ||
|
||
|
||
def add_tanh(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
operation_type = trt.ActivationType.TANH | ||
return add_activation_layer(network, input_val, operation_type, target, name) | ||
|
||
|
||
def add_gelu(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
if "approximate" in kwargs.keys(): | ||
approximate = kwargs["approximate"] | ||
if approximate != "none": | ||
raise RuntimeError( | ||
"GeLU converter currently doesn't support fast gelu compute" | ||
) | ||
if not isinstance(input_val, TRTTensor): | ||
raise RuntimeError( | ||
f"ReLU received input {input_val} that is not part " | ||
f"GELU received input {input_val} that is not part " | ||
"of the TensorRT region!" | ||
) | ||
if network.has_implicit_batch_dimension: | ||
raise RuntimeError( | ||
"GeLU converter currently doesn't support implicit batch dimension" | ||
) | ||
|
||
plugin_name = "CustomGeluPluginDynamic" | ||
# type_id 0 for float32, 1 for float16 | ||
type_id = trt.PluginField( | ||
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 | ||
) | ||
field_collection = TRTPluginFieldCollection([type_id]) | ||
plugin_version = "1" | ||
|
||
plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) | ||
|
||
layer = network.add_plugin_v2([input_val], plugin) | ||
set_layer_name(layer, target, name) | ||
return layer.get_output(0) | ||
|
||
|
||
def add_hard_sigmoid(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
|
||
def activation_dyn_range_fn(dyn_range): | ||
return max(0, dyn_range[0]), max(0, dyn_range[1]) | ||
if not isinstance(input_val, TRTTensor): | ||
raise RuntimeError( | ||
f"Hard sigmoid received input {input_val} that is not part " | ||
"of the TensorRT region!" | ||
) | ||
|
||
return common_activation( | ||
return add_activation_layer( | ||
network, | ||
submod, | ||
input_val, | ||
trt.ActivationType.RELU, | ||
activation_dyn_range_fn, | ||
layer_name, | ||
trt.ActivationType.HARD_SIGMOID, | ||
target, | ||
name, | ||
alpha=1 / 6, | ||
beta=0.5, | ||
) | ||
|
||
|
||
@tensorrt_converter(torch.nn.modules.activation.Sigmoid) | ||
def sigmoid(network, submod, args, kwargs, layer_name): | ||
# args/kwargs should have already been normalized to kwargs | ||
assert len(args) == 0 | ||
def add_sigmoid(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
|
||
if not isinstance(input_val, trt.tensorrt.ITensor): | ||
if not isinstance(input_val, TRTTensor): | ||
raise RuntimeError( | ||
f"Sigmoid received input {input_val} that is not part " | ||
"of the TensorRT region!" | ||
) | ||
|
||
def activation_dyn_range_fn(dyn_range): | ||
def sigmoid_fn(x): | ||
return 1 / (1 + np.exp(-x)) | ||
return add_activation_layer( | ||
network, input_val, trt.ActivationType.SIGMOID, target, name | ||
) | ||
|
||
return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1]) | ||
|
||
return common_activation( | ||
network, | ||
submod, | ||
input_val, | ||
trt.ActivationType.SIGMOID, | ||
activation_dyn_range_fn, | ||
layer_name, | ||
def add_hard_tanh(network, target, kwargs, name): | ||
input_val = kwargs["input"] | ||
alpha = kwargs["min_val"] | ||
beta = kwargs["max_val"] | ||
if not isinstance(input_val, TRTTensor): | ||
raise RuntimeError( | ||
f"hardtanh received input {input_val} that is not part " | ||
"of the TensorRT region!" | ||
) | ||
operation_type = trt.ActivationType.CLIP | ||
return add_activation_layer( | ||
network, input_val, operation_type, target, name, alpha, beta | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there reason we want to drop this
common_activation
func and split it into many single functions?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Wei, the
common_activation
function is not split, instead it is included in theadd_activation
function in activation.py. The idea is that the activation and the operator functions are split in two (activation.py and operator.py respectively) and theconverter_utils.py
is for the common functions across the two and other operations.