-
Notifications
You must be signed in to change notification settings - Fork 878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update torch.export load with new api #2906
Changes from 7 commits
224d25d
42b684f
13faf6e
884964e
1b28567
da1e262
a64b65b
cf6b75c
c27f71a
300e6df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,9 +19,9 @@ | |
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml") | ||
|
||
|
||
PT_220_AVAILABLE = ( | ||
PT_230_AVAILABLE = ( | ||
True | ||
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1") | ||
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would set PT_230_AVAILABLE to true if a 2.2.3 is released. Is that correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ankithagunapal Is this check being needed due to API changes between PT2.2.0 and latest nightlies which are 2.3.xxx? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this needs the latest nightlies. The new API got merged only in the last week or so. Based on the current release cadence, its 2.3.0 after 2.2.2 ( 2 patch releases after every major release) But i will keep this in mind in case this changes. |
||
else False | ||
) | ||
|
||
|
@@ -47,7 +47,7 @@ def custom_working_directory(tmp_path): | |
os.chdir(tmp_path) | ||
|
||
|
||
@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0") | ||
@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0") | ||
def test_torch_export_aot_compile(custom_working_directory): | ||
# Get the path to the custom working directory | ||
model_dir = custom_working_directory | ||
|
@@ -88,7 +88,7 @@ def test_torch_export_aot_compile(custom_working_directory): | |
assert labels == EXPECTED_RESULTS | ||
|
||
|
||
@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0") | ||
@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0") | ||
def test_torch_export_aot_compile_dynamic_batching(custom_working_directory): | ||
# Get the path to the custom working directory | ||
model_dir = custom_working_directory | ||
|
@@ -123,6 +123,6 @@ def test_torch_export_aot_compile_dynamic_batching(custom_working_directory): | |
data["body"] = byte_array_type | ||
|
||
# Send a batch of 16 elements | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: outdated comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Updated |
||
result = handler.handle([data for i in range(15)], ctx) | ||
result = handler.handle([data for i in range(32)], ctx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about parameterizing the test and test both? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching this. I changed it |
||
|
||
assert len(result) == 15 | ||
assert len(result) == 32 |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,10 +53,10 @@ | |
) | ||
PT2_AVAILABLE = False | ||
|
||
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"): | ||
PT220_AVAILABLE = True | ||
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
PT230_AVAILABLE = True | ||
else: | ||
PT220_AVAILABLE = False | ||
PT230_AVAILABLE = False | ||
|
||
if os.environ.get("TS_IPEX_ENABLE", "false") == "true": | ||
try: | ||
|
@@ -187,7 +187,7 @@ def initialize(self, context): | |
elif ( | ||
self.model_pt_path.endswith(".so") | ||
and self._use_torch_export_aot_compile() | ||
and PT220_AVAILABLE | ||
and PT230_AVAILABLE | ||
): | ||
# Set cuda device to the gpu_id of the backend worker | ||
# This is needed as the API for loading the exported model doesn't yet have a device id | ||
|
@@ -256,9 +256,15 @@ def initialize(self, context): | |
self.initialized = True | ||
|
||
def _load_torch_export_aot_compile(self, model_so_path): | ||
from ts.handler_utils.torch_export.load_model import load_exported_model | ||
"""Loads the PyTorch model so and returns a Callable object. | ||
|
||
return load_exported_model(model_so_path, self.map_location) | ||
Args: | ||
model_pt_path (str): denotes the path of the model file. | ||
|
||
Returns: | ||
(Callable Object) : Loads the model object. | ||
""" | ||
return torch._export.aot_load(model_so_path, self.map_location) | ||
|
||
def _load_torchscript_model(self, model_pt_path): | ||
"""Loads the PyTorch model and returns the NN model object. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check the batch_size input to reflect the max batch_size limitation ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks..Updated the comments. There is no limitation on batch_size.