Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

Not For Review

@github-actions
Copy link

github-actions bot commented Oct 28, 2025

Review updated until commit fcd6940

Description

  • Add TMA support for pointwise operations with multi-wave and warp specialization

    • Implement scheduling heuristics and Python bindings
    • Enable 1D and 2D TMA variants
  • Introduce CLC launch attributes and circular buffer scheduling

    • Support cluster execution and synchronization
  • Fix TMA store synchronization and WAR hazards

    • Update fence handling and validation logic
  • Add comprehensive tests and documentation

    • Include Python scheduling examples and test kernels

Changes walkthrough 📝

Relevant files
Tests
3 files
test_pointwise.cpp
Add TMA tests for pointwise operations with multi-wave and warp
specialization
+586/-0 
test_circular_buffering.cpp
Rename test to BarSyncTmaPointwise                                             
+1/-1     
test_clc.cu
Add CLC test kernel for cluster launch control                     
+276/-0 
Enhancement
9 files
pointwise.cpp
Implement TMA pointwise scheduling heuristics and multi-wave TMA
support
+106/-0 
predicate.cpp
Modify predicate handling for TMA pointwise operations     
+19/-2   
python_bindings.cpp
Add Python bindings for TMA operations and circular buffer types
+23/-0   
schedule_bindings.cpp
Add circular_buffer scheduling binding with type parameter
+19/-0   
executor.cpp
Add CLC launch attribute support for cluster execution     
+9/-0     
enum.cpp
Add tma_1d enum binding for Python frontend                           
+1/-0     
pointwise_heuristic.h
Add TMA-related parameters to PointwiseParams                       
+6/-0     
pointwise_tma.py
Add 2D TMA scheduling for pointwise operations                     
+458/-0 
pointwise_1dtma.py
Add 1D TMA scheduling for pointwise operations                     
+460/-0 
Bug fix
3 files
insert_syncs.cpp
Add TMA store synchronization and WAR hazard handling       
+63/-14 
tma.cpp
Disable boxDim validation for TMA operations                         
+9/-9     
kernel_ir.cpp
Fix string representation of FenceAsyncProxy and WgMmaFence
+6/-2     
Configuration changes
3 files
options.cpp
Add TmaPointwise enable option                                                     
+1/-0     
options.h
Declare TmaPointwise enable option                                             
+1/-0     
.lintrunner.toml
Exclude Python scheduling docs from linting                           
+6/-0     
Documentation
3 files
pointwise_tma_ws.py
Add Python scheduling example for TMA pointwise with warp
specialization
+572/-0 
pointwise_1dtma_ws.py
Add Python scheduling example for 1D TMA pointwise operations
+503/-0 
ws_tma.cu
Add warp-specialized TMA kernel example                                   
+71/-0   
Additional files
10 files
__tmp_nvfuser_pointwise_f0_c1_r0_g0.cu +12767/-0
ldg_clc.cu +12820/-0
multiwave_tma.cu +12868/-0
multiwave_tma_revised.cu +12861/-0
mw_1dtma.cu +12803/-0
mw_tma_store.cu +12785/-0
ws_tma_revised.cu +12847/-0
ldg_clc.cu +12820/-0
mw_1dtma_clc.cu +12840/-0
tma1d_clc.cu +13053/-0

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

TMA Store Sync

The handling of TMA stores in the WAR hazard prevention logic may need validation, particularly the insertion of commit/wait operations after the BlockSync in outer loops and the use of wait<0> for all pending stores.

// Special handling for TMA stores (CpAsyncBulk S2G operations):
// Unlike TMA loads and WgMma ops, TMA stores read from shared memory
// asynchronously and write to global memory. Their inputs (shared
// memory buffers) may not be tracked in async_inputs_in_current_scope_
// because they originate from compute operations, not async loads.
// However, we still need to insert commit/wait to prevent WAR hazards
// where the next loop iteration overwrites shared memory before the TMA
// store completes.
bool is_tma_store = ir_utils::isCpAsyncBulkStore(expr);

// If the input of the async op is not in the current scope, then this
// async op is not related, so nothing to protect.
if (!is_wgmma_epilogue && !is_tma_store &&
    is_async_inputs_not_present) {
  it++;
  continue;
}

int64_t pending_ops = getPendingOpsFor(expr, for_loop);

// For TMA stores, we must wait for ALL pending stores (wait<0>) to
// complete before the next iteration overwrites the shared memory
// buffer. This prevents WAR (Write-After-Read) hazards where shared
// memory is reused before async stores finish reading from it.
if (is_tma_store) {
  pending_ops = 0;
}
TMA Heuristics

The TMA pointwise heuristics use hardcoded tiling parameters without runtime tuning, which may not be optimal across different hardware and tensor sizes.

// TODO: tune three tiling sizes for the best performance
if (isOptionEnabled(EnableOption::TmaPointwise)) {
  params->tag = "TMA pointwise heuristics";
  params->use_tma_load = true;
  params->use_tma_store = true;
  // TMA tile
  params->tma_tile_inner = 256;
  params->tma_tile_outer =
      std::max(ceilDiv(n_elems, params->tma_tile_inner), (int64_t)64);

  // thread tile
  params->vectorization_factor =
      (int64_t)kOneHundredTwentyEight / max_dtype_size_bit_for_vectorization;
  params->unroll_factor_outer = 2;

  // block tile
  params->lparams.bind(32, ParallelType::TIDx);
  params->lparams.bind(4, ParallelType::TIDy);
  return params;
Predicate Handling

The conditional predicate generation skips predicates for TMA operations, which may affect correctness in cases where vectorization predicates are needed for shared memory operations.

bool skip_predicate = has_nd_tma_ &&
    expr->predicate()->predicate_type() == PredicateType::Inline;

// Replace expr predicate with bool conditional
auto conditional = skip_predicate
    ? GpuLower::current()->kernel()->trueVal()
    : generateConditional(expr->predicate());

if (expr->predicate()->predicate_type() == PredicateType::Vectorize) {
  if (expr->isA<kir::IfThenElse>()) {
    // TODO: This logic doesn't seem to fit well here, for unswitch the
    // logic is in the unroll loop to set the thread predicate to the
    // expr. I didn't have a quick way to do that so placing this here for
    // now.
    auto ite = expr->as<kir::IfThenElse>();

    NVF_ERROR(
        ite->thenBody().size() == 1,
        "Expecting predicated body to only have one vectorized "
        "expression.");
    auto vec_expr = ite->thenBody()[0];
    NVF_ERROR(
        vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>() ||
            vec_expr->isA<TernaryOp>() || vec_expr->isA<IndexSelectOp>(),
        "Vectorize predicate exprs only supported on set operations.");
    NVF_ERROR(
        ir_utils::isTvOp(vec_expr),
        "Vectorize predicate exprs only supported on tensor view "
        "operations.");

    // load from smem still needs predicate unless heuristic ensures
    // divisible by threads count.
    if (false && has_nd_tma_ &&
        vec_expr->outputs()[0]
                ->as<kir::TensorIndex>()
                ->view()
                ->getMemoryType() != MemoryType::Global) {
      conditional = GpuLower::current()->kernel()->trueVal();
    } else if (!vec_expr->inputs()[0]->isConstScalar()) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants