1111from torch_tensorrt .dynamo .conversion .converter_utils import (
1212 cast_trt_tensor ,
1313 get_positive_dim ,
14- get_trt_tensor ,
1514 set_layer_name ,
1615)
1716
@@ -60,13 +59,47 @@ def unify_and_concat_trt_tensors(
6059 if not has_dynamic and not force_trt_output :
6160 return trt_tensors # all ints
6261
62+ final_dtype = None
63+ if cast_dtype :
64+ # Explicit cast requested
65+ if isinstance (cast_dtype , _enums .dtype ):
66+ final_dtype = cast_dtype .to (trt .DataType )
67+ elif isinstance (cast_dtype , np .dtype ):
68+ final_dtype = _enums .dtype ._from (cast_dtype ).to (trt .DataType )
69+ else :
70+ final_dtype = cast_dtype # already trt.DataType
71+ else :
72+ # Automatic promotion
73+ promoted_type = None
74+ for t in trt_tensors :
75+ if isinstance (t , TRTTensor ):
76+ if promoted_type is None :
77+ promoted_type = t .dtype
78+ else :
79+ promoted_type = _enums .dtype ._from (
80+ torch .promote_types (
81+ _enums .dtype ._from (promoted_type ).to (torch .dtype ),
82+ _enums .dtype ._from (t .dtype ).to (torch .dtype ),
83+ )
84+ ).to (trt .DataType )
85+ final_dtype = promoted_type
86+
6387 # promote remaining ints to TRT consts before concat
6488 for i , t in enumerate (trt_tensors ):
6589 if isinstance (t , int ):
6690 const = ctx .net .add_constant ((1 ,), np .array ([t ], dtype = np .int32 ))
6791 set_layer_name (const , target , f"{ name } _static_{ i } _const" )
6892 trt_tensors [i ] = const .get_output (0 )
6993
94+ # final cast
95+ if final_dtype is not None :
96+ casted = []
97+ for i , t in enumerate (trt_tensors ):
98+ if isinstance (t , TRTTensor ):
99+ t = cast_trt_tensor (ctx , t , final_dtype , f"{ name } _cast_{ i } " )
100+ casted .append (t )
101+ trt_tensors = casted
102+
70103 concat = ctx .net .add_concatenation (trt_tensors )
71104 concat .axis = concat_axis
72105 set_layer_name (concat , target , f"{ name } _concat" )
@@ -82,45 +115,17 @@ def cat(
82115 dim : int ,
83116 cast_dtype : Union [_enums .dtype , trt .DataType , np .dtype ] = None ,
84117) -> Union [TRTTensor , Sequence [TRTTensor ]]:
85- trt_inputs = []
86- for i , each_input in enumerate (input ):
87- if not isinstance (each_input , TRTTensor ):
88- each_input = get_trt_tensor (ctx , each_input , f"{ name } _tensor_{ i } " )
89- if cast_dtype :
90- each_input = cast_trt_tensor (
91- ctx , each_input , cast_dtype , f"{ name } _tensor_int32_cast_{ i } "
92- )
93- trt_inputs .append (each_input )
94-
95- if len (trt_inputs ) > 1 :
96- # Cast to promoted type for all inputs
97- promoted_type = trt_inputs [0 ].dtype
98- for each_input in trt_inputs [1 :]:
99- promoted_type = _enums .dtype ._from (
100- torch .promote_types (
101- _enums .dtype ._from (promoted_type ).to (torch .dtype ),
102- _enums .dtype ._from (each_input .dtype ).to (torch .dtype ),
103- )
104- )
105- trt_promoted_type = promoted_type .to (trt .DataType )
106-
107- trt_casted_inputs = []
108- for i , each_input in enumerate (trt_inputs ):
109- casted_input = cast_trt_tensor (
110- ctx , each_input , trt_promoted_type , f"{ name } _input_casted_{ i } "
111- )
112- trt_casted_inputs .append (casted_input )
113- trt_inputs = trt_casted_inputs
118+ # int is only when cat called in other ops like pad
119+ if not isinstance (input [0 ], int ):
120+ dim = get_positive_dim (dim , len (input [0 ].shape ))
114121 else :
115- trt_promoted_type = None
116-
117- dim = get_positive_dim (dim , len (trt_inputs [0 ].shape ))
122+ dim = 0
118123 return unify_and_concat_trt_tensors (
119124 ctx ,
120125 target ,
121126 name ,
122- trt_inputs ,
127+ input ,
123128 concat_axis = dim ,
124- cast_dtype = trt_promoted_type ,
129+ cast_dtype = cast_dtype ,
125130 force_trt_output = True ,
126131 )
0 commit comments