Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mx: triton kernel to cast to mx and write in col-major #1932

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Mar 21, 2025

Summary:

Implements a triton kernel for a cast to mxfp8 from a row-major input across dim1, which is 3.5x to 4.5x faster than what compile can generate today. Note that this is a prototype kernel, and I expect to (a) improve it in future PRs and (b) delete it in ~weeks when we have compile support for this.

An integration into MXLinear will follow in a separate PR.

Example of tiling (simplified for small example size):

        Example tiling for n_rows==8, n_cols=8, ROW_TILE_SIZE=4, COL_TILE_SIZE=4, INNER_BLOCK_SIZE=2,
        pid_row=0, pid_col=0:

        Input (row-major)

        cols      0  1  2  3  4  5  6  7
        --------------------------------
        rows 0 |  0  1  2  3
             1 |  8  9 10 11
             2 | 16 17 18 19
             3 | 24 25 26 27
             4 |
             5 |
             6 |
             7 |

        Output (row-major of transpose), ids are from input

        cols      0  1  2  3  4  5  6  7
        --------------------------------
        rows 0 |  0  8 16 24
             1 |  1  9 17 25
             2 |  2 10 18 26
             3 |  3 11 19 27
             4 |
             5 |
             6 |
             7 |

        Output (scales), s(0, 8) means the scale used to cast elements 0 and 8

        rows           0          1  ...      4  ...       31
        ------------------------------------------------------
                  s(0, 8)  s(16, 24) ... s(1, 9) ... s(19, 27)

Test Plan:

// tests pass
pytest test/prototype/mx_formats/test_custom_cast.py -s -x -k triton_mxfp8_dim1

// performance compile vs triton: https://www.internalfb.com/phabricator/paste/view/P1762691809
// * 4k by 4k tensor: about a 3.6x speedup
// * 16k by 16k tensor: about a 4.7x speedup

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Mar 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1932

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 7ecd79f with merge base 3fb1665 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4c77bd3df692755ea9659a1a3c9f396778ba3f30
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 21, 2025
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 18073c56689a61e6bece5a9311d2e326f410a1a4
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c2d2aed9482c1058e5899d84bce3e3959a7f7d46
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 91248bc71f4da1fd8d0925541f4a39d5becb8a07
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3a53ae76b2029887c77ec1b9b9faf87aa0902758
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 76fef6c6d8dc26bfd18ea9edea9fb7b2f5a43f7d
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 84869136b5270214c933e7cdd28460a77a30e57f
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c44a9dbf3d63172254ce2a995bca7f44b8e5f2b9
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d5904016d9ac76c8414848fa7316923597d140e3
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 26e84f93e12f87df4a1707ebe1d168bcf45db790
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
@vkuzo vkuzo changed the title [wip] triton kernel to cast to mx and write in col-major mx: triton kernel to cast to mx and write in col-major Mar 21, 2025
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 95105a696b1200ce2f6785a5d7e626f651cdc3de
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
@vkuzo vkuzo added the topic: performance Use this tag if this PR improves the performance of a feature label Mar 21, 2025
@eellison eellison self-requested a review March 21, 2025 16:47
# example transformation (specifics depend on tile sizes):
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
col_scale_indices = col_scale_indices + (
tl.floor(col_scale_indices / BLOCKS_PER_ROW_TILE) * jump_vals_per_col
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should just be doing integer division instead of floor + /

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good ! next step: compile to generate this

).to(tl.int32)

# TODO(future): mask this store
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the launcher, should we assert divisibility of block sizes, so we hard error for this case ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for now I hackly assert that on L1319:L1322

)

return (
output_col_major.t(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since only the data_ptr of output_col_major is used when you pass it into triton, you could initialize it with the correct strides


return (
output_col_major.t(),
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here..

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 24, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d1c1db7c22b2640c6ee43365832a5194abd67e48
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 24, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6c795ce09e43f864d88e1e2e2ef0ae363510a223
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: performance Use this tag if this PR improves the performance of a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants