Skip to content

Conversation

@whitebox2
Copy link
Contributor

Summary

On Apple Metal, ModernBertHead::forward may receive a non-contiguous view from upstream pooling (CLS/MEAN). Passing that view directly into Linear causes the Metal matmul path to misinterpret shapes and fail with:

Invalid matmul arguments [17920, 1] [1, 256] (2, 256, 256)

CPU succeeds because it tolerates non-contiguous inputs. We materialize the input with contiguous() before applying Linear. This resolves the Metal-only failure and aligns with prior guidance that GPU matmul expects contiguous tensors. ([GitHub][1])

Patch

 impl Module for ModernBertHead {
     fn forward(&self, xs: &Tensor) -> Result<Tensor> {
-        let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?;
+        let xs = xs.contiguous()?.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?;
         Ok(xs)
     }
 }

Issue

#3138 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant