Skip to content

Comments

[WIP] Apple Silicon (MPS/Metal) Support#3950

Draft
Wilbatronic wants to merge 723 commits intounslothai:mainfrom
Wilbatronic:apple-silicon-support
Draft

[WIP] Apple Silicon (MPS/Metal) Support#3950
Wilbatronic wants to merge 723 commits intounslothai:mainfrom
Wilbatronic:apple-silicon-support

Conversation

@Wilbatronic
Copy link

This PR introduces high-performance Apple Silicon support for Unsloth. The goal is to allow Mac users (M1/M2/M3/M4) to fine-tune and run inference on 7B+ models with performance parity to entry-level CUDA hardware, leveraging Apple's Unified Memory and Metal architecture.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Wilbatronic, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands Unsloth's hardware compatibility by introducing initial support for Apple Silicon (MPS/Metal) devices. The changes enable Mac users with M-series chips to leverage their hardware for fine-tuning and inference of large language models, aiming for performance comparable to entry-level CUDA GPUs. This is achieved through a combination of MPS-specific kernel implementations, intelligent device detection, and conditional module loading to adapt to the unique architecture of Apple Silicon, where libraries like Triton and bitsandbytes are not natively supported.

Highlights

  • Apple Silicon (MPS/Metal) Support: Introduced core support for Apple Silicon (MPS/Metal) devices, enabling Mac users to leverage their M-series chips for high-performance fine-tuning and inference of large language models.
  • PyTorch-Native Kernel Fallbacks: Implemented PyTorch-native fallback kernels for critical operations such as RMS LayerNorm, LayerNorm, RoPE embedding, Cross-Entropy Loss, SwiGLU, GEGLU, and LoRA operations. These replace Triton-based kernels on MPS, ensuring functionality and numerical parity.
  • Intelligent Device Detection and Capabilities: Added robust MPS device detection, including checks for bfloat16 support and unified memory information. The system now conditionally loads modules and functionalities, disabling unsupported libraries like Triton and bitsandbytes for MPS.
  • Quantization Handling and Warnings: Disabled bitsandbytes imports and functionalities for MPS devices, as they are not supported. Graceful fallbacks and user warnings are now in place for attempts to load quantized models on MPS, guiding users towards 16-bit models for optimal performance.
  • Comprehensive Testing Suite: Developed a new suite of unit and integration tests specifically for MPS. These tests verify device detection, numerical parity of MPS fallback kernels, and overall system integration, with conditional skipping on non-MPS hardware.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is an impressive and comprehensive pull request that adds Apple Silicon (MPS) support to Unsloth. The changes are well-structured, with clear separation of MPS-specific code, PyTorch-native fallbacks for Triton kernels, and a clean dispatching mechanism. The addition of extensive tests for numerical parity, integration, and sanity checks is commendable and crucial for ensuring correctness on the new backend.

I've identified a critical bug in the RoPE kernel that would lead to incorrect tensor shapes, a high-severity issue in the device stream handling logic, and some dead code that should be removed for better maintainability. After addressing these points, this PR will be a fantastic addition to the project, significantly expanding its user base to Mac users.

Comment on lines 291 to 293

if DEVICE_TYPE == "mps" and USE_MPS_FALLBACK:
from .mps.rope_embedding import mps_rope_embedding_qk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a shape mismatch bug in the MPS fallback for RoPE. The mps_rope_embedding_qk function returns tensors with shape (batch, n_heads, seq_len, head_dim), but the caller of fast_rope_embedding expects the original shape (batch, seq_len, n_heads, head_dim).

The output tensors need to be transposed back to the expected shape before being returned, just like it's done for the Triton path. Without this, downstream operations will fail due to incorrect tensor shapes.

        q_out, k_out = mps_rope_embedding_qk(Q.transpose(1, 2).contiguous(), K.transpose(1, 2).contiguous(), cos, sin)
        return q_out.transpose(1, 2), k_out.transpose(1, 2)

Comment on lines +260 to 263
cgemm_4bit_inference_naive_fp16 = None
cgemm_4bit_inference_naive_bf16 = None
else:
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a bug in the torch_device_stream definition. torch.cuda.current_stream() is being called, which returns a stream object, whereas the other branches return a function that returns a stream object. This will cause an error on CUDA devices when torch_device_stream is used.

The expression is also quite complex and hard to read. I suggest refactoring it into a simple if/elif/else block for clarity and to fix the bug.

Suggested change
cgemm_4bit_inference_naive_fp16 = None
cgemm_4bit_inference_naive_bf16 = None
else:
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
if DEVICE_TYPE == "xpu":
torch_device_stream = torch.xpu.current_stream
elif DEVICE_TYPE == "mps":
torch_device_stream = lambda: None
else:
torch_device_stream = torch.cuda.current_stream

Comment on lines 36 to 65

return out


class MPSLoRA_MLP(torch.autograd.Function):
@staticmethod
def forward(
ctx,
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
_forward_function,
):
# Forward pass using MPS-compatible operations
e = mps_matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
g = mps_matmul_lora(X, upW, upW_quant, upA, upB, upS)
h = _forward_function(e, g)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The MPSLoRA_MLP class appears to be unused. The dispatch logic in unsloth/kernels/mps/dispatch.py for dispatch_lora_mlp_swiglu calls mps_apply_lora_mlp_swiglu, which uses a direct PyTorch-native implementation instead of this torch.autograd.Function.

Since the backward method is not implemented and raises a NotImplementedError, and the class itself is not being used, it would be best to remove it to avoid confusion and dead code in the repository.

@Wilbatronic Wilbatronic force-pushed the apple-silicon-support branch 2 times, most recently from 4df41f8 to 17f4ddd Compare February 6, 2026 20:47
…ices to avoid 'element 0 of tensors does not require grad' error
…fication

The _backward_function modifies DW, e, and g in-place. Since e and g
come from ctx.saved_tensors, they don't have requires_grad=True.
Modifying them in-place breaks the gradient computation graph.

Fix: Clone tensors before passing to _backward_function so the
original saved_tensors are preserved and gradient tracking works
correctly.
Replace in-place .addmm_() operations with non-in-place .addmm() in
MPSLoRA_MLP.backward() to fix 'element 0 of tensors does not require
grad' error when using gradient checkpointing on MPS.

When gradient checkpointing saves tensors via ctx.save_for_backward(),
those tensors don't have requires_grad=True. The backward function was
performing in-place operations on these saved tensors, breaking the
autograd graph on MPS backend.

Changes made:
- Lines 208-224: Changed LoRA weight gradient computations
- Lines 233-242: Changed dX accumulation operations

All operations now use addmm() instead of addmm_() to preserve gradient
computation graph compatibility with gradient checkpointing.
This commit addresses the persistent 'element 0 of tensors does not require
grad and does not have a grad_fn' error when training on Apple Silicon with
gradient checkpointing enabled.

Root Cause:
The MPS custom autograd fallback kernels (MPSLoRA_MLP, MPSLoRA_O, etc.) have
compatibility issues with PyTorch's gradient checkpointing on MPS. When
gradient checkpointing saves tensors via ctx.save_for_backward(), those
tensors don't have requires_grad=True. The custom backward functions in the
MPS kernels try to perform operations on these saved tensors, causing the
autograd graph to break.

Changes Made:
1. unsloth/kernels/mps/__init__.py: Added documentation about the issue
2. unsloth/kernels/mps/fast_lora.py: Replaced in-place .addmm_() operations
   with non-in-place .addmm() in MPSLoRA_MLP backward pass
3. unsloth/kernels/mps/cross_entropy_loss.py: Replaced in-place scatter_add_()
   with scatter_add() to avoid gradient graph issues
4. unsloth/models/llama.py: Disable USE_MPS_FALLBACK when gradient checkpointing
   is enabled on MPS to use standard PyTorch operations instead
5. unsloth/models/loader.py: Same fix as llama.py for fast model loading path

By disabling USE_MPS_FALLBACK when gradient checkpointing is enabled, training
falls back to standard PyTorch operations which correctly handle gradient
checkpointing on MPS.

Fixes: Training failures on Apple Silicon with gradient_checkpointing='unsloth'
Creates a new path for MPS training that avoids custom autograd functions
which have compatibility issues with PyTorch's gradient checkpointing.

Changes:
- unsloth/kernels/mps/lora_pytorch.py: New pure PyTorch LoRA implementations
  that use standard F.linear operations instead of custom autograd functions.
  Implements SwiGLU, GeGLU (exact and approx), QKV, and O projections.

- unsloth/kernels/mps/dispatch.py: Updated all dispatch functions to use
  pure PyTorch implementations as Priority 3 when on MPS and custom fallback
  is disabled or causes issues.

The pure PyTorch implementations avoid the 'element 0 of tensors does not
require grad' error by not using ctx.save_for_backward() and custom backward
functions, which don't work well with gradient checkpointing on MPS.

This provides a working training path on Apple Silicon with gradient
checkpointing enabled.
Add a comprehensive SwiGLU benchmark and correctness suite for Apple Silicon (benchmark_swiglu.py) along with a sample run output (swiglu_bench_compiled). Change MPS dispatch logic to avoid snapshotting the runtime-mutable USE_MPS_FALLBACK flag by importing the parent mps module and using _use_mps_fallback() to read the flag dynamically. Update all conditional checks to call the helper and import unsloth.kernels.mps as _mps_module. Remove an unused import of USE_MPS_FALLBACK from cross_entropy_loss.py.
The LoRA A and B matrices are stored as [rank, hidden_dim] and [out_dim, rank].
The computation should be X @ A.T @ B.T, which F.linear already handles.
Removed incorrect .t() calls that caused shape mismatch (16x2048 and 16x2048).
Add pure PyTorch RoPE implementation and use it when MPS fallback is
disabled (gradient checkpointing). Previously fell through to
fast_rope_embedding which called dispatch_rope_embedding again.
Use F.linear instead of torch.matmul/addmm to avoid potential autograd
issues on MPS. F.linear handles the weight transpose internally and
maintains gradient flow properly.
Always use mps_cross_entropy_loss on MPS regardless of fallback status.
When fallback was disabled, it fell through to fast_cross_entropy_loss
which called dispatch_cross_entropy_loss again causing infinite recursion.
…tibility\n\nWhen USE_MPS_FALLBACK=False (gradient checkpointing enabled), custom\ntorch.autograd.Function classes conflict with gradient checkpointing.\nThis adds Priority 3 pure PyTorch fallbacks for:\n\n- dispatch_cross_entropy_loss: uses F.cross_entropy instead of MPSCrossEntropyLoss\n- dispatch_swiglu_fg/backward: uses F.silu + standard ops\n- dispatch_geglu_exact_forward/backward: uses F.gelu + standard ops\n- dispatch_geglu_approx_forward/backward: uses F.gelu(approximate='tanh') + standard ops\n\nAlso adds one-time diagnostic logging to dispatch_lora_mlp_swiglu to\nverify USE_MPS_FALLBACK state at runtime.\n\nFixes RuntimeError: element 0 of tensors does not require grad"
Pass use_gradient_checkpointing=False to both from_pretrained and
get_peft_model. The from_pretrained call defaults to 'unsloth' which
automatically disables MPS fallback kernels on Apple Silicon.
…eakage\n\nTests each component independently:\n- Basic autograd on MPS\n- Raw model forward (pre-LoRA)\n- LoRA forward with module-level hooks\n- Patched cross-entropy loss\n- Full model+labels forward\n- SFTTrainer training step\n\nThis will pinpoint exactly where requires_grad is lost."
Introduce half4 (128-bit) SIMD vectorization for Metal SwiGLU kernels (v9): process 4 fp16 elements per thread with a scalar tail path, rename kernels, and adjust grid sizing accordingly. Add mx.compile-based compiled forward/backward wrappers to let MLX fuse operations at graph level. Update benchmarks to report mx.synchronize availability and include a compiled Metal benchmark. Prefer mx.synchronize() when available (MLX >= 0.30) as the MLX synchronization barrier in bridge and context exit for lower overhead.
…ph\n\nThe unsloth_fused_ce_loss function uses torch.func which causes 'iteration over 0-d tensor' errors on MPS and breaks the gradient graph. This commit adds a check for DEVICE_TYPE == 'mps' in llama.py to fall back to standard torch.nn.functional.cross_entropy, ensuring correct gradient propagation.
…nThe dispatch_rms_layernorm function was calling the metal kernel directly which bypasses autograd. Changed to use Metal_RMSLayerNorm.apply to ensure gradients propagate correctly. Also removed temporary debug logging from llama.py.
- Updated MockModule to provide real nn.Module instances when called, satisfying PyTorch's child module checks.
- Made MockModule.__getattr__ strictly selective to prevent interference with other libraries (like transformers).
- Implemented __mro_entries__ to return (nn.Module,) for better inheritance support.
- Refined UnslothMockFinder to avoid overwriting existing real modules.
…n-the-fly attribute injection

- Replaced static mocking with a recursive MetaPathFinder (UnslothMockFinder) and PatchedLoader.
- Automatically injects ghost attributes (Linear4bit, Linear) into real peft and bitsandbytes modules upon import.
- Robustly redefined MockModule to inherit from nn.Module and handle complex dynamic attribute resolution.
- This truly 'automatic' approach eliminates the need for manual file-by-file fixes.
- Removed the explicit {} mock for triton.backends which caused AttributeErrors when accessing .compiler.
- Let MockModule handle .backends and .compiler dynamically, while providing __contains__ to satisfy PyTorch membership checks.
- Updated MockModule.__getattr__ to return a dummy version string '3.0.0' for __version__.
- This fixes the AttributeError: __version__ encountered when torch.inductor attempts to check triton's version on Mac.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant