From 224d25d176fdde7919d13d1661ce244a304720e7 Mon Sep 17 00:00:00 2001 From: agunapal Date: Mon, 22 Jan 2024 22:59:27 +0000 Subject: [PATCH 1/6] changed to new api --- ts/torch_handler/base_handler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 02b56ca98c..175a91a4e2 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -256,9 +256,7 @@ def initialize(self, context): self.initialized = True def _load_torch_export_aot_compile(self, model_so_path): - from ts.handler_utils.torch_export.load_model import load_exported_model - - return load_exported_model(model_so_path, self.map_location) + return torch._export.aot_load(model_so_path, self.map_location) def _load_torchscript_model(self, model_pt_path): """Loads the PyTorch model and returns the NN model object. From 42b684fe6c86f76db9fc596f40a9f0f37e0794b6 Mon Sep 17 00:00:00 2001 From: agunapal Date: Wed, 24 Jan 2024 19:44:16 +0000 Subject: [PATCH 2/6] Updated to use new api torch._export.aot_load --- test/pytest/test_torch_export.py | 8 ++++---- ts/torch_handler/base_handler.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/test/pytest/test_torch_export.py b/test/pytest/test_torch_export.py index 7031c983eb..29ff517800 100644 --- a/test/pytest/test_torch_export.py +++ b/test/pytest/test_torch_export.py @@ -19,9 +19,9 @@ MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml") -PT_220_AVAILABLE = ( +PT_230_AVAILABLE = ( True - if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1") + if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2") else False ) @@ -47,7 +47,7 @@ def custom_working_directory(tmp_path): os.chdir(tmp_path) -@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0") +@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0") def test_torch_export_aot_compile(custom_working_directory): # Get the path to the custom working directory model_dir = custom_working_directory @@ -88,7 +88,7 @@ def test_torch_export_aot_compile(custom_working_directory): assert labels == EXPECTED_RESULTS -@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0") +@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0") def test_torch_export_aot_compile_dynamic_batching(custom_working_directory): # Get the path to the custom working directory model_dir = custom_working_directory diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 175a91a4e2..9cda1cb746 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -53,10 +53,10 @@ ) PT2_AVAILABLE = False -if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"): - PT220_AVAILABLE = True +if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2"): + PT230_AVAILABLE = True else: - PT220_AVAILABLE = False + PT230_AVAILABLE = False if os.environ.get("TS_IPEX_ENABLE", "false") == "true": try: @@ -187,7 +187,7 @@ def initialize(self, context): elif ( self.model_pt_path.endswith(".so") and self._use_torch_export_aot_compile() - and PT220_AVAILABLE + and PT230_AVAILABLE ): # Set cuda device to the gpu_id of the backend worker # This is needed as the API for loading the exported model doesn't yet have a device id @@ -256,6 +256,14 @@ def initialize(self, context): self.initialized = True def _load_torch_export_aot_compile(self, model_so_path): + """Loads the PyTorch model so and returns a Callable object. + + Args: + model_pt_path (str): denotes the path of the model file. + + Returns: + (Callable Object) : Loads the model object. + """ return torch._export.aot_load(model_so_path, self.map_location) def _load_torchscript_model(self, model_pt_path): From 13faf6e2ced20700c677b4837b5eff20b7be663a Mon Sep 17 00:00:00 2001 From: agunapal Date: Wed, 24 Jan 2024 19:45:09 +0000 Subject: [PATCH 3/6] Updated to use new api torch._export.aot_load --- ts/handler_utils/torch_export/__init__.py | 0 ts/handler_utils/torch_export/load_model.py | 27 --------------------- 2 files changed, 27 deletions(-) delete mode 100644 ts/handler_utils/torch_export/__init__.py delete mode 100644 ts/handler_utils/torch_export/load_model.py diff --git a/ts/handler_utils/torch_export/__init__.py b/ts/handler_utils/torch_export/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ts/handler_utils/torch_export/load_model.py b/ts/handler_utils/torch_export/load_model.py deleted file mode 100644 index fe3ca867c5..0000000000 --- a/ts/handler_utils/torch_export/load_model.py +++ /dev/null @@ -1,27 +0,0 @@ -import tempfile - -import torch.fx._pytree as fx_pytree -from torch._inductor.utils import aot_inductor_launcher, cache_dir -from torch.utils import _pytree as pytree -from torch.utils.cpp_extension import load_inline - - -def load_exported_model(model_so_path, device): - module = load_inline( - name="aot_inductor", - cpp_sources=[aot_inductor_launcher(model_so_path, device)], - # use a unique build directory to avoid test interference - build_directory=tempfile.mkdtemp(dir=cache_dir()), - functions=["run", "get_call_spec"], - with_cuda=("cuda" == device), - ) - call_spec = module.get_call_spec() - in_spec = pytree.treespec_loads(call_spec[0]) - out_spec = pytree.treespec_loads(call_spec[1]) - - def optimized(*args): - flat_inputs = fx_pytree.tree_flatten_spec((args, {}), in_spec) - flat_outputs = module.run(flat_inputs) - return pytree.tree_unflatten(flat_outputs, out_spec) - - return optimized From 884964e21dc8ac0cc0a82a53508189200347cba2 Mon Sep 17 00:00:00 2001 From: agunapal Date: Thu, 25 Jan 2024 01:41:31 +0000 Subject: [PATCH 4/6] update the install script --- .../torch_export_aot_compile/install_pytorch_nightlies.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh b/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh index f7ee5c39db..0b2ab20ddf 100755 --- a/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh +++ b/examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh @@ -4,7 +4,11 @@ pip uninstall torchtext torchdata torch torchvision torchaudio -y # Install nightly PyTorch and torchvision from the specified index URL -pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed +if nvidia-smi > /dev/null 2>&1; then + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed +else + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu --ignore-installed +fi # Optional: Display the installed PyTorch and torchvision versions python -c "import torch; print('PyTorch version:', torch.__version__)" From 1b285674a803f9a383fb4b9122ad6de72a816aad Mon Sep 17 00:00:00 2001 From: agunapal Date: Thu, 25 Jan 2024 01:55:02 +0000 Subject: [PATCH 5/6] tested with CPU & batch size = 32 --- .../torch_export_aot_compile/resnet18_torch_export.py | 10 +++++++++- test/pytest/test_torch_export.py | 4 ++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py b/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py index db31228bac..0d74ea6bce 100644 --- a/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py +++ b/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py @@ -9,7 +9,15 @@ model.eval() with torch.no_grad(): - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + # The below config is needed for max batch_size = 16 + # https://github.com/pytorch/pytorch/pull/116152 + torch.backends.mkldnn.set_flags(False) + torch.backends.nnpack.set_flags(False) + model = model.to(device=device) example_inputs = (torch.randn(2, 3, 224, 224, device=device),) diff --git a/test/pytest/test_torch_export.py b/test/pytest/test_torch_export.py index 29ff517800..d0b40bc0ed 100644 --- a/test/pytest/test_torch_export.py +++ b/test/pytest/test_torch_export.py @@ -123,6 +123,6 @@ def test_torch_export_aot_compile_dynamic_batching(custom_working_directory): data["body"] = byte_array_type # Send a batch of 16 elements - result = handler.handle([data for i in range(15)], ctx) + result = handler.handle([data for i in range(32)], ctx) - assert len(result) == 15 + assert len(result) == 32 From cf6b75c56f1d1ac3bc4261f4aacd4615b48a95f2 Mon Sep 17 00:00:00 2001 From: agunapal Date: Mon, 29 Jan 2024 19:33:03 +0000 Subject: [PATCH 6/6] updated based on review comments --- .../torch_export_aot_compile/resnet18_torch_export.py | 9 +++++---- test/pytest/test_torch_export.py | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py b/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py index 0d74ea6bce..bf35477a94 100644 --- a/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py +++ b/examples/pt2/torch_export_aot_compile/resnet18_torch_export.py @@ -5,6 +5,8 @@ torch.set_float32_matmul_precision("high") +MAX_BATCH_SIZE = 32 + model = resnet18(weights=ResNet18_Weights.DEFAULT) model.eval() @@ -13,7 +15,8 @@ device = "cuda" else: device = "cpu" - # The below config is needed for max batch_size = 16 + # We need to turn off the below optimizations to support batch_size = 16, + # which is treated like a special case # https://github.com/pytorch/pytorch/pull/116152 torch.backends.mkldnn.set_flags(False) torch.backends.nnpack.set_flags(False) @@ -21,9 +24,7 @@ model = model.to(device=device) example_inputs = (torch.randn(2, 3, 224, 224, device=device),) - # Max value is 15 because of https://github.com/pytorch/pytorch/pull/116152 - # On a CUDA enabled device, we tested batch_size of 32. - batch_dim = torch.export.Dim("batch", min=2, max=15) + batch_dim = torch.export.Dim("batch", min=2, max=MAX_BATCH_SIZE) so_path = torch._export.aot_compile( model, example_inputs, diff --git a/test/pytest/test_torch_export.py b/test/pytest/test_torch_export.py index d0b40bc0ed..dbd16fd3df 100644 --- a/test/pytest/test_torch_export.py +++ b/test/pytest/test_torch_export.py @@ -30,6 +30,8 @@ ("kitten.jpg", EXPECTED_RESULTS[0]), ] +BATCH_SIZE = 32 + import os @@ -122,7 +124,7 @@ def test_torch_export_aot_compile_dynamic_batching(custom_working_directory): byte_array_type = bytearray(image_file) data["body"] = byte_array_type - # Send a batch of 16 elements - result = handler.handle([data for i in range(32)], ctx) + # Send a batch of BATCH_SIZE elements + result = handler.handle([data for i in range(BATCH_SIZE)], ctx) - assert len(result) == 32 + assert len(result) == BATCH_SIZE