diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index f9df21856c..623bad7b9f 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -606,7 +606,7 @@ def save( inputs: Optional[Sequence[torch.Tensor]] = None, arg_inputs: Optional[Sequence[torch.Tensor]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, - retrace: bool = False, + retrace: bool = True, pickle_protocol: int = 2, **kwargs: Any, ) -> None: @@ -661,7 +661,7 @@ def save( "Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram." ) elif module_type == _ModuleType.ts: - if not all([output_format == f for f in ["exported_program", "aot_inductor"]]): + if not all(output_format == f for f in ["exported_program", "aot_inductor"]): raise ValueError( "Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported" ) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 9aae901f87..2dcc75bcb7 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -27,6 +27,7 @@ ) from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs from torch_tensorrt.dynamo.lowering import ( + clean_up_graph_after_modifications, get_decompositions, post_lowering, pre_export_lowering, @@ -94,6 +95,8 @@ def construct_refit_mapping_from_weight_name_map( engine_weight_map = {} for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): # Add more constant folding converters here + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]: # Batch Norm Layer params = {} @@ -106,12 +109,12 @@ def construct_refit_mapping_from_weight_name_map( engine_weight_map[engine_weight_name] = eval( engine_weight_name.split(" ")[-1].lower() ) + elif sd_weight_name not in state_dict: # If weights is not in sd, we can leave it unchanged continue else: - trt_dtype = dtype._from(np_weight_type).to(trt.DataType) - torch_dtype = dtype._from(np_weight_type).to(torch.dtype) + engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( to_torch_device(settings.device) ) @@ -272,12 +275,66 @@ def refit_module_weights( compiled_submodules_map[name] = submodule else: + # Handle torch modules + compiled_submodules_map = {} + guard_fn_modules = [] for name, submodule in compiled_module.named_children(): - if not isinstance( - submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) + if ( + not isinstance( + submodule, + ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + torch.nn.modules.module.Module, + ), + ) + or "_run_on_gpu" in name ): continue - settings = submodule.settings + + # When we re-export the graph module, torch.export._unlift.GuardsFn modules are being added as submodules. + if isinstance(submodule, torch.export._unlift.GuardsFn): + guard_fn_modules.append(name) + continue + # Obtain the settings + + compiled_submodules = [ + (name.replace("_engine", ""), engine) + for name, engine in submodule.__dict__.items() + if "engine" in name + ] + + settings = None + try: + # If the gm is not inlined or transformed by retracing, the settings is stored in the submodule + settings = submodule.settings + except AttributeError: + + encoded_metadata = [ + engine for name, engine in compiled_submodules if name == "engine" + ][0].__getstate__()[0][SERIALIZED_METADATA_IDX] + assert ( + encoded_metadata != "" + ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version" + settings = TorchTensorRTModule.decode_metadata(encoded_metadata)[ + "settings" + ] + + compiled_submodules_map[name] = submodule + + # Delete the guard fn modules to avoid the guard fn modules being refitted + # First, remove nodes in the graph that reference the guard function modules + for node in list(compiled_module.graph.nodes): + if node.op == "call_module" and node.target in guard_fn_modules: + compiled_module.graph.erase_node(node) + + # Now delete the submodules themselves + for guard_fn_module_name in guard_fn_modules: + # delattr(compiled_module, guard_fn_module_name) + compiled_module.delete_submodule(guard_fn_module_name) + + # Clean up the graph + clean_up_graph_after_modifications(compiled_module) assert settings is not None @@ -411,11 +468,29 @@ def refit_module_weights( ) else: compiled_submodule = getattr(compiled_module, name) + if "_run_on_acc" not in name: + compiled_submodule.load_state_dict(new_submodule.state_dict()) + continue + weight_name_map = None if use_weight_map_cache: try: weight_name_map = compiled_submodule.weight_name_map except AttributeError: + if isinstance(compiled_submodule, torch.nn.Module): + # Torch retrace module + assert ( + not settings.use_python_runtime + ), "Refitting a torch retraced module is only supported with use_python_runtime=False" + encoded_metadata = [ + engine + for name, engine in compiled_submodules + if name == "engine" + ][0].__getstate__()[0][SERIALIZED_METADATA_IDX] + weight_name_map = TorchTensorRTModule.decode_metadata( + encoded_metadata + )["weight_name_map"] + if not isinstance( compiled_submodule, torch.fx.graph_module.GraphModule ): @@ -427,21 +502,16 @@ def refit_module_weights( logger.warning( "This engine does not have a weight map cache. Rebuilding the weight map" ) - if isinstance(compiled_submodule, PythonTorchTensorRTModule): + + # Rexporting the TRT compiled graph module and loading it back doesn't preserve the instance type and registers + # the compiled submodule as torch.nn.Module. So we use settings.use_python_runtime to determine the instance type. + if settings.use_python_runtime: engine = compiled_submodule.engine - elif isinstance(compiled_submodule, TorchTensorRTModule): + else: engine_info = compiled_submodule.engine.__getstate__()[0] engine = get_engine_from_encoded_engine( engine_info[ENGINE_IDX], runtime ) - elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule): - # This is graph break resulted by unsupported ops - compiled_submodule.load_state_dict(new_submodule.state_dict()) - continue - else: - raise AssertionError( - "The type of graph module is not supported for refitting." - ) except AttributeError: raise AssertionError( "The type of graph module is not supported for refitting or two compiled modules do not match." @@ -500,7 +570,12 @@ def refit_module_weights( new_engine_info[ENGINE_IDX] = bytes(serialized_engine) refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) setattr(compiled_module, f"{name}_engine", refitted_engine) - + elif isinstance(compiled_submodule, torch.nn.Module): + # Torch retrace module + new_engine_info = list(engine_info) + new_engine_info[ENGINE_IDX] = bytes(serialized_engine) + refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) + compiled_submodule.engine = refitted_engine del engine gc.collect() torch.cuda.empty_cache() diff --git a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py index c0e2803e60..c980224869 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py @@ -1,3 +1,4 @@ from ._aten_lowering_pass import * +from .pass_utils import clean_up_graph_after_modifications from .remove_sym_nodes import remove_sym_nodes from .repair_input_aliasing import repair_input_aliasing diff --git a/tests/py/dynamo/models/test_export_kwargs_serde.py b/tests/py/dynamo/models/test_export_kwargs_serde.py index 70a0fde12f..dabbad3cc8 100644 --- a/tests/py/dynamo/models/test_export_kwargs_serde.py +++ b/tests/py/dynamo/models/test_export_kwargs_serde.py @@ -76,7 +76,7 @@ def forward(self, x, b=5, c=None, d=None): # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() @@ -138,7 +138,7 @@ def forward(self, x, b=5, c=None, d=None): # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() @@ -209,7 +209,7 @@ def forward(self, x, b=5, c=None, d=None): # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() @@ -299,7 +299,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]): ) # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() @@ -389,7 +389,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]): ) # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index d9c2ca3b0b..c5b007e34b 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -56,7 +56,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -111,7 +111,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -170,7 +170,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -232,7 +232,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -279,7 +279,7 @@ def test_resnet18(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -331,7 +331,7 @@ def test_resnet18_cpu_offload(ir): msg="Model should be offloaded to CPU", ) model.cuda() - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -380,7 +380,7 @@ def test_resnet18_dynamic(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) # TODO: Enable this serialization issues are fixed # deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -413,7 +413,7 @@ def test_resnet18_torch_exec_ops_serde(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = deser_trt_module(input) outputs_trt = trt_module(input) @@ -463,7 +463,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -525,7 +525,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) model.cuda() - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -584,7 +584,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 222221e089..e6b7f6e2a4 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -540,8 +540,8 @@ def test_refit_one_engine_inline_runtime_with_weightmap(): min_block_size = 1 use_python_runtime = False - exp_program = torch.export.export(model, tuple(inputs)) - exp_program2 = torch.export.export(model2, tuple(inputs)) + exp_program = torch.export.export(model, tuple(inputs), strict=False) + exp_program2 = torch.export.export(model2, tuple(inputs), strict=False) trt_gm = torchtrt.dynamo.compile( exp_program, @@ -551,8 +551,9 @@ def test_refit_one_engine_inline_runtime_with_weightmap(): min_block_size=min_block_size, immutable_weights=False, ) - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs, retrace=True) trt_gm = torch.export.load(trt_ep_path) + new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, @@ -565,6 +566,7 @@ def test_refit_one_engine_inline_runtime_with_weightmap(): expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): assertions.assertTrue( torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), @@ -906,7 +908,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): min_block_size=min_block_size, immutable_weights=False, ) - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs) trt_gm = torch.export.load(trt_ep_path) new_trt_gm = refit_module_weights( compiled_module=trt_gm, diff --git a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py index ca82797090..539c11a303 100644 --- a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py +++ b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py @@ -314,7 +314,9 @@ def test_lazy_engine_init_cpp_serialization(self): trt_mod = torchtrt.compile(model, **compile_spec) with tempfile.TemporaryDirectory() as tmpdir: - torch_tensorrt.save(trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep")) + torch_tensorrt.save( + trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep"), arg_inputs=(input,) + ) new_trt_mod = torch.export.load(os.path.join(tmpdir, "tmp_trt_mod.ep")) loaded_trt_mod = new_trt_mod.module()