Skip to content

Commit

Permalink
Merge pull request #2500 from pytorch/view_slice_bugfixes_cherry_pick
Browse files Browse the repository at this point in the history
cherry-pick: View and slice bugfixes
  • Loading branch information
gs-olive authored Nov 29, 2023
2 parents 73fefbb + b5dc751 commit 5b0e5fc
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 30 deletions.
16 changes: 13 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,11 @@ def aten_ops_select(


@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_slice(
ctx: ConversionContext,
target: Target,
Expand All @@ -700,9 +705,9 @@ def aten_ops_slice(
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args_bounds_check(args, 1, replacement=0),
args_bounds_check(args, 2, replacement=None),
args_bounds_check(args, 3, replacement=None),
args_bounds_check(args, 4, replacement=1),
)

Expand Down Expand Up @@ -877,6 +882,11 @@ def aten_ops_clone_copy_placeholder(


@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_expand(
ctx: ConversionContext,
target: Target,
Expand Down
7 changes: 4 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ def get_positive_dim(
) -> Union[int, Tuple[int, ...]]:
"""
Given an integer number or tuple that represents dimension(s) in the array,
transform it to a positive integer dim if it's negative. Otherwise, do
nothing.
transform it to a positive integer dim if it's negative.
Otherwise, truncate it to the dimension size
Args:
dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array.
Expand All @@ -353,7 +353,8 @@ def get_positive_dim(
def positive_dim(d: int) -> int:
if d < 0:
return d % dim_size
return d
else:
return min(d, dim_size)

return (
positive_dim(dim)
Expand Down
24 changes: 9 additions & 15 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ def slice_op( # TODO: This should be slice not whatever is in base
name: str,
input: TRTTensor,
dim: int,
start: int,
stop: int,
start: Optional[int],
stop: Optional[int],
step: int,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"slice_tensor received input {input} that is not part "
"of the TensorRT region!"
)
# Special case for start being None
if start is None:
start = 0

# Special case for stop being None
if stop is None:
stop = input.shape[dim]

dim = get_positive_dim(dim, len(input.shape))
start = get_positive_dim(start, input.shape[dim])
Expand All @@ -39,9 +41,6 @@ def slice_op( # TODO: This should be slice not whatever is in base
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"

if stop == 2**63 - 1:
stop = input.shape[dim]

start_slice = [0] * len(input.shape)
start_slice[dim] = start
stride_slice = [1] * len(input.shape)
Expand All @@ -62,11 +61,6 @@ def expand(
input_t: TRTTensor,
shape: Shape,
) -> TRTTensor:
if not isinstance(input_t, TRTTensor):
raise RuntimeError(
f"expand received input {input_t} that is not a TensorRT ITensor"
)

shape_rank = len(shape)
initial_tensor_rank = len(input_t.shape)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
Expand All @@ -19,6 +20,7 @@
lower_efficient_attention,
fuse_prims_broadcast,
replace_max_pool_with_indices,
view_to_reshape,
]
)

Expand Down
41 changes: 41 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from typing import Callable, List, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def view_to_reshape(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
orig, replacement = view_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm


def view_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]
):
"""Constructs the original and replacement functions for view"""

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)

# Replacement graph
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.reshape.default(input, shape)

return orig, replacement
35 changes: 26 additions & 9 deletions tests/py/dynamo/conversion/test_slice_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from .harness import DispatchTestCase


class TestSelectConverter(DispatchTestCase):
class TestSliceConverter(DispatchTestCase):
@parameterized.expand(
[
("select_dim_start_stop_step", 0, 0, 7, 2),
("select_dim_start_stop_step_offset", 1, 0, 7, 2),
("select_dim_start_stop_step_exact", 1, 0, 10, 2),
("select_dim_start_stop_step_negatives", -3, -2, -1, 1),
("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
("slice_dim_start_stop_step", 0, 0, 7, 2),
("slice_dim_start_stop_step_offset", 1, 0, 7, 2),
("slice_dim_start_stop_step_exact", 1, 0, 10, 2),
("slice_dim_start_stop_step_negatives", -3, -2, -1, 1),
("slice_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
("slice_dim_start_stop_step_past_end", 2, 0, 2048, 1),
("slice_dim_start_stop_step_none", 2, None, None, 1),
]
)
def test_slice(self, _, dim, start, stop, step):
Expand All @@ -32,12 +34,27 @@ def forward(self, input):
input,
)

def test_slice_empty(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
out = torch.ops.aten.slice.Tensor(input)
return out

input = [torch.randn(10, 10, 3, 1)]
self.run_test(
TestModule(),
input,
)


class TestSelectConverterDynamicShape(DispatchTestCase):
class TestSliceConverterDynamicShape(DispatchTestCase):
@parameterized.expand(
[
("select_dim_start_stop_step", 1, 0, 7, 2),
("select_dim_start_stop_step", 1, 0, 10, 2),
("slice_dim_start_stop_step", 1, 0, 7, 2),
("slice_dim_start_stop_step", 1, 0, 10, 2),
]
)
def test_slice(self, _, dim, start, stop, step):
Expand Down
65 changes: 65 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,70 @@ def forward(self, q, k, v):
torch._dynamo.reset()


class TestLowerViewToReshape(TestCase):
def test_view_to_reshape(self):
class ViewToReshape(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.view.default(input, (1, 1, -1))
return out

inputs = [
torch.rand((3, 4, 5, 32)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(ViewToReshape())
expected_ops = {torch.ops.aten.reshape.default}
unexpected_ops = {
torch.ops.aten.view.default,
}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"ViewToReshape TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


if __name__ == "__main__":
run_tests()

0 comments on commit 5b0e5fc

Please sign in to comment.