|
12 | 12 | from torch_tensorrt._Device import Device
|
13 | 13 | from torch_tensorrt._enums import Platform, dtype
|
14 | 14 | from torch_tensorrt.dynamo._settings import CompilationSettings
|
15 |
| -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM |
16 | 15 | from torch_tensorrt.logging import TRT_LOGGER
|
17 | 16 | from torch_tensorrt.runtime._utils import (
|
18 | 17 | _is_switch_required,
|
|
23 | 22 | logger = logging.getLogger(__name__)
|
24 | 23 |
|
25 | 24 |
|
| 25 | +class OutputAllocator(trt.IOutputAllocator): # type: ignore[misc] |
| 26 | + def __init__(self) -> None: |
| 27 | + trt.IOutputAllocator.__init__(self) |
| 28 | + self.buffers: Dict[str, torch.Tensor] = {} |
| 29 | + self.shapes: Dict[str, Tuple[int, ...]] = {} |
| 30 | + |
| 31 | + def reallocate_output( |
| 32 | + self, tensor_name: str, memory: int, size: int, alignment: int |
| 33 | + ) -> Any: |
| 34 | + shape = (size,) |
| 35 | + if tensor_name not in self.buffers: |
| 36 | + self.buffers[tensor_name] = torch.empty( |
| 37 | + shape, dtype=torch.float, device=torch.cuda.current_device() |
| 38 | + ) |
| 39 | + else: |
| 40 | + self.buffers[tensor_name] = self.resize_or_reallocate( |
| 41 | + self.buffers[tensor_name], shape |
| 42 | + ) |
| 43 | + return self.data_ptr(self.buffers[tensor_name]) |
| 44 | + |
| 45 | + def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None: |
| 46 | + self.shapes[tensor_name] = tuple(shape) |
| 47 | + |
| 48 | + def resize_or_reallocate( |
| 49 | + self, buffer: torch.Tensor, shape: Tuple[int, ...] |
| 50 | + ) -> torch.Tensor: |
| 51 | + if buffer.shape != shape: |
| 52 | + buffer = torch.empty( |
| 53 | + shape, dtype=torch.float, device=torch.cuda.current_device() |
| 54 | + ) |
| 55 | + return buffer |
| 56 | + |
| 57 | + def data_ptr(self, buffer: torch.Tensor) -> Any: |
| 58 | + return buffer.data_ptr() |
| 59 | + |
| 60 | + |
26 | 61 | class TorchTRTRuntimeStates:
|
27 | 62 | def __init__(self, new_cudagraphs: bool):
|
28 | 63 | # Indicates whether CUDAGraphs were enabled in the previous execute_engine
|
@@ -147,6 +182,8 @@ def __init__(
|
147 | 182 | self.output_names = (
|
148 | 183 | output_binding_names if output_binding_names is not None else []
|
149 | 184 | )
|
| 185 | + self.output_allocator = OutputAllocator() |
| 186 | + |
150 | 187 | self.initialized = False
|
151 | 188 | self.target_device_id = (
|
152 | 189 | settings.device.gpu_id
|
@@ -342,19 +379,6 @@ def setup_input_tensors(
|
342 | 379 | input_name, contiguous_inputs[i].data_ptr()
|
343 | 380 | )
|
344 | 381 |
|
345 |
| - def create_output_tensors(self) -> List[torch.Tensor]: |
346 |
| - # create output tensors |
347 |
| - outputs: List[torch.Tensor] = [] |
348 |
| - |
349 |
| - for o, _ in enumerate(self.output_names): |
350 |
| - output = torch.empty( |
351 |
| - size=self.output_shapes[o], |
352 |
| - dtype=self.output_dtypes[o], |
353 |
| - device=torch.cuda.current_device(), |
354 |
| - ) |
355 |
| - outputs.append(output) |
356 |
| - return outputs |
357 |
| - |
358 | 382 | def set_pre_allocated_outputs(self, enable: bool) -> None:
|
359 | 383 | self.use_pre_allocated_outputs = enable
|
360 | 384 |
|
@@ -445,47 +469,18 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
|
445 | 469 | This could happen if the input tensor addresses/shapes haven't been configured correctly"
|
446 | 470 | )
|
447 | 471 |
|
448 |
| - with ( |
449 |
| - torch.autograd.profiler.record_function( |
450 |
| - "PythonTorchTensorRTModule:ProcessOutputs" |
451 |
| - ) |
452 |
| - if self.profiling_enabled |
453 |
| - else nullcontext() |
454 |
| - ): |
455 |
| - if can_use_pre_allocated_outputs: |
456 |
| - outputs = self.pre_allocated_outputs |
457 |
| - else: |
458 |
| - self.output_shapes = [ |
459 |
| - tuple(self.context.get_tensor_shape(output_name)) |
460 |
| - for output_name in self.output_names |
461 |
| - ] |
462 |
| - if DYNAMIC_DIM in self.output_shapes: |
463 |
| - raise ValueError( |
464 |
| - "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." |
465 |
| - ) |
466 |
| - outputs = self.create_output_tensors() |
467 |
| - |
468 |
| - for o, output_name in enumerate(self.output_names): |
469 |
| - |
470 |
| - if need_cudagraphs_record: |
471 |
| - self._output_buffers[o] = outputs[o].clone() |
472 |
| - |
473 |
| - if cudagraphs_enabled: |
474 |
| - self.context.set_tensor_address( |
475 |
| - output_name, self._output_buffers[o].data_ptr() |
476 |
| - ) |
477 |
| - else: |
478 |
| - self.context.set_tensor_address( |
479 |
| - output_name, outputs[o].data_ptr() |
480 |
| - ) |
481 |
| - |
482 | 472 | with (
|
483 | 473 | torch.autograd.profiler.record_function(
|
484 | 474 | "PythonTorchTensorRTModule:TensorRTRuntime"
|
485 | 475 | )
|
486 | 476 | if self.profiling_enabled
|
487 | 477 | else nullcontext()
|
488 | 478 | ):
|
| 479 | + for output_name in self.output_names: |
| 480 | + self.context.set_output_allocator( |
| 481 | + output_name, self.output_allocator |
| 482 | + ) |
| 483 | + |
489 | 484 | self._caller_stream = torch.cuda.current_stream()
|
490 | 485 | if (
|
491 | 486 | self._engine_stream == torch.cuda.default_stream()
|
@@ -526,8 +521,42 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
|
526 | 521 |
|
527 | 522 | self._caller_stream.wait_stream(self._engine_stream)
|
528 | 523 |
|
| 524 | + with ( |
| 525 | + torch.autograd.profiler.record_function( |
| 526 | + "PythonTorchTensorRTModule:ProcessOutputs" |
| 527 | + ) |
| 528 | + if self.profiling_enabled |
| 529 | + else nullcontext() |
| 530 | + ): |
| 531 | + if can_use_pre_allocated_outputs: |
| 532 | + outputs = self.pre_allocated_outputs |
| 533 | + else: |
| 534 | + outputs = [] |
| 535 | + for o, output_name in enumerate(self.output_names): |
| 536 | + shape = self.output_allocator.shapes.get(output_name, None) |
| 537 | + self.output_shapes[o] = shape |
| 538 | + dtype = self.output_dtypes[o] |
| 539 | + output = self.output_allocator.buffers.get(output_name, None) |
| 540 | + prod = int(torch.prod(torch.tensor(shape))) |
| 541 | + output = output.reshape(-1).view(dtype)[:prod].reshape(shape) |
| 542 | + outputs.append(output) |
| 543 | + |
| 544 | + for o, output_name in enumerate(self.output_names): |
| 545 | + |
| 546 | + if need_cudagraphs_record: |
| 547 | + self._output_buffers[o] = outputs[o].clone() |
| 548 | + |
| 549 | + if cudagraphs_enabled: |
| 550 | + self.context.set_tensor_address( |
| 551 | + output_name, self._output_buffers[o].data_ptr() |
| 552 | + ) |
| 553 | + else: |
| 554 | + self.context.set_tensor_address( |
| 555 | + output_name, outputs[o].data_ptr() |
| 556 | + ) |
| 557 | + |
529 | 558 | if self.use_pre_allocated_outputs:
|
530 |
| - self.pre_allocated_outputs = self.create_output_tensors() |
| 559 | + self.pre_allocated_outputs = outputs |
531 | 560 |
|
532 | 561 | if cudagraphs_enabled:
|
533 | 562 | for idx, o in enumerate(outputs):
|
|
0 commit comments