Skip to content

Commit 2ee9299

Browse files
committed
support dds and nonzero op in _PythonTorchTensorRTModule
1 parent 43831dc commit 2ee9299

File tree

6 files changed

+90
-9
lines changed

6 files changed

+90
-9
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+8
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class TRTInterpreterResult(NamedTuple):
6262
input_names: Sequence[str]
6363
output_names: Sequence[str]
6464
weight_name_map: Optional[dict[Any, Any]]
65+
output_shapes: Optional[Sequence[Tuple[int]]]
6566

6667

6768
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
@@ -132,6 +133,7 @@ def __init__(
132133
# Mapping of constants to shapes and dtypes
133134
self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {}
134135
self.weight_name_map: Optional[Dict[str, Any]] = None
136+
self.output_shapes: Sequence[Tuple[int]] = []
135137

136138
# Engine cache for storing and reusing TRT engines
137139
self.engine_cache = engine_cache
@@ -651,6 +653,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
651653
self._input_names,
652654
self._output_names,
653655
self.weight_name_map,
656+
self.output_shapes if self.output_shapes else None,
654657
)
655658
return None
656659

@@ -731,11 +734,16 @@ def run(
731734
engine_bytes.write(serialized_engine)
732735
engine_str = engine_bytes.getvalue()
733736

737+
for node in self.module.graph.nodes:
738+
if node.op == "output":
739+
self.output_shapes.append(tuple(node.meta["tensor_meta"].shape))
740+
734741
return TRTInterpreterResult(
735742
engine_str,
736743
self._input_names,
737744
self._output_names,
738745
self.weight_name_map,
746+
self.output_shapes if self.output_shapes else None,
739747
)
740748

741749
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+17
Original file line numberDiff line numberDiff line change
@@ -3580,3 +3580,20 @@ def aten_ops_full(
35803580
fill_value=args[1],
35813581
dtype=kwargs.get("dtype", None),
35823582
)
3583+
3584+
3585+
@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
3586+
def aten_ops_nonzero(
3587+
ctx: ConversionContext,
3588+
target: Target,
3589+
args: Tuple[Argument, ...],
3590+
kwargs: Dict[str, Argument],
3591+
name: str,
3592+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3593+
return impl.unary.nonzero(
3594+
ctx,
3595+
target,
3596+
SourceIR.ATEN,
3597+
name,
3598+
args[0],
3599+
)

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

+15
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,18 @@ def native_dropout(
625625
mask = np.ones(input_val.shape, dtype=bool)
626626
mask = get_trt_tensor(ctx, mask, f"{name}_mask")
627627
return identity_layer.get_output(0), mask
628+
629+
630+
def nonzero(
631+
ctx: ConversionContext,
632+
target: Target,
633+
source_ir: Optional[SourceIR],
634+
name: str,
635+
input_val: TRTTensor,
636+
) -> TRTTensor:
637+
non_zero_layer = ctx.net.add_non_zero(input_val)
638+
set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir)
639+
shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0))
640+
shuffle_layer.first_transpose = trt.Permutation([1, 0])
641+
set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir)
642+
return shuffle_layer.get_output(0)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
serialized_engine: Optional[bytes] = None,
8989
input_binding_names: Optional[List[str]] = None,
9090
output_binding_names: Optional[List[str]] = None,
91+
output_shapes: Optional[Sequence[Tuple[int]]] = None,
9192
*,
9293
name: str = "",
9394
settings: CompilationSettings = CompilationSettings(),
@@ -100,6 +101,7 @@ def __init__(
100101
serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray
101102
input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules
102103
output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned
104+
output_shapes (Sequence[Tuple]): List of output shapes for the engine. For some cases, output shapes are dynamic and depends on input data, like NonZero op, so we need to explicitly provide output shapes
103105
104106
Keyword Arguments:
105107
name (str): Name for module
@@ -147,6 +149,7 @@ def __init__(
147149
self.output_names = (
148150
output_binding_names if output_binding_names is not None else []
149151
)
152+
self.output_shapes = output_shapes
150153
self.initialized = False
151154
self.target_device_id = (
152155
settings.device.gpu_id
@@ -233,10 +236,12 @@ def setup_engine(self) -> None:
233236
dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype)
234237
for output_name in self.output_names
235238
]
236-
self.output_shapes = [
237-
self.engine.get_tensor_shape(output_name)
238-
for output_name in self.output_names
239-
]
239+
240+
if self.output_shapes is None:
241+
self.output_shapes = [
242+
self.engine.get_tensor_shape(output_name)
243+
for output_name in self.output_names
244+
]
240245

241246
if torch_tensorrt.runtime.get_cudagraphs_mode():
242247
self.cudagraph = torch.cuda.CUDAGraph()
@@ -345,7 +350,7 @@ def setup_input_tensors(
345350
def create_output_tensors(self) -> List[torch.Tensor]:
346351
# create output tensors
347352
outputs: List[torch.Tensor] = []
348-
353+
assert self.output_shapes is not None
349354
for o, _ in enumerate(self.output_names):
350355
output = torch.empty(
351356
size=self.output_shapes[o],
@@ -455,10 +460,11 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
455460
if can_use_pre_allocated_outputs:
456461
outputs = self.pre_allocated_outputs
457462
else:
458-
self.output_shapes = [
459-
tuple(self.context.get_tensor_shape(output_name))
460-
for output_name in self.output_names
461-
]
463+
if self.output_shapes is None:
464+
self.output_shapes = [
465+
tuple(self.context.get_tensor_shape(output_name))
466+
for output_name in self.output_names
467+
]
462468
if DYNAMIC_DIM in self.output_shapes:
463469
raise ValueError(
464470
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."

tests/py/dynamo/conversion/harness.py

+2
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def run_test(
206206
serialized_engine=interpreter_result.serialized_engine,
207207
input_binding_names=list(interpreter_result.input_names),
208208
output_binding_names=list(interpreter_result.output_names),
209+
output_shapes=list(interpreter_result.output_shapes),
209210
name="test_engine",
210211
)
211212
mod = mod.cuda()
@@ -288,6 +289,7 @@ def run_test_custom_compare_results(
288289
serialized_engine=interpreter_result.serialized_engine,
289290
input_binding_names=list(interpreter_result.input_names),
290291
output_binding_names=list(interpreter_result.output_names),
292+
output_shapes=list(interpreter_result.output_shapes),
291293
name="test_engine",
292294
)
293295
res_trt = trt_mod(*cuda_inputs).cpu()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestAtanConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((10,), torch.int),
13+
((1, 20), torch.int32),
14+
((5, 3), torch.int64),
15+
((2, 3, 4), torch.float),
16+
((2, 3, 4, 5), torch.float),
17+
]
18+
)
19+
def test_atan_float(self, input_shape, dtype):
20+
class atan(nn.Module):
21+
def forward(self, input):
22+
return torch.ops.aten.nonzero.default(input)
23+
24+
inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)]
25+
self.run_test(
26+
atan(),
27+
inputs,
28+
propagate_shapes=True, # it requires propagate_shapes=True to get output shapes
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()

0 commit comments

Comments
 (0)