Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,21 +561,28 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any:
return dynamo_load_cross_compiled_exported_program(file_path)


def load(file_path: str = "") -> Any:
def load(file_path: str = "", extra_files: Optional[dict[str, Any]] = None) -> Any:
"""
Load either a Torchscript model or ExportedProgram.

Loads a TorchScript or ExportedProgram file from disk. File type will be detect the type using try, except.

Arguments:
file_path (str): Path to file on the disk
extra_files (dict[str, Any]): Extra files to load with the model

Example:
# Load with extra files.
extra_files = {"foo.txt": ""} # values will be replaced with data
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
print(extra_files["foo.txt"])

Raises:
ValueError: If there is no file or the file is not either a TorchScript file or ExportedProgram file
"""
try:
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
ts_module = torch.jit.load(file_path)
ts_module = torch.jit.load(file_path, extra_files=extra_files)
return ts_module
except Exception:
logger.info(
Expand All @@ -586,7 +593,7 @@ def load(file_path: str = "") -> Any:

try:
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
exp_program = torch.export.load(file_path)
exp_program = torch.export.load(file_path, extra_files=extra_files)
return exp_program
except Exception:
logger.info(
Expand All @@ -602,6 +609,7 @@ def save(
module: Any,
file_path: str = "",
*,
extra_files: Optional[dict[str, Any]] = None,
output_format: str = "exported_program",
inputs: Optional[Sequence[torch.Tensor]] = None,
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
Expand All @@ -615,6 +623,8 @@ def save(

Arguments:
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module
file_path (str): Path to file on the disk
extra_files (Optional[Dict[str, Any]]): Map from filename to contents which will be stored as part of saved file.
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
Expand Down Expand Up @@ -670,7 +680,7 @@ def save(
logger.warning(
"Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save."
)
torch.jit.save(module, file_path)
torch.jit.save(module, file_path, extra_files=extra_files)
elif module_type == _ModuleType.ep:
if output_format == "torchscript":
raise ValueError(
Expand All @@ -682,7 +692,12 @@ def save(
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
)
if output_format == "exported_program":
torch.export.save(module, file_path, pickle_protocol=pickle_protocol)
torch.export.save(
module,
file_path,
pickle_protocol=pickle_protocol,
extra_files=extra_files,
)
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
Expand All @@ -703,7 +718,7 @@ def save(
module_ts = torch.jit.trace(
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
)
torch.jit.save(module_ts, file_path)
torch.jit.save(module_ts, file_path, extra_files=extra_files)
else:
if not retrace:
from torch_tensorrt.dynamo._exporter import export
Expand All @@ -715,7 +730,10 @@ def save(
exp_program = export(module)
if output_format == "exported_program":
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
exp_program,
file_path,
pickle_protocol=pickle_protocol,
extra_files=extra_files,
)
elif output_format == "aot_inductor":
inductor_configs = {}
Expand Down Expand Up @@ -745,7 +763,10 @@ def save(

if output_format == "exported_program":
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
exp_program,
file_path,
pickle_protocol=pickle_protocol,
extra_files=extra_files,
)
elif output_format == "aot_inductor":
inductor_configs = {}
Expand Down
Loading