Skip to content

Commit 426562f

Browse files
authored
Use IUnsqueezeLayer in unsqueeze impl (#3366)
1 parent 1fb79d0 commit 426562f

File tree

2 files changed

+13
-65
lines changed

2 files changed

+13
-65
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import torch
99
from torch.fx.node import Argument, Node, Target
10-
1110
from torch_tensorrt.dynamo._settings import CompilationSettings
1211
from torch_tensorrt.dynamo._SourceIR import SourceIR
1312
from torch_tensorrt.dynamo.conversion import impl
@@ -650,16 +649,19 @@ def aten_ops_erf(
650649
@dynamo_tensorrt_converter(
651650
torch.ops.aten.unsqueeze.default, supports_dynamic_shapes=True
652651
)
652+
@enforce_tensor_types(
653+
{
654+
0: (TRTTensor,),
655+
}
656+
)
653657
def aten_ops_unsqueeze(
654658
ctx: ConversionContext,
655659
target: Target,
656660
args: Tuple[Argument, ...],
657661
kwargs: Dict[str, Argument],
658662
name: str,
659663
) -> Union[TRTTensor, Sequence[TRTTensor]]:
660-
return impl.unsqueeze.unsqueeze(
661-
ctx, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1]
662-
)
664+
return impl.unsqueeze.unsqueeze(ctx, target, SourceIR.ATEN, name, args[0], args[1])
663665

664666

665667
@dynamo_tensorrt_converter(

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

+7-61
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,25 @@
1-
from typing import List, Optional, Sequence, cast
1+
from typing import List, Optional, Sequence
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
55
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
66
from torch_tensorrt.dynamo.conversion.converter_utils import (
7-
get_positive_dim,
87
get_trt_tensor,
8+
set_layer_name,
99
)
10-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
11-
from torch_tensorrt.fx.types import Shape, TRTTensor
10+
from torch_tensorrt.dynamo.types import TRTTensor
1211

1312

1413
def unsqueeze(
1514
ctx: ConversionContext,
1615
target: Target,
1716
source_ir: Optional[SourceIR],
1817
name: str,
19-
input_t: TRTTensor,
20-
dim: Shape,
18+
input: TRTTensor,
19+
dim: int,
2120
) -> TRTTensor:
22-
input_val = get_trt_tensor(ctx, input_t, f"{name}_input_t")
23-
if not isinstance(input_val, TRTTensor):
24-
raise RuntimeError(
25-
f"unsqueeze received input {input_val} that is not part "
26-
"of the TensorRT region!"
27-
)
28-
29-
dim = cast(int, dim)
30-
31-
input_shape_size = len(input_val.shape)
32-
dim = get_positive_dim(dim, input_shape_size + 1)
33-
34-
intermediate_dim = 0
35-
dynamic_shape_cnt = 0
36-
# if unsqueeze the last dimensions, we can directly append to the shape
37-
if dim == input_shape_size:
38-
intermediate_dim = dim
39-
else:
40-
# since maximum of one dimension is permitted to be specified as -1
41-
# find the intermediate_dim which has only 1 dynamic_shape_cnt
42-
# and then we can add a transpose after reshape if it is not the final shape we want
43-
for i, s in reversed(list(enumerate(input_val.shape))):
44-
if i >= dim:
45-
if s == -1:
46-
dynamic_shape_cnt += 1
47-
if dynamic_shape_cnt > 1:
48-
intermediate_dim = i + 1
49-
break
50-
if i == dim:
51-
intermediate_dim = i
52-
break
53-
# calculate the new_shape for the shuffle layer's reshape_dims
54-
new_shape = list(
55-
tuple(input_val.shape)[:intermediate_dim]
56-
+ (1,)
57-
+ tuple(input_val.shape)[intermediate_dim:]
58-
)
59-
for i, s in enumerate(new_shape):
60-
if i < intermediate_dim and s == -1:
61-
new_shape[i] = 0
62-
layer = ctx.net.add_shuffle(input_val)
63-
layer.reshape_dims = tuple(new_shape)
64-
# if the intermediate_dim is not the final dim we want to unsqueeze, add a second_transpose after reshape
65-
if intermediate_dim != dim:
66-
# calculate the second_transpose for the shuffle layer
67-
permutation = [*range(0, len(new_shape))]
68-
# for example: if the reshape_dims is (3, 3, 5, 1, 5) and the final shape we want is (3, 1, 3, 5, 5)
69-
# here intermediate_dim=3, dim=1, we need to move intermediate_dim before [dim: intermediate_dim)
70-
new_permutation = (
71-
tuple(permutation[:dim])
72-
+ (intermediate_dim,)
73-
+ tuple(permutation[dim:intermediate_dim])
74-
+ tuple(permutation[intermediate_dim + 1 :])
75-
)
76-
layer.second_transpose = new_permutation
21+
axes = get_trt_tensor(ctx, dim, f"{name}_axes")
22+
layer = ctx.net.add_unsqueeze(input, axes)
7723
set_layer_name(layer, target, name, source_ir)
7824
return layer.get_output(0)
7925

0 commit comments

Comments
 (0)