feat: Add BF16 float-accumulator TensorOp epilogue specialization#3287
Open
lyuxy-infra wants to merge 1 commit into
Open
feat: Add BF16 float-accumulator TensorOp epilogue specialization#3287lyuxy-infra wants to merge 1 commit into
lyuxy-infra wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds a
DefaultIteratorsTensorOp<bfloat16_t, float, 8, ...>specialization for TensorOp epilogues.CUTLASS already has a
half_t, float, 8specialization that usesTileIteratorTensorOpMixedandSharedLoadIteratorMixedto optimize mixed-precision epilogues with FP32 accumulators and 16-bit outputs. BF16 output with FP32 accumulators has the same 32-bit accumulator / 16-bit output / 8-elements-per-access structure, but currently falls back to the generic iterator path.This patch mirrors the existing
half_t, float, 8specialization forbfloat16_t, float, 8.Motivation
For mixed-precision TensorOp epilogues where accumulators are FP32 and outputs are 16-bit, the mixed iterator path uses a shared-memory layout designed to avoid bank conflicts. BF16 output should be able to use the same iterator structure as FP16 output.
Changes
DefaultIteratorsTensorOp<bfloat16_t, float, 8, ...>TileIteratorTensorOpMixed<..., float, 32, 16, 8, 8>SharedLoadIteratorMixed<..., float, 32, 16, 8, 8>kFragmentsPerIteration = 2, matching the existinghalf_t, float, 8specializationNotes
This does not change the output operator, numerical conversion, or GEMM mainloop. It only changes the epilogue shared-memory staging iterator selection for this BF16 mixed-precision case.