diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 52a9b47c12..135d940c2d 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -237,6 +237,14 @@ TRTEngine::TRTEngine( out_binding_names[pyt_idx] = binding_name; } num_io = std::make_pair(inputs_size, outputs); + + this->current_device_id = at::cuda::current_device(); + this->stream = c10::cuda::getCurrentCUDAStream(this->current_device_id); + this->io_size = this->cuda_engine->getNbIOTensors(); + for (int64_t i = 0; i < this->in_binding_names.size(); i++) { + this->isShapeInferenceIO[this->in_binding_names[i]] = + this->cuda_engine->isShapeInferenceIO(this->in_binding_names[i].c_str()); + } } #ifndef NDEBUG @@ -281,6 +289,14 @@ void TRTEngine::enable_profiling() { exec_ctx->setProfiler(trt_engine_profiler.get()); } +void TRTEngine::set_unowned_output_tensor(bool enable) { + this->unowned_output_tensor = enable; +} + +bool TRTEngine::is_unowned_output_tensor() { + return this->unowned_output_tensor; +} + void TRTEngine::set_profile_format(std::string format) { if (format == "trex") { this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 15d723ce4e..c44869c7c8 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -103,6 +103,9 @@ struct TRTEngine : torch::CustomClassHolder { std::shared_ptr cuda_engine; std::shared_ptr exec_ctx; std::pair num_io; + uint64_t io_size; + std::map isShapeInferenceIO; + bool unowned_output_tensor = false; std::string name; RTDevice device_info; @@ -159,6 +162,8 @@ struct TRTEngine : torch::CustomClassHolder { int64_t get_automatic_device_memory_budget(); std::vector infer_outputs(std::vector> input_shapes); void set_pre_allocated_outputs(bool enable); + void set_unowned_output_tensor(bool enable); + bool is_unowned_output_tensor(); TorchTRTRuntimeStates runtime_states; friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine); static const char BINDING_DELIM = '%'; @@ -169,13 +174,14 @@ struct TRTEngine : torch::CustomClassHolder { // CUDAGraph-Related Functionality at::cuda::CUDAGraph cudagraph = {}; - at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream(); - at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream = c10::cuda::getDefaultCUDAStream(); + int64_t current_device_id = at::cuda::current_device(); std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key = "None"; bool use_pre_allocated_outputs = false; std::vector pre_allocated_outputs; + std::vector allocated_outputs; // Output Allocator-Related Functionality bool requires_output_allocator = false; // engine requires output allocator diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..22d3c6340e 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -96,7 +96,8 @@ void setup_input_tensors( std::vector inputs, c10::intrusive_ptr compiled_engine, bool cudagraphs_enabled, - bool need_cudagraphs_record) { + bool need_cudagraphs_record, + bool shape_changed) { // this is a buffer to store shape tensor input addresses throughout the runtime scope std::list> inputShapeTensorValues; std::list formatted_inputs(compiled_engine->num_io.first); @@ -117,7 +118,7 @@ void setup_input_tensors( auto shape = core::util::toVec(dims); LOG_DEBUG("Input Name: " << name << " Shape: " << dims); - if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) { + if (compiled_engine->isShapeInferenceIO[name]) { // Shape tensor inputs are casted to int64 explicitly. // Refer to // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 @@ -145,10 +146,10 @@ void setup_input_tensors( // Create a new persistent input buffer compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone()); } - - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); - + if (shape_changed) { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); + } if (cudagraphs_enabled) { // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true); @@ -217,7 +218,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->cudagraph.reset(); } - std::vector outputs(compiled_engine->num_io.second); + std::vector outputs; // Intialize inputs and outputs to be available throughout the succeeding scopes { // Input Setup @@ -226,10 +227,9 @@ std::vector execute_engine(std::vector inputs, c10::intr input_profiler_guard = std::make_unique(compiled_engine->input_profile_path); } - - setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record); + setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, shape_changed); // Check if input shapes can be inferred. - int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; + int32_t const io_size{compiled_engine->io_size}; std::vector names(io_size); int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); TORCHTRT_CHECK( @@ -240,6 +240,7 @@ std::vector execute_engine(std::vector inputs, c10::intr } { // Output Setup + bool new_outputs = false; std::unique_ptr output_profiler_guard; if (compiled_engine->profile_execution) { output_profiler_guard = @@ -248,26 +249,32 @@ std::vector execute_engine(std::vector inputs, c10::intr if (can_use_pre_allocated_outputs) { outputs = compiled_engine->pre_allocated_outputs; } else { - outputs = create_output_tensors(compiled_engine); + if (compiled_engine->allocated_outputs.size() == 0 or compiled_engine->unowned_output_tensor or shape_changed) { + compiled_engine->allocated_outputs = create_output_tensors(compiled_engine); + new_outputs = true; + } + outputs = compiled_engine->allocated_outputs; } - for (auto output_indices : compiled_engine->out_binding_map) { - auto pyt_idx = output_indices.second; - std::string name = compiled_engine->out_binding_names[pyt_idx]; - if (need_cudagraphs_record) { - // If we are recording the cuda graph then we need to update the persistent output buffer - compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); - } + if (new_outputs) { + for (auto output_indices : compiled_engine->out_binding_map) { + auto pyt_idx = output_indices.second; + std::string name = compiled_engine->out_binding_names[pyt_idx]; + if (need_cudagraphs_record) { + // If we are recording the cuda graph then we need to update the persistent output buffer + compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); + } - if (cudagraphs_enabled) { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress( - name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); - } else { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); + if (cudagraphs_enabled) { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress( + name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); + } else { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); + } } } } @@ -275,18 +282,12 @@ std::vector execute_engine(std::vector inputs, c10::intr auto current_device_id = -1; if (inputs.size() > 0) { current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else if (outputs.size() > 0) { - current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart - } - - compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); - if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { - // Create a new stream if the engine stream is the default stream - compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + if (current_device_id != compiled_engine->current_device_id) { + compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id); + } } { // Engine Execution (execute on engine stream) - c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); std::unique_ptr enqueue_profiler_guard; if (compiled_engine->profile_execution) { @@ -294,18 +295,13 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } - // Block engine stream until results are available on caller stream - at::cuda::CUDAEvent caller_exec_complete; - caller_exec_complete.record(compiled_engine->caller_stream); - caller_exec_complete.block(compiled_engine->engine_stream); - if (!cudagraphs_enabled) { // Direct execution uses the caller buffers directly - compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream); } else { if (need_cudagraphs_record) { // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph - c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; + c10::cuda::CUDAStream recording_stream = compiled_engine->stream; compiled_engine->cudagraph.capture_begin(); compiled_engine->exec_ctx->enqueueV3(recording_stream); compiled_engine->cudagraph.capture_end(); @@ -325,11 +321,6 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); } - // Block caller stream until engine execution is complete - at::cuda::CUDAEvent trt_exec_complete; - trt_exec_complete.record(compiled_engine->engine_stream); - trt_exec_complete.block(compiled_engine->caller_stream); - if (cudagraphs_enabled) { // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { @@ -354,7 +345,7 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->input_profile_path); } - setup_input_tensors(inputs, compiled_engine, false, false); + setup_input_tensors(inputs, compiled_engine, false, false, true); // Check if input shapes can be inferred. int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; std::vector names(io_size); @@ -378,18 +369,12 @@ std::vector execute_engine(std::vector inputs, c10::intr auto current_device_id = -1; if (inputs.size() > 0) { current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else { - current_device_id = at::cuda::current_device(); - } - - compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); - if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { - // Create a new stream if the engine stream is the default stream - compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + if (current_device_id != compiled_engine->current_device_id) { + compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id); + } } { // Engine Execution (execute on engine stream) - c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); std::unique_ptr enqueue_profiler_guard; if (compiled_engine->profile_execution) { @@ -397,21 +382,11 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } - // Block engine stream until results are available on caller stream - at::cuda::CUDAEvent caller_exec_complete; - caller_exec_complete.record(compiled_engine->caller_stream); - caller_exec_complete.block(compiled_engine->engine_stream); - // Direct execution uses the caller buffers directly - compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream); } // End engine exeuction (resets to caller stream) - // Block caller stream until engine execution is complete - at::cuda::CUDAEvent trt_exec_complete; - trt_exec_complete.record(compiled_engine->engine_stream); - trt_exec_complete.block(compiled_engine->caller_stream); - std::unique_ptr output_profiler_guard; if (compiled_engine->profile_execution) { output_profiler_guard = diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 173ff8c35f..36d85481ff 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -90,6 +90,8 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) .def("reset_captured_graph", &TRTEngine::reset_captured_graph) + .def("set_unowned_output_tensor", &TRTEngine::set_unowned_output_tensor) + .def("is_unowned_output_tensor", &TRTEngine::is_unowned_output_tensor) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 0dc4654db0..145307004d 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -42,6 +42,7 @@ ) from torch_tensorrt.dynamo.utils import ( deallocate_module, + get_cpu_memory_usage, get_flat_args_with_check, get_output_metadata, parse_graph_io, @@ -675,7 +676,7 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, } - + logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -689,14 +690,17 @@ def compile( # Apply lowering on the graph module gm = post_lowering(gm, settings) + logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") logger.debug("Lowered Input graph: " + str(gm.graph)) # Move the weights in the state_dict to CPU if offload_module_to_cpu: + deallocate_module(gm, delete_module=False) deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) + logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB") else: remaining_memory, total_memory = torch.cuda.mem_get_info() if remaining_memory < total_memory // 2: @@ -858,6 +862,12 @@ def preserve_module_specs( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. + # This is done to release CPU memory. + for attr in dir(gm): + if attr.startswith("_frozen_param"): + delattr(gm, attr) + trt_module = None for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -978,6 +988,10 @@ def preserve_module_specs( ) as f: f.write(trt_module.get_layer_info()) + # Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor + if trt_module: + trt_module.set_unowned_output_tensor(True) + # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) @@ -1231,7 +1245,7 @@ def convert_exported_program_to_serialized_trt_engine( # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) - trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) + trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 73af09448e..2542d652bd 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -50,7 +50,12 @@ from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.observer import Observer -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device +from torch_tensorrt.dynamo.utils import ( + DYNAMIC_DIM, + deallocate_module, + get_cpu_memory_usage, + to_torch_device, +) from torch_tensorrt.logging import TRT_LOGGER _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError): class TRTInterpreterResult(NamedTuple): - serialized_engine: bytes + engine: trt.ICudaEngine input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] @@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - self.module.to(torch_device) - sd = self.module.state_dict() + sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} weight_refit_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} @@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None: torch.cuda.empty_cache() @needs_refit # type: ignore[misc] - def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: + def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None: + serialized_engine = engine.serialize() # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: # # set EXCLUDE_WEIGHTS flag to strip weights - # runtime = trt.Runtime(TRT_LOGGER) - # engine = runtime.deserialize_cuda_engine(serialized_engine) - # serialization_config = engine.create_serialization_config() # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) # serialized_engine = engine.serialize_with_config( @@ -733,6 +735,9 @@ def run( return interpreter_result # type: ignore[no-any-return] self._construct_trt_network_def() + _LOGGER.debug( + f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" + ) if not self.compilation_settings.immutable_weights: self._save_weight_mapping() @@ -750,16 +755,19 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - serialized_engine = self.builder.build_serialized_network( + + cuda_engine = self.builder.build_engine_with_config( self.ctx.net, builder_config ) - assert serialized_engine + assert cuda_engine + + _LOGGER.debug( + f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB" + ) _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") - self.ctx.clear_cpu_weights_reference_holder() self._save_timing_cache( @@ -772,14 +780,10 @@ def run( and self.compilation_settings.cache_built_engines and self.engine_cache is not None ): - self._insert_engine_to_cache(hash_val, serialized_engine) - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() + self._insert_engine_to_cache(hash_val, cuda_engine) return TRTInterpreterResult( - engine_str, + cuda_engine, self._input_names, self._output_names, self.weight_name_map, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 35b6c26617..0f17227c20 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,7 +1,8 @@ from __future__ import annotations +import io import logging -from typing import Any, List, Optional, Sequence +from typing import Any, List, NamedTuple, Optional, Sequence import torch from torch_tensorrt._enums import dtype @@ -9,16 +10,25 @@ from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( - TRTInterpreter, - TRTInterpreterResult, -) +from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_output_dtypes +from torch_tensorrt.dynamo.utils import ( + get_cpu_memory_usage, + get_output_dtypes, + release_memory, +) logger = logging.getLogger(__name__) +class SerializedInterpreterResult(NamedTuple): + serialized_engine: bytes + input_names: Sequence[str] + output_names: Sequence[str] + weight_name_map: Optional[dict[Any, Any]] + requires_output_allocator: bool + + def infer_module_output_dtypes( module: torch.fx.GraphModule, truncate_double: bool = False, @@ -29,7 +39,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] + return get_output_dtypes(outputs, truncate_double) # type: ignore def interpret_module_to_result( @@ -39,7 +49,7 @@ def interpret_module_to_result( arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, -) -> TRTInterpreterResult: +) -> SerializedInterpreterResult: """Interpret an FX module to a TRTInterpreterResult Args: module: FX GraphModule to interpret @@ -65,7 +75,32 @@ def interpret_module_to_result( ) interpreter_result = interpreter.run() - return interpreter_result + # Delete the frozen parameters from the module to release CPU memory + del interpreter + for attr in dir(module): + if attr.startswith("_frozen_param"): + delattr(module, attr) + release_memory() + logger.debug( + f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" + ) + + serialized_engine = interpreter_result.engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + serialized_engine = engine_bytes.getvalue() + logger.debug( + f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" + ) + serialized_interpreter_result = SerializedInterpreterResult( + serialized_engine=serialized_engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + weight_name_map=interpreter_result.weight_name_map, + requires_output_allocator=interpreter_result.requires_output_allocator, + ) + + return serialized_interpreter_result def convert_module( diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4dcb525405..164f0c1065 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -888,6 +888,7 @@ def aten_ops_select( @dynamo_tensorrt_converter( torch.ops.aten.index_put.default, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { @@ -3168,7 +3169,9 @@ def aten_ops_upsample_bicubic2d( @dynamo_tensorrt_converter( - torch.ops.aten.topk.default, capability_validator=topk_validator + torch.ops.aten.topk.default, + capability_validator=topk_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6f4a812dd8..ff743edf27 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -257,15 +257,17 @@ def index( ) else: dim_tensor_shape_mult_d1 = transpose_tensor_shape[i] - mult_d1 = convert_binary_elementwise( - ctx, - target, - source_ir, - name + f"_shape_{i}", - trt.ElementWiseOperation.PROD, - mult_d1, - dim_tensor_shape_mult_d1, - ) + + if isinstance(dim_tensor_shape_mult_d1, TRTTensor): + mult_d1 = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_shape_{i}", + trt.ElementWiseOperation.PROD, + mult_d1, + dim_tensor_shape_mult_d1, + ) concat_tensor_layer = ctx.net.add_concatenation( [ @@ -548,6 +550,9 @@ def index_put_converter( accumulate: bool = False, ) -> TRTTensor: # Convert 'input_indices' to TRT tensors (or keep None as is) + input_indices = expand_boolean_indices( + ctx, target, source_ir, name, input_tensor, input_indices + ) indices: List[Optional[Union[TRTTensor, None]]] = [] for i, idx in enumerate(input_indices): if idx is None: @@ -571,13 +576,31 @@ def index_put_converter( K = len(I) # Determine the maximum size 'N' among the index tensors if K > 0: - index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None] + index_shapes = ( + [] + ) # [tensor.shape[0] for tensor in indices if tensor is not None] + for idx_tensor in indices: + if idx_tensor is not None: + if idx_tensor.shape[0] != DYNAMIC_DIM: + index_shapes.append(idx_tensor.shape[0]) + else: + index_shapes.append( + get_shape( + ctx, + target, + source_ir, + name + "idx_shape_dim_0", + idx_tensor, + 0, + ) + ) N = max(index_shapes) if index_shapes else 1 else: N = 1 # Compute shapes and volume for the free dimensions F_shapes = [input_tensor.shape[i] for i in F] + assert -1 not in F_shapes, "Dynamic shape in free dimensions is not supported" F_volume = trt.volume(F_shapes) if F_shapes else 1 # Process indexed dimensions (I) @@ -585,8 +608,8 @@ def index_put_converter( for i in I: idx = indices[i] assert idx is not None - idx_reshaped = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_idx_I_{i}", idx, (idx.shape[0], 1) + idx_reshaped = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, f"{name}_unsqueeze_idx_I_{i}", idx, 1 ) expanded_idx = impl.slice.expand( ctx, @@ -608,46 +631,50 @@ def index_put_converter( ) arange_tensors.append(arange_tensor) - meshgrid_tensors = [] - for i, arange in enumerate(arange_tensors): - reshape_shape = [1] * len(F) - reshape_shape[i] = F_shapes[i] - arange_reshaped = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_arange_F_{F[i]}", - arange, - tuple(reshape_shape), - ) - expanded_arange = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_arange_F_{F[i]}", - arange_reshaped, - tuple(F_shapes), - ) - meshgrid_tensors.append(expanded_arange) - - meshgrid_stacked = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_stack_meshgrid", - [ - impl.shuffle.reshape( + if len(arange_tensors) == 1: + # No need to stack + meshgrid_stacked = arange_tensors[0] + else: + meshgrid_tensors = [] + for i, arange in enumerate(arange_tensors): + reshape_shape = [1] * len(F) + reshape_shape[i] = F_shapes[i] + arange_reshaped = impl.shuffle.reshape( ctx, target, source_ir, - f"{name}_reshape_mesh_{i}", - t, - (*F_shapes, 1), + f"{name}_reshape_arange_F_{F[i]}", + arange, + tuple(reshape_shape), ) - for i, t in enumerate(meshgrid_tensors) - ], - dim=-1, - ) + expanded_arange = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_arange_F_{F[i]}", + arange_reshaped, + tuple(F_shapes), + ) + meshgrid_tensors.append(expanded_arange) + + meshgrid_stacked = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_stack_meshgrid", + [ + impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_mesh_{i}", + t, + (*F_shapes, 1), + ) + for i, t in enumerate(meshgrid_tensors) + ], + dim=-1, + ) meshgrid_reshaped = impl.shuffle.reshape( ctx, target, @@ -672,21 +699,15 @@ def index_put_converter( # Combine all indexed dimensions (I) if K > 0: - I_combined = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_cat_I", - [ - impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1) - ) - for i, t in enumerate(I_tensors) - ], - dim=2, - ) + + I_combined = [ + impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1) + ) + for i, t in enumerate(I_tensors) + ] else: - I_combined = None + I_combined = [] # Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded ii_list = [] @@ -695,24 +716,12 @@ def index_put_converter( for dim in range(rank): unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}" if dim in I: - start = [0, 0, i_idx] - shape = [N, F_volume, 1] - stride = [1, 1, 1] - idx_tensor = impl.slice.slice( - ctx, - target, - source_ir, - f"{name}_slice_I_dim_{unique_suffix}", - I_combined, - start, - shape, - stride, - ) + idx_tensor = I_combined[i_idx] ii_list.append(idx_tensor) i_idx += 1 else: start = [0, 0, f_idx] - shape = [N, F_volume, 1] + shape = [-1, F_volume, 1] if isinstance(N, TRTTensor) else [N, F_volume, 1] stride = [1, 1, 1] mesh_tensor = impl.slice.slice( ctx, @@ -731,20 +740,24 @@ def index_put_converter( indices_cat = impl.cat.cat( ctx, target, source_ir, f"{name}_cat_indices", ii_list, dim=2 ) + + # Flatten the indices_cat to (N * F_volume, rank) indices_cat = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_indices_cat", indices_cat, - (N * F_volume, rank), + (-1, rank), ) if not isinstance(values, TRTTensor): values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=0) # Define the expected shape based on (N,) + F_shapes - expected_shape = (N,) + tuple(F_shapes) + expected_shape = ( + (-1,) + tuple(F_shapes) if isinstance(N, TRTTensor) else (N,) + tuple(F_shapes) + ) # Broadcast 'values' to match the expected shape if len(values.shape) == 0 or values.shape == (1,): # Scalar case @@ -761,7 +774,12 @@ def index_put_converter( ) else: # Non-scalar case values_shape = list(values.shape) - if K > 0 and N in values_shape: + if ( + K > 0 + and N in values_shape + and (len(F) > 1 and max(F) - min(F) + 1 == len(F)) + ): + # Continuous case n_idx = values_shape.index(N) permute_order = [n_idx] + [ i for i in range(len(values_shape)) if i != n_idx @@ -807,31 +825,27 @@ def index_put_converter( tuple(broadcast_shape), ) else: + # Discontinuous case values_shape_padded = [1] * ( len(expected_shape) - len(values.shape) ) + list(values.shape) broadcast_shape = [] for exp_dim, val_dim in zip(expected_shape, values_shape_padded): - if val_dim == 1 or exp_dim == val_dim: + if val_dim == DYNAMIC_DIM or exp_dim == DYNAMIC_DIM: + broadcast_shape.append(-1) + elif val_dim == 1 or exp_dim == val_dim: broadcast_shape.append(exp_dim) else: raise ValueError( f"Cannot broadcast {values.shape} to {expected_shape}" ) - values_reshaped = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_values", - values, - tuple(broadcast_shape), - ) + values_expanded = impl.slice.expand( ctx, target, source_ir, f"{name}_expand_values", - values_reshaped, + values, expected_shape, ) @@ -842,16 +856,51 @@ def index_put_converter( source_ir, f"{name}_flatten_values", values_expanded, - (N * F_volume,), + (-1,), ) - indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32") - # Perform Scatter ND operation - scatter_layer = ctx.net.add_scatter( - input_tensor, - indices_cat, - flattened_values, - trt.ScatterMode.ND if not accumulate else trt.ScatterMode.ND_ELEMENTWISE_ADD, - ) - set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) - return scatter_layer.get_output(0) + if accumulate: + zero_tensor = impl.full.full( + ctx, + target, + source_ir, + f"{name}_zero_tensor", + [ + get_shape( + ctx, + target, + source_ir, + name + f"input_tensor_shape_dim_{i}", + input_tensor, + i, + ) + for i in range(len(input_tensor.shape)) + ], + 0.0, + dtype=input_tensor.dtype, + ) + # Perform Scatter ND operation + scatter_layer = ctx.net.add_scatter( + zero_tensor, + indices_cat, + flattened_values, + trt.ScatterMode.ND, + ) + set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) + + scatter_out = scatter_layer.get_output(0) + result = impl.elementwise.add( + ctx, target, source_ir, f"{name}_add", scatter_out, input_tensor + ) + return result + + else: + scatter_layer = ctx.net.add_scatter( + input_tensor, + indices_cat, + flattened_values, + trt.ScatterMode.ND, + ) + set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) + scatter_out = scatter_layer.get_output(0) + return scatter_out diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index ec624ffc5a..3e0ae9ee59 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -197,6 +197,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]: "class": "logging.FileHandler", "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", "formatter": "standard", + "mode": "w", # This will clear the previous content } config["loggers"][""]["handlers"].append("file") return config diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 5ba84b09b0..9b821df906 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -37,7 +37,9 @@ def constant_fold( # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): replace_node_with_constant( - gm, node, torch.nn.Parameter(constant, requires_grad=False) + gm, + node, + torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), ) erased_params = [] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py index 2a2c8e9d5e..a9b7c48ec2 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py @@ -23,7 +23,8 @@ def remove_num_users_is_0_nodes( and len(node.all_input_nodes) > 0 ): gm.graph.erase_node(node) - gm = clean_up_graph_after_modifications(gm) + + gm = clean_up_graph_after_modifications(gm) logger.debug(f"Removed ops that [num_users=0] nodes:\n{gm.graph}") diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..d54be6f9e8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -172,8 +172,9 @@ def __init__( self._input_buffers: List[torch.Tensor] = [] self._output_buffers: List[torch.Tensor] = [] self.cudagraph: Optional[torch.cuda.CUDAGraph] = None - self._caller_stream: Optional[torch.cuda.Stream] = None - self._engine_stream: Optional[torch.cuda.Stream] = None + self._engine_stream: torch.cuda.Stream = torch.cuda.current_stream() + self.output_tensors: Optional[List[torch.Tensor]] = None + self.sync_stream = True # TODO: Make the below a Dictionary {shape: cudagraph} self.shape_key: Optional[str] = None @@ -218,9 +219,30 @@ def __init__( self.requires_output_allocator = requires_output_allocator self.output_allocator: Optional[DynamicOutputAllocator] = None self.use_output_allocator_outputs = False - + self.device = torch.cuda.current_device() + self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + # If the output tensor is not owned by the engine (unowned_output_tensor=True), we need to create a new output tensor in each forward pass + self.unowned_output_tensor = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() + self.is_shape_inference_io = { + input_name: self.engine.is_shape_inference_io(input_name) + for input_name in self.input_names + } + + def set_unowned_output_tensor(self, enabled: bool) -> None: + """ + Set the flag to indicate if the output tensor is unowned by the engine. + If self.unowned_output_tensor=True, the engine will create a new output tensor in each forward pass. + This would be slower but is required when users need to manipulate the output tensor after each forward pass. + Therefore, this should be set to True only for the last module in a graph and leave to False for intermediate modules, + which users don't have access to. + Args: + enabled: bool + Whether to set the flag to True. + + """ + self.unowned_output_tensor = enabled def get_streamable_device_memory_budget(self) -> Any: return self.engine.streamable_weights_size @@ -263,6 +285,9 @@ def setup_engine(self) -> None: assert ( self.target_platform == Platform.current_platform() ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" + # Stream handling: if the caller stream is the pytorch default stream, create a new engine stream + # otherwise, use the caller stream and disable stream synchronization + self._engine_stream = torch.cuda.current_stream() self.initialized = True runtime = trt.Runtime(TRT_LOGGER) @@ -287,10 +312,14 @@ def setup_engine(self) -> None: for output_name in self.output_names ] self.output_shapes = [ - self.engine.get_tensor_shape(output_name) + tuple(self.context.get_tensor_shape(output_name)) for output_name in self.output_names ] + self.shape_key = "".join( + str(tuple(t)).replace(" ", "") for t in self.input_shapes + ) + if self.requires_output_allocator: self.create_output_allocator() @@ -356,6 +385,7 @@ def setup_input_tensors( contiguous_inputs: List[torch.Tensor], cudagraphs_enabled: bool, need_cudagraphs_record: bool, + shape_changed: bool = True, ) -> None: for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: @@ -382,16 +412,17 @@ def setup_input_tensors( # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers # as per TensorRT requirements - if self.engine.is_shape_inference_io(input_name): + if self.is_shape_inference_io[input_name]: # Shape tensor inputs are casted to int64 explicitly # Currently Torch CPU pointers are not working; numpy pointers are used instead # to refer to underlying memory inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() self.context.set_tensor_address(input_name, inputs_cpu.ctypes.data) else: - self.context.set_input_shape( - input_name, tuple(contiguous_inputs[i].shape) - ) + if shape_changed: + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) + ) if cudagraphs_enabled: self._input_buffers[i].copy_(contiguous_inputs[i]) self.context.set_tensor_address( @@ -410,7 +441,7 @@ def create_output_tensors(self) -> List[torch.Tensor]: output = torch.empty( size=self.output_shapes[o], dtype=self.output_dtypes[o], - device=torch.cuda.current_device(), + device=self.device, ) outputs.append(output) return outputs @@ -459,7 +490,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." self.setup_input_tensors( - contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record + contiguous_inputs, + self.cudagraphs_enabled, + need_cudagraphs_record, + shape_changed + or self.output_tensors is None, # First time execution ) if shape_changed: @@ -481,15 +516,22 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: if can_use_pre_allocated_outputs: outputs = self.pre_allocated_outputs else: - self.output_shapes = [ - tuple(self.context.get_tensor_shape(output_name)) - for output_name in self.output_names - ] + if shape_changed: + self.output_shapes = [ + tuple(self.context.get_tensor_shape(output_name)) + for output_name in self.output_names + ] if DYNAMIC_DIM in self.output_shapes: raise ValueError( "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." ) - outputs = self.create_output_tensors() + if ( + self.output_tensors is None + or self.unowned_output_tensor + or shape_changed + ): + self.output_tensors = self.create_output_tensors() + outputs = self.output_tensors for o, output_name in enumerate(self.output_names): if need_cudagraphs_record: @@ -511,41 +553,46 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.profiling_enabled else nullcontext() ): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - self._engine_stream.wait_stream(self._caller_stream) + if self.cudagraphs_enabled: + if need_cudagraphs_record: + self.cudagraph = torch.cuda.CUDAGraph() - with torch.cuda.stream(self._engine_stream): - if self.cudagraphs_enabled: - if need_cudagraphs_record: - self.cudagraph = torch.cuda.CUDAGraph() + if self.profiling_enabled: + self.cudagraph.enable_debug_mode() - if self.profiling_enabled: - self.cudagraph.enable_debug_mode() + with torch.cuda.graph( + self.cudagraph, stream=self._engine_stream + ): + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) - with torch.cuda.graph( - self.cudagraph, stream=self._engine_stream - ): - self.context.execute_async_v3( - self._engine_stream.cuda_stream + if self.profiling_enabled: + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + self.cudagraph.debug_dump( + f"{tmpdir}/{self.name}_cudagraph.dot" ) if self.profiling_enabled: self.cudagraph.debug_dump( f"{DEBUG_LOGGING_DIR}/{self.name}_cudagraph.dot" ) + self.cudagraph.replay() # type: ignore - self.cudagraph.replay() # type: ignore - - else: - self.context.execute_async_v3(self._engine_stream.cuda_stream) + else: + import warnings - self._caller_stream.wait_stream(self._engine_stream) + with warnings.catch_warnings(): + try: + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) + except Warning as e: + breakpoint() + print("warning ignored") if self.use_pre_allocated_outputs: self.pre_allocated_outputs = self.create_output_tensors() @@ -600,22 +647,12 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.profiling_enabled else nullcontext() ): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - - self._engine_stream.wait_stream(self._caller_stream) with torch.cuda.stream(self._engine_stream): self.context.execute_async_v3( self._engine_stream.cuda_stream ) # The OutputAllocator is called by execute_async_v3() - self._caller_stream.wait_stream(self._engine_stream) - with ( torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessOutputs" @@ -644,8 +681,6 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: return outputs - self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() - # Run forward function contiguous_inputs: List[torch.Tensor] = [ (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) @@ -750,13 +785,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: # Representation of input shapes to a given model # Shapes are concatenated as so: # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - tensor_inputs = [] - for t in inputs: - if not isinstance(t, torch.Tensor): - return True - tensor_inputs.append(t) + if not all(isinstance(t, torch.Tensor) for t in inputs): + return True + new_shape_key = "".join( - str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs + str(tuple(t.shape)).replace(" ", "") + for t in inputs + if isinstance(t, torch.Tensor) ) # If the new shape key differs from the existing one, diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..a2e30a8083 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -156,6 +156,11 @@ def _pack_engine_info(self) -> List[str | bytes]: metadata = { "settings": self.settings, "weight_name_map": self.weight_name_map, + "requires_new_output_tensor": ( + False + if self.engine is None + else self.engine.get_requires_new_output_tensor() + ), } target_platform = ( Platform.current_platform() @@ -284,6 +289,8 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: metadata = TorchTensorRTModule.decode_metadata(serialized_metadata) self.settings = metadata["settings"] self.weight_name_map = metadata["weight_name_map"] + self.unowned_output_tensor = metadata["unowned_output_tensor"] + self.engine.set_unowned_output_tensor(self.unowned_output_tensor) else: self.engine = None @@ -355,6 +362,12 @@ def enable_profiling( self.engine.enable_profiling() self.engine.set_profile_format(profile_format) + def set_unowned_output_tensor(self, enabled: bool) -> None: + self.engine.set_unowned_output_tensor(enabled) + + def is_unowned_output_tensor(self) -> bool: + return self.engine.is_unowned_output_tensor() + def disable_profiling(self) -> None: """Disable the profiler""" if self.engine is None: diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 564250e5ae..6cfa6394ec 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,13 +1,16 @@ from __future__ import annotations +import ctypes import gc import logging +import platform import warnings from dataclasses import fields, replace from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np +import psutil import sympy import tensorrt as trt import torch @@ -858,3 +861,24 @@ def is_thor() -> bool: if torch.cuda.get_device_capability() in [(11, 0)]: return True return False + + +def get_cpu_memory_usage() -> Any: + return psutil.Process().memory_info().rss / 1024 / 1024 + + +def release_memory() -> None: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + + if platform.system() == "Linux": + try: + libc = ctypes.CDLL("libc.so.6") + if libc.malloc_trim(0) != 1: + logger.warning("Failed to release CPU memory.") + except Exception: + logger.warning("Failed to release CPU memory.") diff --git a/setup.py b/setup.py index 8aec510eb2..a1426bd6e1 100644 --- a/setup.py +++ b/setup.py @@ -192,10 +192,10 @@ def build_libtorchtrt_cxx11_abi( else: cmd.append("//:libtorchtrt") - if develop: - cmd.append("--compilation_mode=dbg") - else: - cmd.append("--compilation_mode=opt") + # if develop: + # cmd.append("--compilation_mode=dbg") + # else: + cmd.append("--compilation_mode=opt") if use_dist_dir: if IS_AARCH64: cmd.append("--distdir=third_party/dist_dir/aarch64-linux-gnu") diff --git a/tests/py/dynamo/conversion/test_index_put_aten.py b/tests/py/dynamo/conversion/test_index_put_aten.py index 74e38cd0c5..0f4da97d89 100644 --- a/tests/py/dynamo/conversion/test_index_put_aten.py +++ b/tests/py/dynamo/conversion/test_index_put_aten.py @@ -1,4 +1,5 @@ import torch +import torch_tensorrt as torchtrt from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests @@ -194,11 +195,43 @@ class TestIndexPutConverter(DispatchTestCase): dtype=torch.int32, ), ), + # param( + # test_name="4d_indices_none_none_multiple_idx_broadcast_error", + # source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32), + # indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)), + # value_tensor=torch.randn([2, 3, 3], dtype=torch.float32), + # ), + param( + test_name="discontinuous_test", + source_tensor=torch.zeros([2, 4, 4], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 0, 1], dtype=torch.int64), + None, + torch.tensor([0, 0, 1], dtype=torch.int64), + ), + value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), + ), + param( + test_name="discontinuous_test_two", + source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32), + indices_tensor=( + None, + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + None, + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + ), + value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), + ), param( - test_name="4d_indices_none_none_multiple_idx_broadcast_error", - source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32), - indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)), - value_tensor=torch.randn([2, 3, 3], dtype=torch.float32), + test_name="continuous_test", + source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32), + indices_tensor=( + None, + None, + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + ), + value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), ), # param( # test_name="2d_indices_accumulate_True", @@ -244,6 +277,94 @@ def forward(self, source_tensor, value_tensor): use_dynamo_tracer=True, ) + def test_index_add_dynamic_shape(self): + + class Model(torch.nn.Module): + def forward(self, x, y, z, a, b): + x.index_add_(0, y, z) + x.index_add_(0, a, b) + return x + + dim = 10 + model = Model().cuda() + inputs = [ + torch.ones((12, dim)).half().cuda(), + torch.tensor([0, 1]).cuda(), + torch.randn((2, dim)).half().cuda(), + torch.tensor([2, 9, 11]).cuda(), + torch.randn((3, dim)).half().cuda(), + ] + torch_output = model.cuda().forward(*inputs) + seq_len1 = torch.export.Dim("seq_len1", min=1, max=128) + seq_len2 = torch.export.Dim("seq_len2", min=1, max=128) + seq_len3 = torch.export.Dim("seq_len3", min=1, max=128) + + ep = torch.export.export( + model, + tuple(inputs), + dynamic_shapes=( + {0: seq_len1}, + {0: seq_len2}, + {0: seq_len2}, + {0: seq_len3}, + {0: seq_len3}, + ), + ) + with torchtrt.dynamo.Debugger( + log_level="debug", + capture_fx_graph_after=["remove_num_users_is_0_nodes"], + logging_dir="/home/profile/logging/moe", + engine_builder_monitor=False, + ): + trt_mod = torchtrt.dynamo.compile( + ep, + inputs, + enabled_precisions={torch.float16}, + min_block_size=1, + use_explicit_typing=False, + use_fp32_acc=False, + disable_tf32=True, + ) + result = trt_mod(*inputs) + assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4) + + def test_bool_mask_test(self): + + source_tensor = torch.ones([5, 10], dtype=torch.float32).cuda() + indices_tensor = torch.tensor([False, False, True, False, True]) + value_tensor = torch.zeros([2, 10], dtype=torch.float32).cuda() + + dim1 = torch.export.Dim("dim1", min=1, max=5) + dim2 = torch.export.Dim("dim2", min=1, max=5) + + class TestIndexPut(torch.nn.Module): + def forward(self, source_tensor, indices_tensor, value_tensor): + source_tensor[indices_tensor] = value_tensor + return source_tensor + + model = TestIndexPut() + torch_output = model.forward(source_tensor, indices_tensor, value_tensor) + + ep = torch.export.export( + model, + (source_tensor, indices_tensor, value_tensor), + dynamic_shapes=({0: dim1}, {0: dim1}, {0: dim2}), + ) + with torchtrt.dynamo.Debugger(log_level="debug"): + trt_engine = torchtrt.dynamo.compile( + ep, + inputs=(source_tensor, indices_tensor, value_tensor), + enabled_precisions={torch.float32}, + min_block_size=1, + use_explicit_typing=False, + use_fp32_acc=False, + disable_tf32=True, + use_python_runtime=True, + ) + result = trt_engine(source_tensor, indices_tensor, value_tensor) + + torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index c52b732c42..13ba856d35 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -54,6 +54,52 @@ def test_resnet18(ir): torch._dynamo.reset() +def compile_one(idx: int, ir: str): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((idx + 1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"In multiprocess compilation test, process {idx} failed: Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +def test_resnet18_multiprocess(ir): + import torch.multiprocessing as mp + + mp.set_start_method("spawn", force=True) + procs = [] + for i in range(3): + p = mp.Process(target=compile_one, args=(i, ir)) + p.start() + procs.append(p) + for p in procs: + p.join() + torch._dynamo.reset() + + @pytest.mark.unit @unittest.skipIf( not importlib.util.find_spec("torchvision"), diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index ab9470cc61..97b6616581 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -71,7 +71,7 @@ def get_model(args): else: model = model.to(torch.float32) - return model + return model.cuda() def compile_torchtrt(model, input_ids, args):