-
Notifications
You must be signed in to change notification settings - Fork 213
fix(torchx): better mps support #1652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@josevalim thoughts on the f64 issue mentioned on the PR description? |
torchx/lib/torchx/backend.ex
Outdated
| @impl true | ||
| def optional(function_name, args, default_impl) do | ||
| # For MPS device, some linear algebra operations are not supported | ||
| # Delegate to default implementation which will fall back to BinaryBackend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
binary backend? or you mean default implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this was the LLM getting confused. This falls back to elementary Nx operations, not the binary backend.
torchx/lib/torchx.ex
Outdated
| target_device_struct = torch_device!(user_device, index) | ||
|
|
||
| tensor_to_move = | ||
| if user_device == :mps do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed we should just let it raise on f64.
torchx/lib/torchx/backend.ex
Outdated
| device = device_option(backend_options) | ||
| torch_type = to_torch_type(type, device) | ||
|
|
||
| # Handle type downgrading for MPS - need to convert binary data format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just let it raise here too. Convert the tests to f32 if necessary (I think we changed the overall defaults to f32 a long time ago).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f32 was always the default. the vast majority of tests we had were actually for-generated, so it was very easy to get all tests to green out
torchx/test/torchx/nx_test.exs
Outdated
|
|
||
| # TODO: MPS uses different rounding rules (half-to-even vs half-away-from-zero) | ||
| # Need to investigate if this can be fixed or if tests need to account for it | ||
| @tag :skip_on_mps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't call it :skip_on_maps, but rather, :round_up, :requires_f64, etc. And then exclude those if the device is mps.
torchx/test/test_helper.exs
Outdated
| # Tests must run synchronously to avoid GPU framework crashes | ||
| mps_opts = | ||
| if device_is_mps do | ||
| [max_cases: 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@josevalim do you think we should try to add some device lock mechanism to torchx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we would need to understand why it happens. Are the failures due to mutation or just the lack of queueing in the device itself? What does PyTorch do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The errors I was getting were signaling something about "device in use", so I don't think it's mutation happening.
I'll see if I can find what pytorch does
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The errors I was getting were signaling something about "device in use", so I don't think it's mutation happening.
I'll see if I can find what pytorch does
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not getting the errors anymore, so maybe we can kick this can down the road.
I did find that there are a few mechanisms we could use, but generally MPS assumes a single command queue per os process.
closes #679
closes #1608
This PR is currently masking f64 as f32, however, I think it would be better if we instead just raise whenever f64/c128 show up.