Skip to content

CachedMultipleNegativesRankingLoss + MPS is broken #3564

@morrisalp

Description

@morrisalp

Sentence-transformers version: 5.1.1
Python 3.11.14
macOS Sequoia 15.6
MacBook Pro with M4 Pro chip

CachedMultipleNegativesRankingLoss fails due to trying to call non-existent torch.mps.device():

File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/trainer.py", line 431, in compute_loss
    loss = loss_fn(features, labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/malper/repos/morphograph/src/train.py", line 813, in forward
    loss_value = super().forward(sentence_features, labels)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py", line 286, in forward
    for reps_mb, random_state in self.embed_minibatch_iter(
  File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py", line 215, in embed_minibatch_iter
    reps, random_state = self.embed_minibatch(
                         ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py", line 191, in embed_minibatch
    random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py", line 29, in __init__
    self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/malper/repos/morphograph/.venv/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 178, in get_device_states
    with device_module.device(device_id):
         ^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'torch.mps' has no attribute 'device'

I got it running via monkey-patching:

import torch
from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import RandContext

original_init = RandContext.__init__

def patched_init(self, *tensors):
    if tensors and hasattr(tensors[0], 'device'):
        device = tensors[0].device
        if device.type == 'mps':
            # Skip device state tracking for MPS
            self.fwd_cpu_state = torch.get_rng_state()
            self.fwd_gpu_devices = []
            self.fwd_gpu_states = []
            return
    original_init(self, *tensors)

RandContext.__init__ = patched_init

This could be incorporated into sentence-transformers to fix CachedMultipleNegativesRankingLoss on MPS backends.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions