Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update torch.export load with new api #2906

Merged
merged 10 commits into from
Jan 29, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
pip uninstall torchtext torchdata torch torchvision torchaudio -y

# Install nightly PyTorch and torchvision from the specified index URL
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed
if nvidia-smi > /dev/null 2>&1; then
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed
else
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu --ignore-installed
fi

# Optional: Display the installed PyTorch and torchvision versions
python -c "import torch; print('PyTorch version:', torch.__version__)"
Expand Down
10 changes: 9 additions & 1 deletion examples/pt2/torch_export_aot_compile/resnet18_torch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
model.eval()

with torch.no_grad():
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# The below config is needed for max batch_size = 16
# https://github.com/pytorch/pytorch/pull/116152
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the batch_size input to reflect the max batch_size limitation ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks..Updated the comments. There is no limitation on batch_size.

torch.backends.mkldnn.set_flags(False)
torch.backends.nnpack.set_flags(False)

model = model.to(device=device)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)

Expand Down
12 changes: 6 additions & 6 deletions test/pytest/test_torch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml")


PT_220_AVAILABLE = (
PT_230_AVAILABLE = (
True
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1")
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would set PT_230_AVAILABLE to true if a 2.2.3 is released. Is that correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ankithagunapal Is this check being needed due to API changes between PT2.2.0 and latest nightlies which are 2.3.xxx?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this needs the latest nightlies. The new API got merged only in the last week or so. Based on the current release cadence, its 2.3.0 after 2.2.2 ( 2 patch releases after every major release) But i will keep this in mind in case this changes.

else False
)

Expand All @@ -47,7 +47,7 @@ def custom_working_directory(tmp_path):
os.chdir(tmp_path)


@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0")
@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0")
def test_torch_export_aot_compile(custom_working_directory):
# Get the path to the custom working directory
model_dir = custom_working_directory
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_torch_export_aot_compile(custom_working_directory):
assert labels == EXPECTED_RESULTS


@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0")
@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0")
def test_torch_export_aot_compile_dynamic_batching(custom_working_directory):
# Get the path to the custom working directory
model_dir = custom_working_directory
Expand Down Expand Up @@ -123,6 +123,6 @@ def test_torch_export_aot_compile_dynamic_batching(custom_working_directory):
data["body"] = byte_array_type

# Send a batch of 16 elements
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: outdated comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated

result = handler.handle([data for i in range(15)], ctx)
result = handler.handle([data for i in range(32)], ctx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about parameterizing the test and test both?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I changed it


assert len(result) == 15
assert len(result) == 32
Empty file.
27 changes: 0 additions & 27 deletions ts/handler_utils/torch_export/load_model.py

This file was deleted.

18 changes: 12 additions & 6 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@
)
PT2_AVAILABLE = False

if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"):
PT220_AVAILABLE = True
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

PT230_AVAILABLE = True
else:
PT220_AVAILABLE = False
PT230_AVAILABLE = False

if os.environ.get("TS_IPEX_ENABLE", "false") == "true":
try:
Expand Down Expand Up @@ -187,7 +187,7 @@ def initialize(self, context):
elif (
self.model_pt_path.endswith(".so")
and self._use_torch_export_aot_compile()
and PT220_AVAILABLE
and PT230_AVAILABLE
):
# Set cuda device to the gpu_id of the backend worker
# This is needed as the API for loading the exported model doesn't yet have a device id
Expand Down Expand Up @@ -256,9 +256,15 @@ def initialize(self, context):
self.initialized = True

def _load_torch_export_aot_compile(self, model_so_path):
from ts.handler_utils.torch_export.load_model import load_exported_model
"""Loads the PyTorch model so and returns a Callable object.

return load_exported_model(model_so_path, self.map_location)
Args:
model_pt_path (str): denotes the path of the model file.

Returns:
(Callable Object) : Loads the model object.
"""
return torch._export.aot_load(model_so_path, self.map_location)

def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.
Expand Down