@@ -158,12 +158,12 @@ def setup_accelerated_inferencing_device(self):
158
158
"""
159
159
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
160
160
"""
161
- self .log_system_info ()
161
+ system_info = self .get_system_info ()
162
162
self .check_ffmpeg_installed ()
163
163
self .log_onnxruntime_packages ()
164
- self .setup_torch_device ()
164
+ self .setup_torch_device (system_info )
165
165
166
- def log_system_info (self ):
166
+ def get_system_info (self ):
167
167
"""
168
168
This method logs the system information, including the operating system, CPU archutecture and Python version
169
169
"""
@@ -179,6 +179,7 @@ def log_system_info(self):
179
179
180
180
pytorch_version = torch .__version__
181
181
self .logger .info (f"PyTorch Version: { pytorch_version } " )
182
+ return system_info
182
183
183
184
def check_ffmpeg_installed (self ):
184
185
"""
@@ -210,7 +211,7 @@ def log_onnxruntime_packages(self):
210
211
if onnxruntime_cpu_package is not None :
211
212
self .logger .info (f"ONNX Runtime CPU package installed with version: { onnxruntime_cpu_package .version } " )
212
213
213
- def setup_torch_device (self ):
214
+ def setup_torch_device (self , system_info ):
214
215
"""
215
216
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
216
217
"""
@@ -222,7 +223,7 @@ def setup_torch_device(self):
222
223
if torch .cuda .is_available ():
223
224
self .configure_cuda (ort_providers )
224
225
hardware_acceleration_enabled = True
225
- elif hasattr (torch .backends , "mps" ) and torch .backends .mps .is_available ():
226
+ elif hasattr (torch .backends , "mps" ) and torch .backends .mps .is_available () and system_info . processor == "arm" :
226
227
self .configure_mps (ort_providers )
227
228
hardware_acceleration_enabled = True
228
229
@@ -247,7 +248,7 @@ def configure_mps(self, ort_providers):
247
248
"""
248
249
This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available.
249
250
"""
250
- self .logger .info ("Apple Silicon MPS/CoreML is available in Torch, setting Torch device to MPS" )
251
+ self .logger .info ("Apple Silicon MPS/CoreML is available in Torch and processor is ARM , setting Torch device to MPS" )
251
252
self .torch_device_mps = torch .device ("mps" )
252
253
253
254
self .torch_device = self .torch_device_mps
0 commit comments