Skip to content

Commit

Permalink
feat: support for many padding dynamo converters (#2482)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Nov 30, 2023
1 parent 4f8eb56 commit 9b88e92
Show file tree
Hide file tree
Showing 5 changed files with 574 additions and 6 deletions.
122 changes: 122 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2304,3 +2304,125 @@ def aten_ops_addmm(
beta=kwargs.get("beta", 1),
alpha=kwargs.get("alpha", 1),
)


@dynamo_tensorrt_converter(torch.ops.aten.constant_pad_nd.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_constant_pad(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pad.constant_padNd(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args_bounds_check(args, 2, 0),
)


@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad1d.default)
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad3d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_reflection_pad(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pad.reflection_padNd(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.replication_pad1d.default)
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad3d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_replication_pad(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pad.replication_padNd(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten._pad_circular.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_circular_pad(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pad.circular_padNd(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.pad.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_pad(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pad.pad(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
pad=args[1],
mode=args_bounds_check(args, 2, "constant"),
value=args_bounds_check(args, 3, None),
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
linear,
matmul,
normalization,
pad,
permutation,
pool,
reduce,
Expand Down
11 changes: 5 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/impl/cat.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Dict, Optional, Sequence, Union
from typing import Optional, Sequence, Union

import numpy as np
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
get_positive_dim,
get_trt_tensor,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.fx.types import TRTTensor


def cat(
Expand All @@ -23,12 +22,12 @@ def cat(
dim: int,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
trt_inputs = []
for each_input in input:
for i, each_input in enumerate(input):
if not isinstance(each_input, TRTTensor):
each_input = get_trt_tensor(ctx, each_input, name + "_tensor_{i}")
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
trt_inputs.append(each_input)
concat_layer = ctx.net.add_concatenation(trt_inputs)
dim = get_positive_dim(dim, len(input[0].shape))
concat_layer.axis = dim
set_layer_name(concat_layer, target, name + "_gather", source_ir)
set_layer_name(concat_layer, target, f"{name}_gather", source_ir)
return concat_layer.get_output(0)
205 changes: 205 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import Optional, Sequence, Union

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTTensor

"""
Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding
mode and clamp, and supports padding output with dynamic shape.
"""


def constant_padNd(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
pad: Sequence[int],
value: Union[int, float] = 0,
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."

rank = len(input.shape)

if len(pad) // 2 > rank:
raise RuntimeError(
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension."
)

start_list = [0] * rank
new_shape = list(input.shape)

for i in range(0, len(pad) // 2):
start_list[-i - 1] = -pad[i * 2]
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]

stride_list = [1] * rank
layer = ctx.net.add_slice(
input,
start=tuple(start_list),
shape=tuple(new_shape),
stride=tuple(stride_list),
)
value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype)
layer.set_input(4, value_const)
layer.mode = trt.SliceMode.FILL

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def reflection_padNd(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
padding: Sequence[int],
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."

rank = len(input.shape)

if len(padding) // 2 > rank:
raise RuntimeError(
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension."
)

start_list = [0] * rank
new_shape = list(input.shape)

for i in range(0, len(padding) // 2):
start_list[-i - 1] = -padding[i * 2]
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]

stride_list = [1] * rank
layer = ctx.net.add_slice(
input,
start=tuple(start_list),
shape=tuple(new_shape),
stride=tuple(stride_list),
)
layer.mode = trt.SliceMode.REFLECT

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def replication_padNd(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
padding: Sequence[int],
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."

rank = len(input.shape)

if len(padding) // 2 > rank:
raise RuntimeError(
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension."
)

start_list = [0] * rank
new_shape = list(input.shape)

for i in range(0, len(padding) // 2):
start_list[-i - 1] = -padding[i * 2]
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]

stride_list = [1] * rank
layer = ctx.net.add_slice(
input,
start=tuple(start_list),
shape=tuple(new_shape),
stride=tuple(stride_list),
)
layer.mode = trt.SliceMode.CLAMP

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def circular_padNd(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
pad: Sequence[int],
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."

rank = len(input.shape)

if len(pad) // 2 > rank:
raise RuntimeError(
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension."
)

start_list = [0] * rank
new_shape = list(input.shape)

for i in range(0, len(pad) // 2):
start_list[-i - 1] = -pad[i * 2]
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]

stride_list = [1] * rank
layer = ctx.net.add_slice(
input,
start=tuple(start_list),
shape=tuple(new_shape),
stride=tuple(stride_list),
)
layer.mode = trt.SliceMode.WRAP

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def pad(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
pad: Sequence[int],
mode: str = "constant",
value: Optional[float] = None,
) -> TRTTensor:
if mode == "constant":
return constant_padNd(
ctx,
target,
source_ir,
f"{name}_{mode}",
input,
pad,
value if value is not None else 0,
)
elif mode == "reflect":
return reflection_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad)
elif mode == "replicate":
return replication_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad)
elif mode == "circular":
return circular_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad)
else:
raise RuntimeError(
f'We currently only support for `mode` in ["constant", "reflect", "replicate", "circular"], but got {mode}'
)
Loading

0 comments on commit 9b88e92

Please sign in to comment.