Skip to content

Commit 3b60296

Browse files
committed
use IOutputAllocator
1 parent 43831dc commit 3b60296

File tree

4 files changed

+142
-50
lines changed

4 files changed

+142
-50
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import torch
99
from torch.fx.node import Argument, Node, Target
10-
1110
from torch_tensorrt.dynamo._settings import CompilationSettings
1211
from torch_tensorrt.dynamo._SourceIR import SourceIR
1312
from torch_tensorrt.dynamo.conversion import impl
@@ -3580,3 +3579,20 @@ def aten_ops_full(
35803579
fill_value=args[1],
35813580
dtype=kwargs.get("dtype", None),
35823581
)
3582+
3583+
3584+
@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
3585+
def aten_ops_nonzero(
3586+
ctx: ConversionContext,
3587+
target: Target,
3588+
args: Tuple[Argument, ...],
3589+
kwargs: Dict[str, Argument],
3590+
name: str,
3591+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3592+
return impl.unary.nonzero(
3593+
ctx,
3594+
target,
3595+
SourceIR.ATEN,
3596+
name,
3597+
args[0],
3598+
)

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

+15
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,18 @@ def native_dropout(
625625
mask = np.ones(input_val.shape, dtype=bool)
626626
mask = get_trt_tensor(ctx, mask, f"{name}_mask")
627627
return identity_layer.get_output(0), mask
628+
629+
630+
def nonzero(
631+
ctx: ConversionContext,
632+
target: Target,
633+
source_ir: Optional[SourceIR],
634+
name: str,
635+
input_val: TRTTensor,
636+
) -> TRTTensor:
637+
non_zero_layer = ctx.net.add_non_zero(input_val)
638+
set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir)
639+
shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0))
640+
shuffle_layer.first_transpose = trt.Permutation([1, 0])
641+
set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir)
642+
return shuffle_layer.get_output(0)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+78-49
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch_tensorrt._Device import Device
1313
from torch_tensorrt._enums import Platform, dtype
1414
from torch_tensorrt.dynamo._settings import CompilationSettings
15-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1615
from torch_tensorrt.logging import TRT_LOGGER
1716
from torch_tensorrt.runtime._utils import (
1817
_is_switch_required,
@@ -23,6 +22,42 @@
2322
logger = logging.getLogger(__name__)
2423

2524

25+
class OutputAllocator(trt.IOutputAllocator): # type: ignore[misc]
26+
def __init__(self) -> None:
27+
trt.IOutputAllocator.__init__(self)
28+
self.buffers: Dict[str, torch.Tensor] = {}
29+
self.shapes: Dict[str, Tuple[int, ...]] = {}
30+
31+
def reallocate_output(
32+
self, tensor_name: str, memory: int, size: int, alignment: int
33+
) -> Any:
34+
shape = (size,)
35+
if tensor_name not in self.buffers:
36+
self.buffers[tensor_name] = torch.empty(
37+
shape, dtype=torch.float, device=torch.cuda.current_device()
38+
)
39+
else:
40+
self.buffers[tensor_name] = self.resize_or_reallocate(
41+
self.buffers[tensor_name], shape
42+
)
43+
return self.data_ptr(self.buffers[tensor_name])
44+
45+
def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None:
46+
self.shapes[tensor_name] = tuple(shape)
47+
48+
def resize_or_reallocate(
49+
self, buffer: torch.Tensor, shape: Tuple[int, ...]
50+
) -> torch.Tensor:
51+
if buffer.shape != shape:
52+
buffer = torch.empty(
53+
shape, dtype=torch.float, device=torch.cuda.current_device()
54+
)
55+
return buffer
56+
57+
def data_ptr(self, buffer: torch.Tensor) -> Any:
58+
return buffer.data_ptr()
59+
60+
2661
class TorchTRTRuntimeStates:
2762
def __init__(self, new_cudagraphs: bool):
2863
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
@@ -147,6 +182,8 @@ def __init__(
147182
self.output_names = (
148183
output_binding_names if output_binding_names is not None else []
149184
)
185+
self.output_allocator = OutputAllocator()
186+
150187
self.initialized = False
151188
self.target_device_id = (
152189
settings.device.gpu_id
@@ -342,19 +379,6 @@ def setup_input_tensors(
342379
input_name, contiguous_inputs[i].data_ptr()
343380
)
344381

345-
def create_output_tensors(self) -> List[torch.Tensor]:
346-
# create output tensors
347-
outputs: List[torch.Tensor] = []
348-
349-
for o, _ in enumerate(self.output_names):
350-
output = torch.empty(
351-
size=self.output_shapes[o],
352-
dtype=self.output_dtypes[o],
353-
device=torch.cuda.current_device(),
354-
)
355-
outputs.append(output)
356-
return outputs
357-
358382
def set_pre_allocated_outputs(self, enable: bool) -> None:
359383
self.use_pre_allocated_outputs = enable
360384

@@ -445,47 +469,18 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
445469
This could happen if the input tensor addresses/shapes haven't been configured correctly"
446470
)
447471

448-
with (
449-
torch.autograd.profiler.record_function(
450-
"PythonTorchTensorRTModule:ProcessOutputs"
451-
)
452-
if self.profiling_enabled
453-
else nullcontext()
454-
):
455-
if can_use_pre_allocated_outputs:
456-
outputs = self.pre_allocated_outputs
457-
else:
458-
self.output_shapes = [
459-
tuple(self.context.get_tensor_shape(output_name))
460-
for output_name in self.output_names
461-
]
462-
if DYNAMIC_DIM in self.output_shapes:
463-
raise ValueError(
464-
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
465-
)
466-
outputs = self.create_output_tensors()
467-
468-
for o, output_name in enumerate(self.output_names):
469-
470-
if need_cudagraphs_record:
471-
self._output_buffers[o] = outputs[o].clone()
472-
473-
if cudagraphs_enabled:
474-
self.context.set_tensor_address(
475-
output_name, self._output_buffers[o].data_ptr()
476-
)
477-
else:
478-
self.context.set_tensor_address(
479-
output_name, outputs[o].data_ptr()
480-
)
481-
482472
with (
483473
torch.autograd.profiler.record_function(
484474
"PythonTorchTensorRTModule:TensorRTRuntime"
485475
)
486476
if self.profiling_enabled
487477
else nullcontext()
488478
):
479+
for output_name in self.output_names:
480+
self.context.set_output_allocator(
481+
output_name, self.output_allocator
482+
)
483+
489484
self._caller_stream = torch.cuda.current_stream()
490485
if (
491486
self._engine_stream == torch.cuda.default_stream()
@@ -526,8 +521,42 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
526521

527522
self._caller_stream.wait_stream(self._engine_stream)
528523

524+
with (
525+
torch.autograd.profiler.record_function(
526+
"PythonTorchTensorRTModule:ProcessOutputs"
527+
)
528+
if self.profiling_enabled
529+
else nullcontext()
530+
):
531+
if can_use_pre_allocated_outputs:
532+
outputs = self.pre_allocated_outputs
533+
else:
534+
outputs = []
535+
for o, output_name in enumerate(self.output_names):
536+
shape = self.output_allocator.shapes.get(output_name, None)
537+
self.output_shapes[o] = shape
538+
dtype = self.output_dtypes[o]
539+
output = self.output_allocator.buffers.get(output_name, None)
540+
prod = int(torch.prod(torch.tensor(shape)))
541+
output = output.reshape(-1).view(dtype)[:prod].reshape(shape)
542+
outputs.append(output)
543+
544+
for o, output_name in enumerate(self.output_names):
545+
546+
if need_cudagraphs_record:
547+
self._output_buffers[o] = outputs[o].clone()
548+
549+
if cudagraphs_enabled:
550+
self.context.set_tensor_address(
551+
output_name, self._output_buffers[o].data_ptr()
552+
)
553+
else:
554+
self.context.set_tensor_address(
555+
output_name, outputs[o].data_ptr()
556+
)
557+
529558
if self.use_pre_allocated_outputs:
530-
self.pre_allocated_outputs = self.create_output_tensors()
559+
self.pre_allocated_outputs = outputs
531560

532561
if cudagraphs_enabled:
533562
for idx, o in enumerate(outputs):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestNonZeroConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((10,), torch.int),
13+
((1, 20), torch.int32),
14+
((2, 3), torch.int64),
15+
((2, 3, 4), torch.float),
16+
((2, 3, 4, 5), torch.float),
17+
]
18+
)
19+
def test_non_zero_float(self, input_shape, dtype):
20+
class NonZero(nn.Module):
21+
def forward(self, input):
22+
return torch.ops.aten.nonzero.default(input)
23+
24+
inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)]
25+
self.run_test(
26+
NonZero(),
27+
inputs,
28+
)
29+
30+
31+
if __name__ == "__main__":
32+
run_tests()

0 commit comments

Comments
 (0)