- 
                Notifications
    
You must be signed in to change notification settings  - Fork 68
 
manual schedule a pointwise fusion using multi-wave tma #5444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| 
          
 Review updated until commit 0caba7c Description
 Changes walkthrough 📝
 PR Reviewer Guide 🔍Here are some key observations to aid the review process: 
  | 
    
| 
           !test  | 
    
| 
           !test  | 
    
| } | ||
| } | ||
| // Inline most tensors | ||
| inlineMost(); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| inlineMost(); | |
| // inlineMost(&fusion); // Removed undefined function call | 
[Suggested by AI] The change simply removes the call to the (apparently undefined) function inlineMost()—now commented out—to eliminate the resulting build/undefined-reference error. No other logic was modified.
| "along each of the tensorRank dimensions, must be non-zero and less " | ||
| "than or equal to 256. box_dim_val = 512"))); | ||
| } | ||
| } // namespace nvfuser | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @rdspring1 and @naoyam, what do you think about this comment?
This test demonstrates that we cannot use nd-tma to load more than 256 elements from a 1D tensor. The dimension in nd-tma corresponds to the logical domain, and we lack the flexibility to merge or split logical domains to form loop domains that could be parallelized with TMA.
It also doesn’t seem like a good idea to use reshape to alter the logical domains. We might be better off reverting to using a 1D TMA instead or keep both 1D and nD TMA and check performance difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each individual multi-dimensional tma load has a 256 limit for each box dimension. However, you can issue multiple tma loads for a given mbarrier.
We have to do this for matmuls to load a (256, 256) tile with 4x (256, 64) tma loads to avoid bank conflicts.
if warp == load_warp:
    mbarrier::wait(load-mbarrier)
    for i in range(4):
        tma-load([256])
    arriveExpectTx(load-mbarrier, 256 * 2)
elif warp == compute_warp:
    mbarrier::wait(load-mbarrier)
    # compute 256 * 2 tile
    mbarrier::arrive(load-mbarrier)
Demo how do we want to schedule a pointwise fusion using multi-wave tma.
Will be added to the auto scheduler.
See design doc for details.
IR and Kernel for
PointwiseMultiWaveTMATest.PointwiseMulMultiWaveTMA/WithTMAStore_WithUnrollIR and Kernel for
PointwiseMultiWaveTMATest.PointwiseMulMultiWaveTMA/WithTMAStore_WithUnroll