Skip to content

Commit 96d84f8

Browse files
committed
Fix allocator env handling for torch 2.9+ and future versions
1 parent b6dbba1 commit 96d84f8

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

unsloth_zoo/__init__.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,32 @@
8585
_ORIGINAL_PYTORCH_HIP_ALLOC_CONF = os.environ.get("PYTORCH_HIP_ALLOC_CONF")
8686
_HAS_ORIGINAL_PYTORCH_ALLOC_CONF = "PYTORCH_ALLOC_CONF" in os.environ
8787

88+
# We support Pytorch 2
89+
# Fixes https://github.com/unslothai/unsloth/issues/38
90+
from importlib.metadata import version as importlib_version
91+
torch_version_raw = str(importlib_version("torch"))
92+
torch_version = str(re.match(r"[0-9\.]{3,}", torch_version_raw).group(0)).split(".")
93+
major_torch, minor_torch = torch_version[0], torch_version[1]
94+
major_torch, minor_torch = int(major_torch), int(minor_torch)
95+
IS_TORCH_2_9_OR_NEWER = (major_torch > 2) or (major_torch == 2 and minor_torch >= 9)
96+
IS_TORCH_ROCM_BUILD = "+rocm" in torch_version_raw.lower()
97+
8898
# Reduce VRAM usage by reducing fragmentation
8999
# And optimize pinning of memory
90100
if os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "0":
91-
if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
92-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
93-
"expandable_segments:True,"\
94-
"roundup_power2_divisions:[32:256,64:128,256:64,>:32]"
95-
if "PYTORCH_HIP_ALLOC_CONF" not in os.environ:
96-
# [TODO] Check if AMD works with roundup_power2_divisions
97-
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "expandable_segments:True"
98-
if "PYTORCH_ALLOC_CONF" not in os.environ:
99-
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
101+
if IS_TORCH_2_9_OR_NEWER:
102+
if "PYTORCH_ALLOC_CONF" not in os.environ:
103+
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
104+
else:
105+
if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
106+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
107+
"expandable_segments:True,"\
108+
"roundup_power2_divisions:[32:256,64:128,256:64,>:32]"
109+
if "PYTORCH_HIP_ALLOC_CONF" not in os.environ:
110+
# [TODO] Check if AMD works with roundup_power2_divisions
111+
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "expandable_segments:True"
112+
if "PYTORCH_ALLOC_CONF" not in os.environ:
113+
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
100114
elif os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "1":
101115
for key in ("PYTORCH_CUDA_ALLOC_CONF", "PYTORCH_HIP_ALLOC_CONF", "PYTORCH_ALLOC_CONF",):
102116
if "expandable_segments:True" in os.environ.get(key, ""):
@@ -107,14 +121,6 @@
107121
)
108122
os.environ[key] = re.sub(r"expandable\_segments\:True\,?", "", os.environ[key])
109123

110-
# We support Pytorch 2
111-
# Fixes https://github.com/unslothai/unsloth/issues/38
112-
from importlib.metadata import version as importlib_version
113-
torch_version_raw = str(importlib_version("torch"))
114-
torch_version = str(re.match(r"[0-9\.]{3,}", torch_version_raw).group(0)).split(".")
115-
major_torch, minor_torch = torch_version[0], torch_version[1]
116-
major_torch, minor_torch = int(major_torch), int(minor_torch)
117-
IS_TORCH_ROCM_BUILD = "+rocm" in torch_version_raw.lower()
118124
def delete_key(key):
119125
if key in os.environ: del os.environ[key]
120126

@@ -203,8 +209,8 @@ def filter(self, x): return not (self.text in x.getMessage())
203209
)
204210
IS_HIP_RUNTIME = (DEVICE_TYPE == "hip") or bool(is_hip())
205211

206-
# Torch 2.9 removed PYTORCH_HIP_ALLOC_CONF and PYTORCH_CUDA_ALLOC_CONF
207-
if major_torch == 2 and minor_torch >= 9:
212+
# Torch >= 2.9 uses PYTORCH_ALLOC_CONF and treats legacy per-backend vars as deprecated.
213+
if IS_TORCH_2_9_OR_NEWER:
208214
# Preserve explicit legacy allocator settings when user did not directly set PYTORCH_ALLOC_CONF.
209215
if not _HAS_ORIGINAL_PYTORCH_ALLOC_CONF:
210216
promoted = _ORIGINAL_PYTORCH_CUDA_ALLOC_CONF
@@ -220,7 +226,7 @@ def filter(self, x): return not (self.text in x.getMessage())
220226

221227
# Specify PYTORCH_CUDA_ALLOC_CONF or PYTORCH_HIP_ALLOC_CONF
222228
if IS_HIP_RUNTIME:
223-
if major_torch == 2 and minor_torch >= 9:
229+
if IS_TORCH_2_9_OR_NEWER:
224230
# PyTorch >= 2.9 uses PYTORCH_ALLOC_CONF. expandable_segments is unsupported on HIP.
225231
remove_expandable_segments("PYTORCH_ALLOC_CONF")
226232
delete_key("PYTORCH_CUDA_ALLOC_CONF")
@@ -236,7 +242,7 @@ def filter(self, x): return not (self.text in x.getMessage())
236242
remove_expandable_segments("PYTORCH_HIP_ALLOC_CONF")
237243
remove_expandable_segments("PYTORCH_ALLOC_CONF")
238244
delete_key("PYTORCH_CUDA_ALLOC_CONF")
239-
elif DEVICE_TYPE == "cuda" and not IS_HIP_RUNTIME and not (major_torch == 2 and minor_torch >= 9):
245+
elif DEVICE_TYPE == "cuda" and not IS_HIP_RUNTIME and not IS_TORCH_2_9_OR_NEWER:
240246
delete_key("PYTORCH_HIP_ALLOC_CONF")
241247
delete_key("PYTORCH_ALLOC_CONF")
242248

@@ -247,7 +253,7 @@ def filter(self, x): return not (self.text in x.getMessage())
247253
elif DEVICE_TYPE == "hip":
248254
# CCE also fails in HIP / AMD
249255
os.environ["UNSLOTH_ENABLE_CCE"] = "0"
250-
del remove_expandable_segments, delete_key, IS_HIP_RUNTIME, IS_TORCH_ROCM_BUILD, major_torch, minor_torch, torch_version, torch_version_raw, importlib_version, find_spec
256+
del remove_expandable_segments, delete_key, IS_HIP_RUNTIME, IS_TORCH_2_9_OR_NEWER, IS_TORCH_ROCM_BUILD, major_torch, minor_torch, torch_version, torch_version_raw, importlib_version, find_spec
251257
del clean_expandable_segments_value
252258
del _ORIGINAL_PYTORCH_CUDA_ALLOC_CONF, _ORIGINAL_PYTORCH_HIP_ALLOC_CONF, _HAS_ORIGINAL_PYTORCH_ALLOC_CONF
253259

0 commit comments

Comments
 (0)