Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fx2trt converters aten::slice,aten::select and aten::matmul #1770

Merged
merged 34 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
96bdead
Implementation of slice and select operations
apbose Mar 22, 2023
85c755b
merging the bose_fx2trt_converter
apbose Mar 23, 2023
c8811cd
select test implementation
apbose Mar 23, 2023
4f18c0f
select aten test
apbose Mar 23, 2023
8303cd5
aten::matmul, aten::slice, aten::select converters
apbose Mar 24, 2023
f1098f2
feat: Add sample torch.compile backend for tensorrt aten path
gs-olive Mar 20, 2023
243bf9b
Add decompositions to aot call
gs-olive Mar 21, 2023
76fd3c8
Mark FX2TRT converter as fake tensor unsupported
gs-olive Mar 27, 2023
6a8102c
Minor naming bugfix
gs-olive Mar 29, 2023
35cf89d
Merge branch 'main' into bose_fx2trt_converters_slice_select
apbose Mar 29, 2023
4d7f83d
Merge branch 'main' into bose_fx2trt_converters_transformer_encoder
apbose Mar 29, 2023
205b321
Merge branch 'bose_fx2trt_converters_slice_select' into bose_fx2trt_c…
apbose Apr 7, 2023
e97ed50
Implementing aten::chunk, aten::layer_norm, aten::softmax, aten::wher…
apbose Apr 7, 2023
c5a4744
Transformer operator changes
apbose Apr 10, 2023
8d4e4b4
Fixing acc split test
apbose Apr 11, 2023
1ab9af5
Bug fix for add_slice
apbose Apr 11, 2023
8de6c9d
dynamic test for slice
apbose Apr 11, 2023
ab89d2b
Correct the output_shape dimension for add_slice
apbose Apr 14, 2023
dba4988
Merge branch 'main' into bose_fx2trt_converters_slice_select
apbose Apr 19, 2023
09a52b9
matmul changes, bmm changes and adding broadcastable
apbose Apr 19, 2023
d1fd1d7
Correcting pre-commit hooks
apbose Apr 19, 2023
1d78f43
feat: Add ts converter support for aten::all.dim (#1840)
mfeliz-cruise Apr 19, 2023
39e4da6
Resolving merge conflicts with bose_fx2trt_converters
apbose Apr 20, 2023
d9545be
Merge branch 'main' into bose_fx2trt_converters_transformer_encoder
apbose Apr 20, 2023
89f2fcf
Merge branch 'bose_fx2trt_converters_slice_select' into bose_fx2trt_c…
apbose Apr 20, 2023
ce7f122
Correcting rsqrt and rsub operator
apbose Apr 20, 2023
30c5fd6
python linting issues and removing chunk test
apbose Apr 20, 2023
7ab071d
Correcting acc squeeze test
apbose Apr 20, 2023
36ac0cf
test_reshape expected ops aten.reshape since aten.view has been remov…
apbose Apr 21, 2023
eb851b1
removing aten.view in lowering pass
apbose Apr 21, 2023
6b234e0
layer_norm test
apbose Apr 21, 2023
95c1ada
correcting linting error
apbose Apr 21, 2023
1a1b809
correcting dynamic shape layer norm
apbose Apr 21, 2023
f721ed1
Merge pull request #1814 from pytorch/bose_fx2trt_converters_transfor…
apbose Apr 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 54 additions & 22 deletions core/conversion/converters/impl/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,36 @@ namespace converters {
namespace impl {
namespace {

nvinfer1::ITensor* anyDimImplementation(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* in_tensor,
int dim,
bool keepdim) {
auto in_dims = in_tensor->getDimensions();
LOG_DEBUG("Dim to reduce (original): " << dim);
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
LOG_DEBUG("Dim to reduce (converted): " << dim);

uint32_t axis_mask = 1 << dim;
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
LOG_DEBUG("Keep dims: " << keepdim);

// Reduce does not work on bool inputs
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
}
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);

TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);

sum_layer->setName(util::node_info(n).c_str());
auto out_tensor =
castITensor(ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
return out_tensor;
}

auto reduce_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern(
Expand Down Expand Up @@ -224,33 +254,35 @@ auto reduce_registrations TORCHTRT_UNUSED =
{"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in_tensor = args[0].ITensorOrFreeze(ctx);
auto in_dims = in_tensor->getDimensions();
auto dim = args[1].unwrapToInt();
LOG_DEBUG("Dim to reduce (original): " << dim);
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
LOG_DEBUG("Dim to reduce (converted): " << dim);

uint32_t axis_mask = 1 << dim;
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));

auto keepdim = args[2].unwrapToBool();
LOG_DEBUG("Keep dims: " << keepdim);

// Reduce does not work on bool inputs
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
in_tensor =
castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
}
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);

TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);

sum_layer->setName(util::node_info(n).c_str());
auto out_tensor = castITensor(
ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
auto out_tensor = anyDimImplementation(ctx, n, in_tensor, dim, keepdim);
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}})
.pattern(
{"aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// use Not(Any(Not(input))) to calculate all without a direct all reduction
auto in_tensor = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
auto keepdim = args[2].unwrapToBool();
if (in_tensor->getType() != nvinfer1::DataType::kBOOL) {
// unary not layer only supports bool inputs
in_tensor = castITensor(
ctx, in_tensor, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_in_to_bool").c_str());
}
auto not_input_layer = ctx->net->addUnary(*in_tensor, nvinfer1::UnaryOperation::kNOT);
TORCHTRT_CHECK(not_input_layer, "Unable to create logical_not layer from node: " << *n);
not_input_layer->setName((util::node_info(n) + "_not_in").c_str());
auto not_in = not_input_layer->getOutput(0);
auto any_out = anyDimImplementation(ctx, n, not_in, dim, keepdim);
auto not_output_layer = ctx->net->addUnary(*any_out, nvinfer1::UnaryOperation::kNOT);
TORCHTRT_CHECK(not_output_layer, "Unable to create logical_not layer from node: " << *n);
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], not_output_layer->getOutput(0));
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}});
} // namespace
} // namespace impl
Expand Down
220 changes: 12 additions & 208 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,13 @@ def acc_ops_batch_norm(


@tensorrt_converter(acc_ops.layer_norm)
def acc_ops_layer_norm(network, target, args, kwargs, name):
def acc_ops_layer_norm(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return add_layer_norm(network, target, kwargs, name)


Expand All @@ -690,37 +696,7 @@ def acc_ops_softmax(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"softmax received input {input_val} that is not part "
"of the TensorRT region!"
)

# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
def get_softmax_dim(ndim: int) -> int:
if ndim == 0 or ndim == 1 or ndim == 3:
ret = 0
else:
ret = 1
return ret

if kwargs["dim"] is None:
dim = get_softmax_dim(input_ranks)
else:
dim = cast(int, kwargs["dim"])

dim = get_positive_dim(dim, input_ranks)
if network.has_implicit_batch_dimension:
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
dim -= 1

layer = network.add_softmax(input_val)
layer.axes = 1 << dim
set_layer_name(layer, target, name)
return layer.get_output(0)
return add_softmax(network, target, kwargs, name)


@tensorrt_converter(acc_ops.tile)
Expand Down Expand Up @@ -956,9 +932,7 @@ def acc_ops_sqrt(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.UnaryOperation.SQRT
return add_unary_layer(network, input_val, operation_type, target, name)
return add_sqrt(network, target, kwargs, name)


@tensorrt_converter(acc_ops.reciprocal)
Expand Down Expand Up @@ -1619,40 +1593,7 @@ def acc_ops_squeeze(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"squeeze received input {input_val} that is not part "
"of the TensorRT region!"
)

dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None)
# Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
# dim, which is a very rare case. For now we just claim not supporting dim=None.
assert dim is not None, "We don't support dim=None right now for squeeze."

dim = get_positive_dim(
dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
)
if network.has_implicit_batch_dimension:
assert dim != 0, "We don't support squeeze batch dim when it's implicit."
dim -= 1

assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim."
assert (
len(get_dynamic_dims(input_val.shape)) <= 1
), "Currently more than one dynamic dim for input to squeeze is not supported."

output_shape = []
for i, s in enumerate(input_val.shape):
if i == dim and s == 1:
continue
output_shape.append(s)
layer = network.add_shuffle(input_val)
layer.reshape_dims = tuple(output_shape)
set_layer_name(layer, target, name)
return layer.get_output(0)
return add_squeeze(network, target, kwargs, name)


@tensorrt_converter(acc_ops.add)
Expand Down Expand Up @@ -2022,89 +1963,7 @@ def acc_ops_where(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

condition_t = kwargs["condition"]
x_t = kwargs["x"]
y_t = kwargs["y"]

if type(x_t) != TRTTensor:
assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!"

if type(y_t) != TRTTensor:
assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!"

# get output shape

x_shape = list(x_t.shape)
y_shape = list(y_t.shape)
condition_shape = list(condition_t.shape)
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))

# expand shape
if type(condition_t) != TRTTensor:
assert condition_t.dtype == torch.bool, "condition dtype is not bool"
if condition_shape != output_shape:
condition_t.expand(output_shape)
condition_t = condition_t.to(torch.int32)
condition_const = get_trt_tensor(network, condition_t, f"{name}_condition")
condition_layer = network.add_identity(condition_const)
condition_layer.set_output_type(0, trt.bool)
set_layer_name(condition_layer, target, f"{name}_condition")
condition_val = condition_layer.get_output(0)
else:
assert condition_t.dtype == trt.bool, "mask dtype is not bool!"
if condition_shape != output_shape:
condition_val = acc_ops_expand_tensor(
network,
target,
None,
{"input": condition_t, "sizes": output_shape},
name=f"{name}_expand",
)
else:
condition_val = condition_t

if type(x_t) != TRTTensor:
if x_shape != output_shape:
# special case where 1 element in x_t
if len(x_t.shape) == 0:
x_t = x_t.unsqueeze(0)
x_t = x_t.expand(output_shape)
x_val = get_trt_tensor(network, x_t, f"{name}_x")
else:
x_val = x_t
if x_shape != output_shape:
x_val = acc_ops_expand_tensor(
network,
target,
None,
{"input": x_val, "sizes": output_shape},
name=f"{name}_x_expand",
)

if type(y_t) != TRTTensor:
if y_shape != output_shape:
# special case where 1 element in y_t
if len(y_t.shape) == 0:
y_t = y_t.unsqueeze(0)
y_t = y_t.expand(output_shape)
y_val = get_trt_tensor(network, y_t, f"{name}_y")
else:
y_val = y_t
if y_shape != output_shape:
y_val = acc_ops_expand_tensor(
network,
target,
None,
{"input": y_val, "sizes": output_shape},
name=f"{name}_y_expand",
)

select_layer = network.add_select(condition_val, x_val, y_val)

set_layer_name(select_layer, target, f"{name}_select")

return select_layer.get_output(0)
return add_where(network, target, kwargs, name)


@tensorrt_converter(acc_ops.masked_fill, no_implicit_batch_dim=True)
Expand Down Expand Up @@ -2721,62 +2580,7 @@ def acc_ops_chunk(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
chunks = cast(int, kwargs["chunks"])
dim = cast(int, kwargs["dim"])
input_dim_size = len(input_val.shape) # type: ignore[union-attr]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"chunk received input {input_val} that is not part "
"of the TensorRT region!"
)

dynamic_shape = has_dynamic_shape(input_val.shape)
if network.has_implicit_batch_dimension:
input_dim_size += 1
dim = get_positive_dim(dim, input_dim_size)
assert dim != 0, "Can't chunk on batch dim when it's implicit!"
dim -= 1
else:
if dynamic_shape:
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
dim = get_positive_dim(dim, input_dim_size)

if chunks > input_val.shape[dim]:
warnings.warn(
f"Asked for {chunks} chunks along dimention "
f"{dim} on tensor with size {input_val.shape}, chunks "
f"will default to {input_val.shape[dim]}",
RuntimeWarning,
)
chunks = input_val.shape[dim]

start = [0] * len(input_val.shape)
stride = [1] * len(start)
offset = 0
split_size = (input_val.shape[dim] + chunks - 1) // chunks

max_offset = input_val.shape[dim]
# add slice layers
output = []
for i in range(chunks):
shape = list(input_val.shape)
shape[dim] = min(split_size, max_offset - offset)
if dynamic_shape:
shape = get_shape_with_dynamic_shape(
network, shape, input_val, target, f"{name}_{i}"
)
start[dim] = offset
layer = network.add_slice(
input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
)
if dynamic_shape:
layer.set_input(2, shape)
offset += split_size
set_layer_name(layer, target, f"{name}_{i}")
output.append(layer.get_output(0))
return output
return add_chunk(network, target, kwargs, name)


@tensorrt_converter(acc_ops.cumsum, no_implicit_batch_dim=True)
Expand Down
Loading