Skip to content

Commit c62ca74

Browse files
committed
Fixed setup_torch_device to only enable MPS for ARM based macs as pytorch doesn't really support MPS on non-apple silicon
1 parent f2b78ef commit c62ca74

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

audio_separator/separator/separator.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ def setup_accelerated_inferencing_device(self):
158158
"""
159159
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
160160
"""
161-
self.log_system_info()
161+
system_info = self.get_system_info()
162162
self.check_ffmpeg_installed()
163163
self.log_onnxruntime_packages()
164-
self.setup_torch_device()
164+
self.setup_torch_device(system_info)
165165

166-
def log_system_info(self):
166+
def get_system_info(self):
167167
"""
168168
This method logs the system information, including the operating system, CPU archutecture and Python version
169169
"""
@@ -179,6 +179,7 @@ def log_system_info(self):
179179

180180
pytorch_version = torch.__version__
181181
self.logger.info(f"PyTorch Version: {pytorch_version}")
182+
return system_info
182183

183184
def check_ffmpeg_installed(self):
184185
"""
@@ -210,7 +211,7 @@ def log_onnxruntime_packages(self):
210211
if onnxruntime_cpu_package is not None:
211212
self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")
212213

213-
def setup_torch_device(self):
214+
def setup_torch_device(self, system_info):
214215
"""
215216
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
216217
"""
@@ -222,7 +223,7 @@ def setup_torch_device(self):
222223
if torch.cuda.is_available():
223224
self.configure_cuda(ort_providers)
224225
hardware_acceleration_enabled = True
225-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
226+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
226227
self.configure_mps(ort_providers)
227228
hardware_acceleration_enabled = True
228229

@@ -247,7 +248,7 @@ def configure_mps(self, ort_providers):
247248
"""
248249
This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available.
249250
"""
250-
self.logger.info("Apple Silicon MPS/CoreML is available in Torch, setting Torch device to MPS")
251+
self.logger.info("Apple Silicon MPS/CoreML is available in Torch and processor is ARM, setting Torch device to MPS")
251252
self.torch_device_mps = torch.device("mps")
252253

253254
self.torch_device = self.torch_device_mps

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
44

55
[tool.poetry]
66
name = "audio-separator"
7-
version = "0.17.5"
7+
version = "0.17.6"
88
description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07"
99
authors = ["Andrew Beveridge <[email protected]>"]
1010
license = "MIT"

0 commit comments

Comments
 (0)