Skip to content

Commit dd14fab

Browse files
authored
[ZENTORCH-INFRA] Changes to be compatible with Torch2.9 (#31)
* [ZENTORCH-INFRA] Changes to be compatible with Torch2.9 - ZENAI-2264 - Changed range of K to be (2, 10) so that we don't end up with (1,1) tensors for addmm - Added graceful exit for LLMs if IPEX version doesn't match torch version Signed-off-by: Nimisha Gupta <[email protected]> Change-Id: I135e5025db0dbc90a8944a59d70b4f2d3bb94445
1 parent dcdb8ce commit dd14fab

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

src/cpu/python/zentorch/llm/_checks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def essential_checks(model, dtype):
5858
if is_well_supported_model:
5959
installed_ipex_version = get_installed_ipex_version()
6060
if installed_ipex_version:
61-
# Zentorch will work with IPEX of atleast 2.6
62-
min_version = TorchVersion("2.6")
61+
# Zentorch will work with IPEX of atleast 2.8
62+
min_version = TorchVersion("2.8")
6363
installed_ipex_version = TorchVersion(installed_ipex_version)
6464

6565
if installed_ipex_version >= min_version:

src/cpu/python/zentorch/llm/_optimize.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# ******************************************************************************
55

66
import torch
7-
from ._checks import essential_checks
7+
from ._checks import essential_checks, get_installed_ipex_version
88
import zentorch._C
99
import zentorch._WOQLinear as WOQLinear
1010
from .._logging import get_logger
11+
import sys
12+
from packaging.version import Version
1113

1214
# make a logger for this file
1315
logger = get_logger(__name__)
@@ -50,13 +52,37 @@ def check_for_shared_params(model):
5052

5153
def optimize(model, dtype=torch.bfloat16):
5254
if essential_checks(model, dtype):
53-
import intel_extension_for_pytorch as ipex
54-
from ._model_conversion_functions import model_convert_lowering, customize_model
55+
# Create version objects
56+
torch_version = Version(torch.__version__)
57+
ipex_version = Version(get_installed_ipex_version())
58+
59+
if torch_version.major != ipex_version.major or \
60+
torch_version.minor != ipex_version.minor:
61+
logger.error(
62+
"Detected Torch version %s is incompatible with IPEX version %s."
63+
"We recommend running with Torch 2.8.0+cpu and IPEX 2.8.0. Exiting.",
64+
torch.__version__,
65+
ipex_version,
66+
)
67+
sys.exit()
68+
69+
try:
70+
import intel_extension_for_pytorch as ipex
71+
except Exception:
72+
logger.error(
73+
"Error occurred in importing Intel Extension for PyTorch"
74+
)
75+
sys.exit()
76+
77+
from ._model_conversion_functions import (
78+
model_convert_lowering,
79+
customize_model,
80+
)
5581

5682
ipex_t = ipex.transformers
5783

5884
# For masked multihead attention, the meta registration uses dynamic shape outputs
59-
# To ensure the dynamic shapes do not cause a greph break
85+
# To ensure the dynamic shapes do not cause a graph break
6086
torch._dynamo.config.capture_dynamic_output_shape_ops = True
6187
torch._dynamo.config.capture_scalar_outputs = True
6288
# Runtime over-riding of IPEX model_convert_lowering with ZenTorch

test/zentorch_test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def get_max(self):
6565

6666
B_RANGE = Range(1, 10)
6767
M_RANGE = Range(1, 10)
68-
K_RANGE = Range(1, 10)
68+
K_RANGE = Range(2, 10)
6969
N_RANGE = Range(1, 10)
7070

7171
P_RANGE = Range(1, 11)

0 commit comments

Comments
 (0)