Skip to content

Conversation

@polvalente
Copy link
Contributor

closes #679
closes #1608

This PR is currently masking f64 as f32, however, I think it would be better if we instead just raise whenever f64/c128 show up.

@polvalente polvalente self-assigned this Dec 28, 2025
@polvalente
Copy link
Contributor Author

@josevalim thoughts on the f64 issue mentioned on the PR description?

@impl true
def optional(function_name, args, default_impl) do
# For MPS device, some linear algebra operations are not supported
# Delegate to default implementation which will fall back to BinaryBackend
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

binary backend? or you mean default implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was the LLM getting confused. This falls back to elementary Nx operations, not the binary backend.

target_device_struct = torch_device!(user_device, index)

tensor_to_move =
if user_device == :mps do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed we should just let it raise on f64.

device = device_option(backend_options)
torch_type = to_torch_type(type, device)

# Handle type downgrading for MPS - need to convert binary data format
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just let it raise here too. Convert the tests to f32 if necessary (I think we changed the overall defaults to f32 a long time ago).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f32 was always the default. the vast majority of tests we had were actually for-generated, so it was very easy to get all tests to green out


# TODO: MPS uses different rounding rules (half-to-even vs half-away-from-zero)
# Need to investigate if this can be fixed or if tests need to account for it
@tag :skip_on_mps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't call it :skip_on_maps, but rather, :round_up, :requires_f64, etc. And then exclude those if the device is mps.

@polvalente polvalente marked this pull request as ready for review December 28, 2025 20:45
# Tests must run synchronously to avoid GPU framework crashes
mps_opts =
if device_is_mps do
[max_cases: 1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josevalim do you think we should try to add some device lock mechanism to torchx?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would need to understand why it happens. Are the failures due to mutation or just the lack of queueing in the device itself? What does PyTorch do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The errors I was getting were signaling something about "device in use", so I don't think it's mutation happening.

I'll see if I can find what pytorch does

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The errors I was getting were signaling something about "device in use", so I don't think it's mutation happening.

I'll see if I can find what pytorch does

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not getting the errors anymore, so maybe we can kick this can down the road.
I did find that there are a few mechanisms we could use, but generally MPS assumes a single command queue per os process.

@polvalente polvalente requested a review from josevalim December 28, 2025 21:15
@polvalente polvalente merged commit 1781659 into main Dec 28, 2025
9 checks passed
@polvalente polvalente deleted the pv-fix/mps-support branch December 28, 2025 22:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Nx.Random.uniform(key) error when using mps backend in Torchx. Allocate intermediate tensors in the same device as the input tensor

3 participants