8585_ORIGINAL_PYTORCH_HIP_ALLOC_CONF = os .environ .get ("PYTORCH_HIP_ALLOC_CONF" )
8686_HAS_ORIGINAL_PYTORCH_ALLOC_CONF = "PYTORCH_ALLOC_CONF" in os .environ
8787
88+ # We support Pytorch 2
89+ # Fixes https://github.com/unslothai/unsloth/issues/38
90+ from importlib .metadata import version as importlib_version
91+ torch_version_raw = str (importlib_version ("torch" ))
92+ torch_version = str (re .match (r"[0-9\.]{3,}" , torch_version_raw ).group (0 )).split ("." )
93+ major_torch , minor_torch = torch_version [0 ], torch_version [1 ]
94+ major_torch , minor_torch = int (major_torch ), int (minor_torch )
95+ IS_TORCH_2_9_OR_NEWER = (major_torch > 2 ) or (major_torch == 2 and minor_torch >= 9 )
96+ IS_TORCH_ROCM_BUILD = "+rocm" in torch_version_raw .lower ()
97+
8898# Reduce VRAM usage by reducing fragmentation
8999# And optimize pinning of memory
90100if os .environ .get ("UNSLOTH_VLLM_STANDBY" , "0" ) == "0" :
91- if "PYTORCH_CUDA_ALLOC_CONF" not in os .environ :
92- os .environ ["PYTORCH_CUDA_ALLOC_CONF" ] = \
93- "expandable_segments:True," \
94- "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"
95- if "PYTORCH_HIP_ALLOC_CONF" not in os .environ :
96- # [TODO] Check if AMD works with roundup_power2_divisions
97- os .environ ["PYTORCH_HIP_ALLOC_CONF" ] = "expandable_segments:True"
98- if "PYTORCH_ALLOC_CONF" not in os .environ :
99- os .environ ["PYTORCH_ALLOC_CONF" ] = "expandable_segments:True"
101+ if IS_TORCH_2_9_OR_NEWER :
102+ if "PYTORCH_ALLOC_CONF" not in os .environ :
103+ os .environ ["PYTORCH_ALLOC_CONF" ] = "expandable_segments:True"
104+ else :
105+ if "PYTORCH_CUDA_ALLOC_CONF" not in os .environ :
106+ os .environ ["PYTORCH_CUDA_ALLOC_CONF" ] = \
107+ "expandable_segments:True," \
108+ "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"
109+ if "PYTORCH_HIP_ALLOC_CONF" not in os .environ :
110+ # [TODO] Check if AMD works with roundup_power2_divisions
111+ os .environ ["PYTORCH_HIP_ALLOC_CONF" ] = "expandable_segments:True"
112+ if "PYTORCH_ALLOC_CONF" not in os .environ :
113+ os .environ ["PYTORCH_ALLOC_CONF" ] = "expandable_segments:True"
100114elif os .environ .get ("UNSLOTH_VLLM_STANDBY" , "0" ) == "1" :
101115 for key in ("PYTORCH_CUDA_ALLOC_CONF" , "PYTORCH_HIP_ALLOC_CONF" , "PYTORCH_ALLOC_CONF" ,):
102116 if "expandable_segments:True" in os .environ .get (key , "" ):
107121 )
108122 os .environ [key ] = re .sub (r"expandable\_segments\:True\,?" , "" , os .environ [key ])
109123
110- # We support Pytorch 2
111- # Fixes https://github.com/unslothai/unsloth/issues/38
112- from importlib .metadata import version as importlib_version
113- torch_version_raw = str (importlib_version ("torch" ))
114- torch_version = str (re .match (r"[0-9\.]{3,}" , torch_version_raw ).group (0 )).split ("." )
115- major_torch , minor_torch = torch_version [0 ], torch_version [1 ]
116- major_torch , minor_torch = int (major_torch ), int (minor_torch )
117- IS_TORCH_ROCM_BUILD = "+rocm" in torch_version_raw .lower ()
118124def delete_key (key ):
119125 if key in os .environ : del os .environ [key ]
120126
@@ -203,8 +209,8 @@ def filter(self, x): return not (self.text in x.getMessage())
203209)
204210IS_HIP_RUNTIME = (DEVICE_TYPE == "hip" ) or bool (is_hip ())
205211
206- # Torch 2.9 removed PYTORCH_HIP_ALLOC_CONF and PYTORCH_CUDA_ALLOC_CONF
207- if major_torch == 2 and minor_torch >= 9 :
212+ # Torch >= 2.9 uses PYTORCH_ALLOC_CONF and treats legacy per-backend vars as deprecated.
213+ if IS_TORCH_2_9_OR_NEWER :
208214 # Preserve explicit legacy allocator settings when user did not directly set PYTORCH_ALLOC_CONF.
209215 if not _HAS_ORIGINAL_PYTORCH_ALLOC_CONF :
210216 promoted = _ORIGINAL_PYTORCH_CUDA_ALLOC_CONF
@@ -220,7 +226,7 @@ def filter(self, x): return not (self.text in x.getMessage())
220226
221227# Specify PYTORCH_CUDA_ALLOC_CONF or PYTORCH_HIP_ALLOC_CONF
222228if IS_HIP_RUNTIME :
223- if major_torch == 2 and minor_torch >= 9 :
229+ if IS_TORCH_2_9_OR_NEWER :
224230 # PyTorch >= 2.9 uses PYTORCH_ALLOC_CONF. expandable_segments is unsupported on HIP.
225231 remove_expandable_segments ("PYTORCH_ALLOC_CONF" )
226232 delete_key ("PYTORCH_CUDA_ALLOC_CONF" )
@@ -236,7 +242,7 @@ def filter(self, x): return not (self.text in x.getMessage())
236242 remove_expandable_segments ("PYTORCH_HIP_ALLOC_CONF" )
237243 remove_expandable_segments ("PYTORCH_ALLOC_CONF" )
238244 delete_key ("PYTORCH_CUDA_ALLOC_CONF" )
239- elif DEVICE_TYPE == "cuda" and not IS_HIP_RUNTIME and not ( major_torch == 2 and minor_torch >= 9 ) :
245+ elif DEVICE_TYPE == "cuda" and not IS_HIP_RUNTIME and not IS_TORCH_2_9_OR_NEWER :
240246 delete_key ("PYTORCH_HIP_ALLOC_CONF" )
241247 delete_key ("PYTORCH_ALLOC_CONF" )
242248
@@ -247,7 +253,7 @@ def filter(self, x): return not (self.text in x.getMessage())
247253elif DEVICE_TYPE == "hip" :
248254 # CCE also fails in HIP / AMD
249255 os .environ ["UNSLOTH_ENABLE_CCE" ] = "0"
250- del remove_expandable_segments , delete_key , IS_HIP_RUNTIME , IS_TORCH_ROCM_BUILD , major_torch , minor_torch , torch_version , torch_version_raw , importlib_version , find_spec
256+ del remove_expandable_segments , delete_key , IS_HIP_RUNTIME , IS_TORCH_2_9_OR_NEWER , IS_TORCH_ROCM_BUILD , major_torch , minor_torch , torch_version , torch_version_raw , importlib_version , find_spec
251257del clean_expandable_segments_value
252258del _ORIGINAL_PYTORCH_CUDA_ALLOC_CONF , _ORIGINAL_PYTORCH_HIP_ALLOC_CONF , _HAS_ORIGINAL_PYTORCH_ALLOC_CONF
253259
0 commit comments