diff --git a/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh b/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh index f7ee5c39db..0b2ab20ddf 100755 --- a/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh +++ b/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh @@ -4,7 +4,11 @@ pip uninstall torchtext torchdata torch torchvision torchaudio -y # Install nightly PyTorch and torchvision from the specified index URL -pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed +if nvidia-smi > /dev/null 2>&1; then + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed +else + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu --ignore-installed +fi # Optional: Display the installed PyTorch and torchvision versions python -c "import torch; print('PyTorch version:', torch.__version__)" diff --git a/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py b/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py index db31228bac..bf35477a94 100644 --- a/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py +++ b/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py @@ -5,17 +5,26 @@ torch.set_float32_matmul_precision("high") +MAX_BATCH_SIZE = 32 + model = resnet18(weights=ResNet18_Weights.DEFAULT) model.eval() with torch.no_grad(): - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + # We need to turn off the below optimizations to support batch_size = 16, + # which is treated like a special case + # https://github.com/pytorch/pytorch/pull/116152 + torch.backends.mkldnn.set_flags(False) + torch.backends.nnpack.set_flags(False) + model = model.to(device=device) example_inputs = (torch.randn(2, 3, 224, 224, device=device),) - # Max value is 15 because of https://github.com/pytorch/pytorch/pull/116152 - # On a CUDA enabled device, we tested batch_size of 32. - batch_dim = torch.export.Dim("batch", min=2, max=15) + batch_dim = torch.export.Dim("batch", min=2, max=MAX_BATCH_SIZE) so_path = torch._export.aot_compile( model, example_inputs, diff --git a/test/pytest/test_torch_export.py b/test/pytest/test_torch_export.py index 7031c983eb..dbd16fd3df 100644 --- a/test/pytest/test_torch_export.py +++ b/test/pytest/test_torch_export.py @@ -19,9 +19,9 @@ MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml") -PT_220_AVAILABLE = ( +PT_230_AVAILABLE = ( True - if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1") + if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2") else False ) @@ -30,6 +30,8 @@ ("kitten.jpg", EXPECTED_RESULTS[0]), ] +BATCH_SIZE = 32 + import os @@ -47,7 +49,7 @@ def custom_working_directory(tmp_path): os.chdir(tmp_path) -@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0") +@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0") def test_torch_export_aot_compile(custom_working_directory): # Get the path to the custom working directory model_dir = custom_working_directory @@ -88,7 +90,7 @@ def test_torch_export_aot_compile(custom_working_directory): assert labels == EXPECTED_RESULTS -@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0") +@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0") def test_torch_export_aot_compile_dynamic_batching(custom_working_directory): # Get the path to the custom working directory model_dir = custom_working_directory @@ -122,7 +124,7 @@ def test_torch_export_aot_compile_dynamic_batching(custom_working_directory): byte_array_type = bytearray(image_file) data["body"] = byte_array_type - # Send a batch of 16 elements - result = handler.handle([data for i in range(15)], ctx) + # Send a batch of BATCH_SIZE elements + result = handler.handle([data for i in range(BATCH_SIZE)], ctx) - assert len(result) == 15 + assert len(result) == BATCH_SIZE diff --git a/ts/handler_utils/torch_export/__init__.py b/ts/handler_utils/torch_export/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ts/handler_utils/torch_export/load_model.py b/ts/handler_utils/torch_export/load_model.py deleted file mode 100644 index fe3ca867c5..0000000000 --- a/ts/handler_utils/torch_export/load_model.py +++ /dev/null @@ -1,27 +0,0 @@ -import tempfile - -import torch.fx._pytree as fx_pytree -from torch._inductor.utils import aot_inductor_launcher, cache_dir -from torch.utils import _pytree as pytree -from torch.utils.cpp_extension import load_inline - - -def load_exported_model(model_so_path, device): - module = load_inline( - name="aot_inductor", - cpp_sources=[aot_inductor_launcher(model_so_path, device)], - # use a unique build directory to avoid test interference - build_directory=tempfile.mkdtemp(dir=cache_dir()), - functions=["run", "get_call_spec"], - with_cuda=("cuda" == device), - ) - call_spec = module.get_call_spec() - in_spec = pytree.treespec_loads(call_spec[0]) - out_spec = pytree.treespec_loads(call_spec[1]) - - def optimized(*args): - flat_inputs = fx_pytree.tree_flatten_spec((args, {}), in_spec) - flat_outputs = module.run(flat_inputs) - return pytree.tree_unflatten(flat_outputs, out_spec) - - return optimized diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 5862195b07..68c330c854 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -53,10 +53,10 @@ ) PT2_AVAILABLE = False -if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"): - PT220_AVAILABLE = True +if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2"): + PT230_AVAILABLE = True else: - PT220_AVAILABLE = False + PT230_AVAILABLE = False if os.environ.get("TS_IPEX_ENABLE", "false") == "true": try: @@ -187,7 +187,7 @@ def initialize(self, context): elif ( self.model_pt_path.endswith(".so") and self._use_torch_export_aot_compile() - and PT220_AVAILABLE + and PT230_AVAILABLE ): # Set cuda device to the gpu_id of the backend worker # This is needed as the API for loading the exported model doesn't yet have a device id @@ -256,9 +256,15 @@ def initialize(self, context): self.initialized = True def _load_torch_export_aot_compile(self, model_so_path): - from ts.handler_utils.torch_export.load_model import load_exported_model + """Loads the PyTorch model so and returns a Callable object. - return load_exported_model(model_so_path, self.map_location) + Args: + model_pt_path (str): denotes the path of the model file. + + Returns: + (Callable Object) : Loads the model object. + """ + return torch._export.aot_load(model_so_path, self.map_location) def _load_torchscript_model(self, model_pt_path): """Loads the PyTorch model and returns the NN model object.