-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Open
Description
Sentence-transformers version: 5.1.1
Python 3.11.14
macOS Sequoia 15.6
MacBook Pro with M4 Pro chip
CachedMultipleNegativesRankingLoss fails due to trying to call non-existent torch.mps.device():
File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/trainer.py", line 431, in compute_loss
loss = loss_fn(features, labels)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/malper/repos/morphograph/src/train.py", line 813, in forward
loss_value = super().forward(sentence_features, labels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py", line 286, in forward
for reps_mb, random_state in self.embed_minibatch_iter(
File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py", line 215, in embed_minibatch_iter
reps, random_state = self.embed_minibatch(
^^^^^^^^^^^^^^^^^^^^^
File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py", line 191, in embed_minibatch
random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py", line 29, in __init__
self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 178, in get_device_states
with device_module.device(device_id):
^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'torch.mps' has no attribute 'device'
I got it running via monkey-patching:
import torch
from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import RandContext
original_init = RandContext.__init__
def patched_init(self, *tensors):
if tensors and hasattr(tensors[0], 'device'):
device = tensors[0].device
if device.type == 'mps':
# Skip device state tracking for MPS
self.fwd_cpu_state = torch.get_rng_state()
self.fwd_gpu_devices = []
self.fwd_gpu_states = []
return
original_init(self, *tensors)
RandContext.__init__ = patched_init
This could be incorporated into sentence-transformers to fix CachedMultipleNegativesRankingLoss on MPS backends.
Metadata
Metadata
Assignees
Labels
No labels