Skip to content

Commit 2cae047

Browse files
committed
feat: support aten._local_scalar_dense converter
1 parent a8a0797 commit 2cae047

File tree

3 files changed

+105
-0
lines changed

3 files changed

+105
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+17
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,23 @@ def aten_ops_isnan(
17141714
)
17151715

17161716

1717+
@dynamo_tensorrt_converter(torch.ops.aten._local_scalar_dense.default)
1718+
def aten_ops_local_scalar_dense(
1719+
ctx: ConversionContext,
1720+
target: Target,
1721+
args: Tuple[Argument, ...],
1722+
kwargs: Dict[str, Argument],
1723+
name: str,
1724+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1725+
return impl.unary.local_scalar_dense(
1726+
ctx,
1727+
target,
1728+
SourceIR.ATEN,
1729+
name,
1730+
args[0],
1731+
)
1732+
1733+
17171734
@dynamo_tensorrt_converter(operator.add, supports_dynamic_shapes=True)
17181735
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor, supports_dynamic_shapes=True)
17191736
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar, supports_dynamic_shapes=True)

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

+23
Original file line numberDiff line numberDiff line change
@@ -571,3 +571,26 @@ def isnan(
571571
)
572572

573573
return nan_values_mask
574+
575+
576+
def local_scalar_dense(
577+
ctx: ConversionContext,
578+
target: Union[Target, str],
579+
source_ir: Optional[SourceIR],
580+
name: str,
581+
input: TRTTensor,
582+
) -> TRTTensor:
583+
start = [0] * len(input.shape)
584+
shape = [1] * len(input.shape) # Get one element from each dimension
585+
stride = [1] * len(input.shape) # Step through each dimension by 1
586+
587+
layer = ctx.net.add_slice(input=input, start=start, shape=shape, stride=stride)
588+
set_layer_name(layer, target, f"{name}_slice", source_ir)
589+
590+
reshape_layer = ctx.net.add_shuffle(layer.get_output(0))
591+
reshape_layer.reshape_dims = [
592+
1,
593+
] # Reshape to a single-element tensor
594+
set_layer_name(reshape_layer, target, f"{name}_reshape", source_ir)
595+
596+
return reshape_layer.get_output(0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
from harness import DispatchTestCase
4+
from parameterized import parameterized
5+
from torch.testing._internal.common_utils import run_tests
6+
7+
8+
class TestLocalScalarDenseConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
(
12+
torch.randn((5, 10, 5), dtype=torch.float32),
13+
),
14+
(
15+
torch.randint(-10, 10, (5, 1, 15), dtype=torch.int32),
16+
),
17+
(
18+
torch.randn((1), dtype=torch.float32),
19+
),
20+
(
21+
(torch.tensor([-2.4])),
22+
),
23+
(
24+
(torch.tensor([5.5, 3.5, 3.6])),
25+
),
26+
(
27+
(torch.tensor([True])),
28+
),
29+
(
30+
torch.tensor(
31+
[
32+
float("nan"),
33+
1.23,
34+
float("inf"),
35+
]
36+
),
37+
),
38+
(
39+
torch.tensor(
40+
[
41+
float("-inf"),
42+
1.23,
43+
float("nan"),
44+
]
45+
),
46+
),
47+
(
48+
(torch.tensor([float("inf")])),
49+
),
50+
]
51+
)
52+
def test_local_scalar_dense(self, data):
53+
class local_scalar_dense(nn.Module):
54+
def forward(self, input):
55+
return torch.ops.aten._local_scalar_dense.default(input)
56+
57+
inputs = [data]
58+
self.run_test(
59+
local_scalar_dense(),
60+
inputs,
61+
)
62+
63+
64+
if __name__ == "__main__":
65+
run_tests()

0 commit comments

Comments
 (0)