|
1 |
| -from typing import List, Optional, Sequence, cast |
| 1 | +from typing import List, Optional, Sequence |
2 | 2 |
|
3 | 3 | from torch.fx.node import Target
|
4 | 4 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
5 | 5 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
|
6 | 6 | from torch_tensorrt.dynamo.conversion.converter_utils import (
|
7 |
| - get_positive_dim, |
8 | 7 | get_trt_tensor,
|
| 8 | + set_layer_name, |
9 | 9 | )
|
10 |
| -from torch_tensorrt.fx.converters.converter_utils import set_layer_name |
11 |
| -from torch_tensorrt.fx.types import Shape, TRTTensor |
| 10 | +from torch_tensorrt.dynamo.types import TRTTensor |
12 | 11 |
|
13 | 12 |
|
14 | 13 | def unsqueeze(
|
15 | 14 | ctx: ConversionContext,
|
16 | 15 | target: Target,
|
17 | 16 | source_ir: Optional[SourceIR],
|
18 | 17 | name: str,
|
19 |
| - input_t: TRTTensor, |
20 |
| - dim: Shape, |
| 18 | + input: TRTTensor, |
| 19 | + dim: int, |
21 | 20 | ) -> TRTTensor:
|
22 |
| - input_val = get_trt_tensor(ctx, input_t, f"{name}_input_t") |
23 |
| - if not isinstance(input_val, TRTTensor): |
24 |
| - raise RuntimeError( |
25 |
| - f"unsqueeze received input {input_val} that is not part " |
26 |
| - "of the TensorRT region!" |
27 |
| - ) |
28 |
| - |
29 |
| - dim = cast(int, dim) |
30 |
| - |
31 |
| - input_shape_size = len(input_val.shape) |
32 |
| - dim = get_positive_dim(dim, input_shape_size + 1) |
33 |
| - |
34 |
| - intermediate_dim = 0 |
35 |
| - dynamic_shape_cnt = 0 |
36 |
| - # if unsqueeze the last dimensions, we can directly append to the shape |
37 |
| - if dim == input_shape_size: |
38 |
| - intermediate_dim = dim |
39 |
| - else: |
40 |
| - # since maximum of one dimension is permitted to be specified as -1 |
41 |
| - # find the intermediate_dim which has only 1 dynamic_shape_cnt |
42 |
| - # and then we can add a transpose after reshape if it is not the final shape we want |
43 |
| - for i, s in reversed(list(enumerate(input_val.shape))): |
44 |
| - if i >= dim: |
45 |
| - if s == -1: |
46 |
| - dynamic_shape_cnt += 1 |
47 |
| - if dynamic_shape_cnt > 1: |
48 |
| - intermediate_dim = i + 1 |
49 |
| - break |
50 |
| - if i == dim: |
51 |
| - intermediate_dim = i |
52 |
| - break |
53 |
| - # calculate the new_shape for the shuffle layer's reshape_dims |
54 |
| - new_shape = list( |
55 |
| - tuple(input_val.shape)[:intermediate_dim] |
56 |
| - + (1,) |
57 |
| - + tuple(input_val.shape)[intermediate_dim:] |
58 |
| - ) |
59 |
| - for i, s in enumerate(new_shape): |
60 |
| - if i < intermediate_dim and s == -1: |
61 |
| - new_shape[i] = 0 |
62 |
| - layer = ctx.net.add_shuffle(input_val) |
63 |
| - layer.reshape_dims = tuple(new_shape) |
64 |
| - # if the intermediate_dim is not the final dim we want to unsqueeze, add a second_transpose after reshape |
65 |
| - if intermediate_dim != dim: |
66 |
| - # calculate the second_transpose for the shuffle layer |
67 |
| - permutation = [*range(0, len(new_shape))] |
68 |
| - # for example: if the reshape_dims is (3, 3, 5, 1, 5) and the final shape we want is (3, 1, 3, 5, 5) |
69 |
| - # here intermediate_dim=3, dim=1, we need to move intermediate_dim before [dim: intermediate_dim) |
70 |
| - new_permutation = ( |
71 |
| - tuple(permutation[:dim]) |
72 |
| - + (intermediate_dim,) |
73 |
| - + tuple(permutation[dim:intermediate_dim]) |
74 |
| - + tuple(permutation[intermediate_dim + 1 :]) |
75 |
| - ) |
76 |
| - layer.second_transpose = new_permutation |
| 21 | + axes = get_trt_tensor(ctx, dim, f"{name}_axes") |
| 22 | + layer = ctx.net.add_unsqueeze(input, axes) |
77 | 23 | set_layer_name(layer, target, name, source_ir)
|
78 | 24 | return layer.get_output(0)
|
79 | 25 |
|
|
0 commit comments