Skip to content

Commit 3d9171d

Browse files
committed
Using torch.export workflow since compile is showing error in tensor guard
1 parent 4fdc6d0 commit 3d9171d

File tree

2 files changed

+24
-62
lines changed

2 files changed

+24
-62
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def slice_scatter_decomposition(
201201
start = get_positive_dim(start, input_tensor.shape[dim])
202202
if end is None: # Ensure end is int
203203
end = dim_size
204-
end = get_positive_dim(end, input_tensor.shape[dim])
204+
end = (
205+
get_positive_dim(end, input_tensor.shape[dim]) if isinstance(end, int) else end
206+
)
205207
if step is None:
206208
step = 1
207209

@@ -212,7 +214,6 @@ def slice_scatter_decomposition(
212214
if start == 0 and end == dim_size and step == 1:
213215
return src_tensor
214216

215-
# Ensure start, end, and step are all integers
216217
# Ensure start, end, and step are all integers
217218
assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt"
218219
assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt"

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 21 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -817,72 +817,33 @@ class sliceScatter(torch.nn.Module):
817817
def __init__(self, *args, **kwargs) -> None:
818818
super().__init__(*args, **kwargs)
819819

820-
def forward(self, x, src, dim, start=None, end=None, step=1):
821-
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
820+
def forward(self, x, src):
821+
y = torch.ops.aten.slice_scatter(x, src, 1, 6, None, 1)
822822
return y
823823

824-
# Operations expected to be removed in the traced graph after decompositions
825-
expected_ops = {
826-
torch.ops.aten.scatter.src,
827-
}
828-
unexpected_ops = {torch.ops.aten.select_scatter}
829-
830-
a = torch.zeros(8, 8).cuda()
831-
b = torch.ones(8, 2).cuda()
832-
833-
# 0-D tensors for dynamic scalar values
834-
start = torch.tensor(1, dtype=torch.int64).cuda()
835-
end = torch.tensor(6, dtype=torch.int64).cuda()
836-
step = torch.tensor(1, dtype=torch.int64).cuda()
837-
838-
# Mark scalar tensors as dynamic (note: shape = ())
839-
torch._dynamo.mark_dynamic(start, (), min=1, max=3)
840-
torch._dynamo.mark_dynamic(end, (), min=4, max=6)
841-
842-
inputs = (a, b, start, end, None, step)
843824
fx_graph = torch.fx.symbolic_trace(sliceScatter())
844-
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
845-
fx_graph,
846-
inputs,
847-
expected_ops=expected_ops,
848-
unexpected_ops=unexpected_ops,
849-
min_block_size=1,
850-
)
851825

852-
self.assertEqual(
853-
len(unexpected_ops_seen),
854-
0,
855-
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
856-
)
857-
858-
self.assertEqual(
859-
len(expected_ops_unseen),
860-
0,
861-
f"The following expected ops were not encountered: {expected_ops_unseen}",
826+
dim1 = torch.export.Dim("dim1", min=8, max=10)
827+
dynamic_shapes = {
828+
"x": [torch.export.Dim.STATIC, dim1],
829+
"src": [torch.export.Dim.STATIC, None],
830+
}
831+
inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda())
832+
exported_program = torch.export.export(
833+
sliceScatter(), tuple(inputs), dynamic_shapes=dynamic_shapes
862834
)
863-
835+
fx_graph = exported_program.module()
836+
inputs = [
837+
torch_tensorrt.Input(
838+
min_shape=[8, 8], opt_shape=[8, 10], max_shape=[8, 10]
839+
),
840+
torch_tensorrt.Input(min_shape=[8, 2], opt_shape=[8, 2], max_shape=[8, 2]),
841+
]
864842
torch._dynamo.reset()
865-
866-
# Validate that the results between Torch and Torch-TRT are similar
867-
optimized_model = torch_tensorrt.compile(
868-
fx_graph,
869-
"torch_compile",
870-
inputs,
871-
min_block_size=1,
872-
truncate_double=True,
873-
pass_through_build_failures=True,
874-
)
875-
optimized_model_results = optimized_model(*inputs).detach().cpu()
876-
torch_model_results = fx_graph(*inputs).detach().cpu()
877-
878-
max_diff = float(
879-
torch.max(torch.abs(optimized_model_results - torch_model_results))
880-
)
881-
self.assertAlmostEqual(
882-
max_diff,
883-
0,
884-
DECIMALS_OF_AGREEMENT,
885-
f"Slice_scatter TRT outputs don't match with the original model.",
843+
trt_model = torch_tensorrt.dynamo.compile(exported_program, inputs)
844+
inputs = (torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda())
845+
torch.testing.assert_close(
846+
trt_model(*inputs), fx_graph(*inputs), rtol=RTOL, atol=ATOL
886847
)
887848

888849
def test_lowering_select_scatter_dimZero_module(self):

0 commit comments

Comments
 (0)