Skip to content

Commit 6e386c1

Browse files
committed
expanding cat converter to address CI error
1 parent 88659a1 commit 6e386c1

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import operator
5-
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
5+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
66

77
import numpy as np
88
import torch
@@ -217,18 +217,50 @@ def aten_ops_native_group_norm(
217217
)
218218

219219

220+
def parse_cat_args(
221+
args: Tuple[Argument, ...], kwargs: Dict[str, Any]
222+
) -> Tuple[List[Any], int]:
223+
"""
224+
Process inputs for torch.ops.aten.cat.default.
225+
226+
Handles these valid patterns:
227+
1. args = ((t1, t2, ...), dim)
228+
2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs
229+
230+
Returns:
231+
(input_tensors, dim)
232+
input_tensors: tuple of tensor arguments
233+
dim: integer concatenation dimension (default 0)
234+
"""
235+
236+
if len(args) > 1 and isinstance(args[0], (list, tuple)):
237+
input_tensors = list(args[0])
238+
dim = args_bounds_check(args, 1, 0)
239+
240+
else:
241+
# If single arg is itself a tuple/list, unwrap it
242+
if len(args) == 1 and isinstance(args[0], (list, tuple)):
243+
input_tensors = list(args[0])
244+
else:
245+
input_tensors = list(args)
246+
247+
dim = kwargs.get("dim", 0)
248+
249+
return input_tensors, dim
250+
251+
220252
def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
221-
# empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
222-
for each_input in node.args[0]:
253+
inputs, dim = parse_cat_args(node.args, node.kwargs)
254+
for each_input in inputs:
223255
if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape):
224256
return False
225257
return True
226258

227259

228260
@dynamo_tensorrt_converter(
229261
torch.ops.aten.cat.default,
230-
capability_validator=cat_validator,
231262
supports_dynamic_shapes=True,
263+
capability_validator=cat_validator,
232264
)
233265
def aten_ops_cat(
234266
ctx: ConversionContext,
@@ -237,13 +269,14 @@ def aten_ops_cat(
237269
kwargs: Dict[str, Argument],
238270
name: str,
239271
) -> Union[TRTTensor, Sequence[TRTTensor]]:
272+
inputs, dim = parse_cat_args(args, kwargs)
240273
return impl.cat.cat(
241274
ctx,
242275
target,
243276
SourceIR.ATEN,
244277
name,
245-
input=args[0],
246-
dim=args_bounds_check(args, 1, 0),
278+
input=inputs,
279+
dim=dim,
247280
)
248281

249282

tests/py/dynamo/conversion/test_cat_aten.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@ def forward(self, x, y, z):
2525
inputs,
2626
)
2727

28+
@parameterized.expand(
29+
[
30+
("pos", 1),
31+
("neg", -2),
32+
]
33+
)
34+
def test_cat_dim_in_kwargs(self, _, dim):
35+
class Cat(nn.Module):
36+
def forward(self, x, y, z):
37+
return torch.ops.aten.cat.default((x, y, z), dim=dim)
38+
39+
inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
40+
self.run_test(
41+
Cat(),
42+
inputs,
43+
)
44+
2845
@parameterized.expand(
2946
[
3047
("pos", 0),

0 commit comments

Comments
 (0)