Skip to content

Commit b7f2dee

Browse files
committed
addressing cat empty tensor case.Fixes gpt2 data distributed example
1 parent 1d038a1 commit b7f2dee

File tree

4 files changed

+75
-3
lines changed

4 files changed

+75
-3
lines changed

examples/distributed_inference/data_parallel_stable_diffusion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,5 @@
5353

5454
# Assume there are 2 processes (2 devices)
5555
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
56-
print("before \n")
5756
result = pipe(prompt).images[0]
58-
print("after ")
5957
result.save(f"result_{distributed_state.process_index}.png")

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,17 @@ def aten_ops_native_group_norm(
217217
)
218218

219219

220-
@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
220+
def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
221+
# Validate only one user, which is a getitem node that accesses the first element in the list
222+
for each_input in node.args[0]:
223+
if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape):
224+
return False
225+
return True
226+
227+
228+
@dynamo_tensorrt_converter(
229+
torch.ops.aten.cat.default, supports_dynamic_shapes=True, validator=cat_validator
230+
)
221231
def aten_ops_cat(
222232
ctx: ConversionContext,
223233
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/cat.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Optional, Sequence, Union
23

34
import numpy as np
@@ -15,6 +16,8 @@
1516
set_layer_name,
1617
)
1718

19+
logger = logging.getLogger(__name__)
20+
1821

1922
def cat(
2023
ctx: ConversionContext,
@@ -27,6 +30,13 @@ def cat(
2730
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2831
trt_inputs = []
2932
for i, each_input in enumerate(input):
33+
if isinstance(each_input, torch.Tensor) and each_input.numel() == 0:
34+
logger.warning(
35+
f"Warning: empty tensor in cat input {i}, replacing with zeros"
36+
)
37+
# ITensor with same condition leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
38+
# hence the validator
39+
continue
3040
if not isinstance(each_input, TRTTensor):
3141
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
3242
if cast_dtype:

tests/py/dynamo/conversion/test_cat_aten.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,60 @@ def forward(self, x, y, z):
2525
inputs,
2626
)
2727

28+
@parameterized.expand(
29+
[
30+
("pos", 0),
31+
("neg", -3),
32+
]
33+
)
34+
def test_cat_with_scalar_inputs(self, _, dim):
35+
# Ensure scalar tensor wrap works
36+
class Cat(nn.Module):
37+
def forward(self, x, y):
38+
# y is a scalar, x is a tensor
39+
return torch.ops.aten.cat.default((x, y), dim)
40+
41+
x = torch.randn(1, 2, 3, device="cuda")
42+
y = torch.ones_like(x) * 5.0 # simulate scalar broadcast
43+
inputs = [x, y]
44+
self.run_test(Cat(), inputs)
45+
46+
@parameterized.expand(
47+
[
48+
("pos", 0),
49+
("neg", -3),
50+
]
51+
)
52+
def test_cat_with_empty_tensor(self, _, dim):
53+
# Handle empty tensor in concat
54+
class Cat(nn.Module):
55+
def forward(self, x):
56+
y = torch.empty(0, 2, 3, device="cuda")
57+
return torch.ops.aten.cat.default((x, y), dim)
58+
59+
inputs = [
60+
torch.randn(1, 2, 3, device="cuda"),
61+
]
62+
self.run_test(Cat(), inputs)
63+
64+
@parameterized.expand(
65+
[
66+
("pos", 2),
67+
("neg", -1),
68+
]
69+
)
70+
def test_cat_with_different_dtypes(self, _, dim):
71+
# check dtype promotion path in concat
72+
class Cat(nn.Module):
73+
def forward(self, x, y):
74+
return torch.ops.aten.cat.default((x, y), dim)
75+
76+
inputs = [
77+
torch.ones(1, 2, 3, dtype=torch.float32, device="cuda"),
78+
torch.ones(1, 2, 3, dtype=torch.float16, device="cuda"),
79+
]
80+
self.run_test(Cat(), inputs)
81+
2882
@parameterized.expand(
2983
[
3084
("pos", 1),

0 commit comments

Comments
 (0)