Skip to content

Commit 5b77d3f

Browse files
committed
add finalcast for cat case
1 parent 3d3a8ee commit 5b77d3f

File tree

1 file changed

+40
-35
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+40
-35
lines changed

py/torch_tensorrt/dynamo/conversion/impl/cat.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
1212
cast_trt_tensor,
1313
get_positive_dim,
14-
get_trt_tensor,
1514
set_layer_name,
1615
)
1716

@@ -60,13 +59,47 @@ def unify_and_concat_trt_tensors(
6059
if not has_dynamic and not force_trt_output:
6160
return trt_tensors # all ints
6261

62+
final_dtype = None
63+
if cast_dtype:
64+
# Explicit cast requested
65+
if isinstance(cast_dtype, _enums.dtype):
66+
final_dtype = cast_dtype.to(trt.DataType)
67+
elif isinstance(cast_dtype, np.dtype):
68+
final_dtype = _enums.dtype._from(cast_dtype).to(trt.DataType)
69+
else:
70+
final_dtype = cast_dtype # already trt.DataType
71+
else:
72+
# Automatic promotion
73+
promoted_type = None
74+
for t in trt_tensors:
75+
if isinstance(t, TRTTensor):
76+
if promoted_type is None:
77+
promoted_type = t.dtype
78+
else:
79+
promoted_type = _enums.dtype._from(
80+
torch.promote_types(
81+
_enums.dtype._from(promoted_type).to(torch.dtype),
82+
_enums.dtype._from(t.dtype).to(torch.dtype),
83+
)
84+
).to(trt.DataType)
85+
final_dtype = promoted_type
86+
6387
# promote remaining ints to TRT consts before concat
6488
for i, t in enumerate(trt_tensors):
6589
if isinstance(t, int):
6690
const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32))
6791
set_layer_name(const, target, f"{name}_static_{i}_const")
6892
trt_tensors[i] = const.get_output(0)
6993

94+
# final cast
95+
if final_dtype is not None:
96+
casted = []
97+
for i, t in enumerate(trt_tensors):
98+
if isinstance(t, TRTTensor):
99+
t = cast_trt_tensor(ctx, t, final_dtype, f"{name}_cast_{i}")
100+
casted.append(t)
101+
trt_tensors = casted
102+
70103
concat = ctx.net.add_concatenation(trt_tensors)
71104
concat.axis = concat_axis
72105
set_layer_name(concat, target, f"{name}_concat")
@@ -82,45 +115,17 @@ def cat(
82115
dim: int,
83116
cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None,
84117
) -> Union[TRTTensor, Sequence[TRTTensor]]:
85-
trt_inputs = []
86-
for i, each_input in enumerate(input):
87-
if not isinstance(each_input, TRTTensor):
88-
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
89-
if cast_dtype:
90-
each_input = cast_trt_tensor(
91-
ctx, each_input, cast_dtype, f"{name}_tensor_int32_cast_{i}"
92-
)
93-
trt_inputs.append(each_input)
94-
95-
if len(trt_inputs) > 1:
96-
# Cast to promoted type for all inputs
97-
promoted_type = trt_inputs[0].dtype
98-
for each_input in trt_inputs[1:]:
99-
promoted_type = _enums.dtype._from(
100-
torch.promote_types(
101-
_enums.dtype._from(promoted_type).to(torch.dtype),
102-
_enums.dtype._from(each_input.dtype).to(torch.dtype),
103-
)
104-
)
105-
trt_promoted_type = promoted_type.to(trt.DataType)
106-
107-
trt_casted_inputs = []
108-
for i, each_input in enumerate(trt_inputs):
109-
casted_input = cast_trt_tensor(
110-
ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}"
111-
)
112-
trt_casted_inputs.append(casted_input)
113-
trt_inputs = trt_casted_inputs
118+
# int is only when cat called in other ops like pad
119+
if not isinstance(input[0], int):
120+
dim = get_positive_dim(dim, len(input[0].shape))
114121
else:
115-
trt_promoted_type = None
116-
117-
dim = get_positive_dim(dim, len(trt_inputs[0].shape))
122+
dim = 0
118123
return unify_and_concat_trt_tensors(
119124
ctx,
120125
target,
121126
name,
122-
trt_inputs,
127+
input,
123128
concat_axis=dim,
124-
cast_dtype=trt_promoted_type,
129+
cast_dtype=cast_dtype,
125130
force_trt_output=True,
126131
)

0 commit comments

Comments
 (0)