Skip to content

Error when trying to feature native MLX accelerated scvi #3268

@c0nleyinnnn

Description

@c0nleyinnnn

Description:

Before the recent update to PyTorch's MPS support, using PyTorch MPS to accelerate scVI would result in NaN values in the returned matrix. For more details, see: Error when training model on M3 Max MPS.

I attempted to port the MLX framework to scVI by following the simple-scvi guide. I rewrote _mlxvae.py, _mlxscvi.py, and _mlxmixin.py to enable backend calls to the MLX framework for Metal acceleration. This decision was influenced by findings that the MLX framework can invoke Metal GPU computations at a higher frequency compared to PyTorch MPS. For more information, see: phi2-llm-on-MLX-vs-Pytorch-MPS.

With the help of various development tools and AI, the code now runs in Python using MLX. However, the returned latent matrix exhibits a similar issue to the previous PyTorch MPS problem:

latent
array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]])

Due to my limited mathematical and programming skills, I am unable to resolve this issue or even identify where to start. A previous issue mentioned problems with lgamma, but I couldn't find the corresponding part in my code.

Additionally, the reason I attempted this port is that the batch effect removal in scVI accelerated by PyTorch MPS differs from the results obtained using CUDA acceleration on my Windows host. The current MLX port runs in almost the same time as PyTorch MPS on the same test data but is significantly faster than the jaxscvi function. This might be due to my lack of development expertise, as I cannot ensure proper compilation or logical efficiency.

If anyone can provide insights, assistance, or even take over this project, that would be fantastic. I am willing to contribute the existing code for free, and I hope someone can help develop more efficient and stable single-cell omics analysis tools for Apple Silicon.

Additional Context:

  • The issue with lgamma was mentioned in a previous discussion, but I couldn't locate it in my code.
  • The performance of the MLX port is comparable to PyTorch MPS but faster than jaxscvi, though this might be due to suboptimal development practices on my part.

Any help or collaboration would be greatly appreciated!

current works:https://github.com/c0nleyinnnn/mlxSCVI

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions