Open
Description
It would be fantastic to see support for Metal GPU on Apple Silicon in Wan2.1. This enhancement could significantly improve performance for users on macOS devices. This would involve:
- Replacing CUDA device initialization with MPS (e.g.,
torch.device("mps" if torch.backends.mps.is_available() else "cpu")
). - Adjusting precision from float64 to float32 in areas like sinusoidal embeddings and rotary encodings.
- Handling random generation by creating tensors on CPU before moving them to MPS.
- Replacing CUDA-specific calls (e.g.,
torch.cuda.empty_cache()
,torch.cuda.amp.autocast
) with conditional alternatives. - Adding a compatibility layer during model loading for device differences.
I’m willing to help contribute to this effort and collaborate with anyone interested in making this happen.
Metadata
Metadata
Assignees
Labels
No labels