22
33import logging
44import operator
5- from typing import Callable , Dict , Optional , Sequence , Tuple , Union
5+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
66
77import numpy as np
88import torch
@@ -217,18 +217,50 @@ def aten_ops_native_group_norm(
217217 )
218218
219219
220+ def parse_cat_args (
221+ args : Tuple [Argument , ...], kwargs : Dict [str , Any ]
222+ ) -> Tuple [List [Any ], int ]:
223+ """
224+ Process inputs for torch.ops.aten.cat.default.
225+
226+ Handles these valid patterns:
227+ 1. args = ((t1, t2, ...), dim)
228+ 2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs
229+
230+ Returns:
231+ (input_tensors, dim)
232+ input_tensors: tuple of tensor arguments
233+ dim: integer concatenation dimension (default 0)
234+ """
235+
236+ if len (args ) > 1 and isinstance (args [0 ], (list , tuple )):
237+ input_tensors = list (args [0 ])
238+ dim = args_bounds_check (args , 1 , 0 )
239+
240+ else :
241+ # If single arg is itself a tuple/list, unwrap it
242+ if len (args ) == 1 and isinstance (args [0 ], (list , tuple )):
243+ input_tensors = list (args [0 ])
244+ else :
245+ input_tensors = list (args )
246+
247+ dim = kwargs .get ("dim" , 0 )
248+
249+ return input_tensors , dim
250+
251+
220252def cat_validator (node : Node , settings : Optional [CompilationSettings ] = None ) -> bool :
221- # empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
222- for each_input in node . args [ 0 ] :
253+ inputs , dim = parse_cat_args ( node . args , node . kwargs )
254+ for each_input in inputs :
223255 if isinstance (each_input , TRTTensor ) and any (s == 0 for s in each_input .shape ):
224256 return False
225257 return True
226258
227259
228260@dynamo_tensorrt_converter (
229261 torch .ops .aten .cat .default ,
230- capability_validator = cat_validator ,
231262 supports_dynamic_shapes = True ,
263+ capability_validator = cat_validator ,
232264)
233265def aten_ops_cat (
234266 ctx : ConversionContext ,
@@ -237,13 +269,14 @@ def aten_ops_cat(
237269 kwargs : Dict [str , Argument ],
238270 name : str ,
239271) -> Union [TRTTensor , Sequence [TRTTensor ]]:
272+ inputs , dim = parse_cat_args (args , kwargs )
240273 return impl .cat .cat (
241274 ctx ,
242275 target ,
243276 SourceIR .ATEN ,
244277 name ,
245- input = args [ 0 ] ,
246- dim = args_bounds_check ( args , 1 , 0 ) ,
278+ input = inputs ,
279+ dim = dim ,
247280 )
248281
249282
0 commit comments