-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mx: triton kernel to cast to mx and write in col-major #1932
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1932
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 7ecd79f with merge base 3fb1665 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 4c77bd3df692755ea9659a1a3c9f396778ba3f30 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 18073c56689a61e6bece5a9311d2e326f410a1a4 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c2d2aed9482c1058e5899d84bce3e3959a7f7d46 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 91248bc71f4da1fd8d0925541f4a39d5becb8a07 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3a53ae76b2029887c77ec1b9b9faf87aa0902758 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 76fef6c6d8dc26bfd18ea9edea9fb7b2f5a43f7d ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 84869136b5270214c933e7cdd28460a77a30e57f ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c44a9dbf3d63172254ce2a995bca7f44b8e5f2b9 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d5904016d9ac76c8414848fa7316923597d140e3 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 26e84f93e12f87df4a1707ebe1d168bcf45db790 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 95105a696b1200ce2f6785a5d7e626f651cdc3de ghstack-comment-id: 2743450537 Pull Request resolved: #1932
# example transformation (specifics depend on tile sizes): | ||
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] | ||
col_scale_indices = col_scale_indices + ( | ||
tl.floor(col_scale_indices / BLOCKS_PER_ROW_TILE) * jump_vals_per_col |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we should just be doing integer division instead of floor + /
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good ! next step: compile to generate this
).to(tl.int32) | ||
|
||
# TODO(future): mask this store | ||
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the launcher, should we assert divisibility of block sizes, so we hard error for this case ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, for now I hackly assert that on L1319:L1322
) | ||
|
||
return ( | ||
output_col_major.t(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since only the data_ptr of output_col_major
is used when you pass it into triton, you could initialize it with the correct strides
|
||
return ( | ||
output_col_major.t(), | ||
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same thing here..
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d1c1db7c22b2640c6ee43365832a5194abd67e48 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6c795ce09e43f864d88e1e2e2ef0ae363510a223 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
Summary:
Implements a triton kernel for a cast to mxfp8 from a row-major input across dim1, which is 3.5x to 4.5x faster than what compile can generate today. Note that this is a prototype kernel, and I expect to (a) improve it in future PRs and (b) delete it in ~weeks when we have compile support for this.
An integration into
MXLinear
will follow in a separate PR.Example of tiling (simplified for small example size):
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: