Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Oct 28, 2025

Demo how do we want to schedule a pointwise fusion using multi-wave tma.
Will be added to the auto scheduler.
See design doc for details.
IR and Kernel for PointwiseMultiWaveTMATest.PointwiseMulMultiWaveTMA/WithTMAStore_WithUnroll

Inputs:
  T0_g___bfloat[iS0{8192}, iS1{8192}]
  T1_g___bfloat[iS2{8192}, iS3{8192}]
Outputs:
  T5_g___bfloat[iblockIdx.y102{128}, iblockIdx.x104{64}, iS110{4}, iS112{1}, ithreadIdx.y111{8}, ithreadIdx.x113{16}, iS107{2}, iV109{8}] ca_pos( 7 ) produce_pos( 7 )

%kernel {
T6_s___bfloat[iblockIdx.y22{128}, iblockIdx.x24{64}, iB23{64}, iB25{128}] ca_pos( 2 )
   = CpAsyncBulkTensorTile( T0_g___bfloat[iS0{8192}, iS1{8192}] )
T8_l___bfloat[iblockIdx.y30{128}, iblockIdx.x32{64}, iS38{4}, iS40{1}, ithreadIdx.y39{8}, ithreadIdx.x41{16}, iS35{2}, iV37{8}] ca_pos( 7 ) produce_pos( 2 )
   = Set( T6_s___bfloat[iblockIdx.y22{128}, iblockIdx.x24{64}, iB23{64}, iB25{128}] ca_pos( 2 ), cache_op=Streaming )
T2_l_float[iblockIdx.y54{128}, iblockIdx.x56{64}, iS62{4}, iS64{1}, ithreadIdx.y63{8}, ithreadIdx.x65{16}, iS59{2}, iS61{8}] ca_pos( 8 ) produce_pos( 7 )
   = __bfloat2float(T8_l___bfloat[iblockIdx.y30{128}, iblockIdx.x32{64}, iS38{4}, iS40{1}, ithreadIdx.y39{8}, ithreadIdx.x41{16}, iS35{2}, iV37{8}] ca_pos( 7 ) produce_pos( 2 ));
T7_s___bfloat[iblockIdx.y26{128}, iblockIdx.x28{64}, iB27{64}, iB29{128}] ca_pos( 2 )
   = CpAsyncBulkTensorTile( T1_g___bfloat[iS2{8192}, iS3{8192}] )
T9_l___bfloat[iblockIdx.y42{128}, iblockIdx.x44{64}, iS50{4}, iS52{1}, ithreadIdx.y51{8}, ithreadIdx.x53{16}, iS47{2}, iV49{8}] ca_pos( 7 ) produce_pos( 2 )
   = Set( T7_s___bfloat[iblockIdx.y26{128}, iblockIdx.x28{64}, iB27{64}, iB29{128}] ca_pos( 2 ), cache_op=Streaming )
T3_l_float[iblockIdx.y66{128}, iblockIdx.x68{64}, iS74{4}, iS76{1}, ithreadIdx.y75{8}, ithreadIdx.x77{16}, iS71{2}, iS73{8}] ca_pos( 8 ) produce_pos( 7 )
   = __bfloat2float(T9_l___bfloat[iblockIdx.y42{128}, iblockIdx.x44{64}, iS50{4}, iS52{1}, ithreadIdx.y51{8}, ithreadIdx.x53{16}, iS47{2}, iV49{8}] ca_pos( 7 ) produce_pos( 2 ));
T4_l_float[iblockIdx.y78{128}, iblockIdx.x80{64}, iS86{4}, iS88{1}, ithreadIdx.y87{8}, ithreadIdx.x89{16}, iS83{2}, iS85{8}] ca_pos( 8 ) produce_pos( 8 )
   = T2_l_float[iblockIdx.y54{128}, iblockIdx.x56{64}, iS62{4}, iS64{1}, ithreadIdx.y63{8}, ithreadIdx.x65{16}, iS59{2}, iS61{8}] ca_pos( 8 ) produce_pos( 7 )
   * T3_l_float[iblockIdx.y66{128}, iblockIdx.x68{64}, iS74{4}, iS76{1}, ithreadIdx.y75{8}, ithreadIdx.x77{16}, iS71{2}, iS73{8}] ca_pos( 8 ) produce_pos( 7 );
T10_l___bfloat[iblockIdx.y90{128}, iblockIdx.x92{64}, iS98{4}, iS100{1}, ithreadIdx.y99{8}, ithreadIdx.x101{16}, iS95{2}, iS97{8}] ca_pos( 7 ) produce_pos( 8 )
   = __float2bfloat(T4_l_float[iblockIdx.y78{128}, iblockIdx.x80{64}, iS86{4}, iS88{1}, ithreadIdx.y87{8}, ithreadIdx.x89{16}, iS83{2}, iS85{8}] ca_pos( 8 ) produce_pos( 8 ));
T5_g___bfloat[iblockIdx.y102{128}, iblockIdx.x104{64}, iS110{4}, iS112{1}, ithreadIdx.y111{8}, ithreadIdx.x113{16}, iS107{2}, iV109{8}] ca_pos( 7 ) produce_pos( 7 )
   = Set( T10_l___bfloat[iblockIdx.y90{128}, iblockIdx.x92{64}, iS98{4}, iS100{1}, ithreadIdx.y99{8}, ithreadIdx.x101{16}, iS95{2}, iS97{8}] ca_pos( 7 ) produce_pos( 8 ), cache_op=Streaming )

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__bfloat, 2, 2> T0, Tensor<__bfloat, 2, 2> T1, const __grid_constant__ TensorMap var0, const __grid_constant__ TensorMap var1, Tensor<__bfloat, 2, 2> T5) {
  alignas(128) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  const TensorMap* ptr2;
  ptr2 = &var0;
  nvfuser_index_t i3;
  i3 = 128 * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i4;
  i4 = 64 * ((nvfuser_index_t)blockIdx.y);
  Array<int, 2, 1> a5;
  a5 = Array<int, 2, 1>{__to_int32(i3), __to_int32(i4)};
  const TensorMap* ptr6;
  ptr6 = &var1;
  nvfuser_index_t i7;
  i7 = 8 * ((nvfuser_index_t)threadIdx.x);
  nvfuser_index_t i8;
  i8 = i7 + (256 * ((nvfuser_index_t)threadIdx.y));
  nvfuser_index_t i9;
  i9 = ((i7 + (16384 * ((nvfuser_index_t)threadIdx.y))) + i3) + (524288 * ((nvfuser_index_t)blockIdx.y));
  bool b10;
  b10 = (((nvfuser_index_t)threadIdx.x) == 0ULL) && (((nvfuser_index_t)threadIdx.y) == 0ULL);
  bool b11;
  b11 = ((7 + i7) + i3) < 8192;
  nvfuser_index_t i12;
  i12 = (-8192 + (2 * ((nvfuser_index_t)threadIdx.y))) + i4;
  __bfloat* T7 = reinterpret_cast<__bfloat*>(array + smem_offset + 0);
  __bfloat* T6 = reinterpret_cast<__bfloat*>(array + smem_offset + 16512);
  uint64_t* T11 = reinterpret_cast<uint64_t*>(array + smem_offset + 16512);
  mbarrier::init(toSmem(T11), 1U);
  __syncthreads();
  if (b10) {
    uint64_t i13;
    i13 = mbarrier::arriveExpectTX(toSmem(T11), 16384U);
    Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr2, a5, toSmem(T11) }), toSmem(T7));
    mbarrier::wait(toSmem(T11), i13);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T11));
  uint64_t* T12 = reinterpret_cast<uint64_t*>(array + smem_offset + 16384);
  mbarrier::init(toSmem(T12), 1U);
  __syncthreads();
  if (b10) {
    uint64_t i14;
    i14 = mbarrier::arriveExpectTX(toSmem(T12), 16384U);
    Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr6, a5, toSmem(T12) }), toSmem(T6));
    mbarrier::wait(toSmem(T12), i14);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T12));
  #pragma unroll
  for(nvfuser_index_t i15 = 0; i15 < 4; ++i15) {
    nvfuser_index_t i16;
    i16 = i8 + (2048 * i15);
    nvfuser_index_t i17;
    i17 = i9 + (131072 * i15);
    nvfuser_index_t i18;
    i18 = -(16 * i15);
    #pragma unroll
    for(nvfuser_index_t i19 = 0; i19 < 2; ++i19) {
      nvfuser_index_t i20;
      i20 = i16 + (128 * i19);
      Array<__bfloat, 8, 8> T9;
      loadGeneric<__bfloat, 8>( &T9[0],  &T7[i20]);
      Array<__bfloat, 8, 8> T8;
      loadGeneric<__bfloat, 8>( &T8[0],  &T6[i20]);
      // Alias Allocation - register
      auto& T10 = T8;
      #pragma unroll
      for(nvfuser_index_t i21 = 0; i21 < 8; ++i21) {
        Array<float, 1, 1> T3;
        T3[0]
           = __bfloat2float(T9[i21]);
        Array<float, 1, 1> T2;
        T2[0]
           = __bfloat2float(T8[i21]);
        Array<float, 1, 1> T4;
        T4[0]
          = T2[0]
          * T3[0];
        T10[i21]
           = __float2bfloat(T4[0]);
      }
      if ((b11 && (i12 < (i18 - i19)))) {
        loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T5[(i17 + (8192 * i19))], &T10[0]);
      }
    }
  }
}

IR and Kernel for PointwiseMultiWaveTMATest.PointwiseMulMultiWaveTMA/WithTMAStore_WithUnroll

Inputs:
  T0_g___bfloat[iS0{8192}, iS1{8192}]
  T1_g___bfloat[iS2{8192}, iS3{8192}]
Outputs:
  T5_g___bfloat[iblockIdx.y32{128}, iblockIdx.x34{64}, iB33{64}, iB35{128}] ca_pos( 2 ) produce_pos( 2 )

%kernel {
T6_s___bfloat[iblockIdx.y24{128}, iblockIdx.x26{64}, iB25{64}, iB27{128}] ca_pos( 2 )
   = CpAsyncBulkTensorTile( T0_g___bfloat[iS0{8192}, iS1{8192}] )
T8_l___bfloat[iblockIdx.y36{128}, iblockIdx.x38{64}, iS44{4}, iS46{1}, ithreadIdx.y45{8}, ithreadIdx.x47{16}, iUR41{2}, iV43{8}] ca_pos( 6 ) produce_pos( 2 )
   = Set( T6_s___bfloat[iblockIdx.y24{128}, iblockIdx.x26{64}, iB25{64}, iB27{128}] ca_pos( 2 ), cache_op=Streaming )
T2_l_float[iblockIdx.y60{128}, iblockIdx.x62{64}, iS68{4}, iS70{1}, ithreadIdx.y69{8}, ithreadIdx.x71{16}, iUR65{2}, iS67{8}] ca_pos( 6 ) produce_pos( 6 )
   = __bfloat2float(T8_l___bfloat[iblockIdx.y36{128}, iblockIdx.x38{64}, iS44{4}, iS46{1}, ithreadIdx.y45{8}, ithreadIdx.x47{16}, iUR41{2}, iV43{8}] ca_pos( 6 ) produce_pos( 2 ));
T7_s___bfloat[iblockIdx.y28{128}, iblockIdx.x30{64}, iB29{64}, iB31{128}] ca_pos( 2 )
   = CpAsyncBulkTensorTile( T1_g___bfloat[iS2{8192}, iS3{8192}] )
T9_l___bfloat[iblockIdx.y48{128}, iblockIdx.x50{64}, iS56{4}, iS58{1}, ithreadIdx.y57{8}, ithreadIdx.x59{16}, iUR53{2}, iV55{8}] ca_pos( 6 ) produce_pos( 2 )
   = Set( T7_s___bfloat[iblockIdx.y28{128}, iblockIdx.x30{64}, iB29{64}, iB31{128}] ca_pos( 2 ), cache_op=Streaming )
T3_l_float[iblockIdx.y72{128}, iblockIdx.x74{64}, iS80{4}, iS82{1}, ithreadIdx.y81{8}, ithreadIdx.x83{16}, iUR77{2}, iS79{8}] ca_pos( 6 ) produce_pos( 6 )
   = __bfloat2float(T9_l___bfloat[iblockIdx.y48{128}, iblockIdx.x50{64}, iS56{4}, iS58{1}, ithreadIdx.y57{8}, ithreadIdx.x59{16}, iUR53{2}, iV55{8}] ca_pos( 6 ) produce_pos( 2 ));
T4_l_float[iblockIdx.y84{128}, iblockIdx.x86{64}, iS92{4}, iS94{1}, ithreadIdx.y93{8}, ithreadIdx.x95{16}, iUR89{2}, iS91{8}] ca_pos( 6 ) produce_pos( 6 )
   = T2_l_float[iblockIdx.y60{128}, iblockIdx.x62{64}, iS68{4}, iS70{1}, ithreadIdx.y69{8}, ithreadIdx.x71{16}, iUR65{2}, iS67{8}] ca_pos( 6 ) produce_pos( 6 )
   * T3_l_float[iblockIdx.y72{128}, iblockIdx.x74{64}, iS80{4}, iS82{1}, ithreadIdx.y81{8}, ithreadIdx.x83{16}, iUR77{2}, iS79{8}] ca_pos( 6 ) produce_pos( 6 );
T11_l___bfloat[iblockIdx.y96{128}, iblockIdx.x98{64}, iS104{4}, iS106{1}, ithreadIdx.y105{8}, ithreadIdx.x107{16}, iUR101{2}, iS103{8}] ca_pos( 6 ) produce_pos( 6 )
   = __float2bfloat(T4_l_float[iblockIdx.y84{128}, iblockIdx.x86{64}, iS92{4}, iS94{1}, ithreadIdx.y93{8}, ithreadIdx.x95{16}, iUR89{2}, iS91{8}] ca_pos( 6 ) produce_pos( 6 ));
T10_s___bfloat[iblockIdx.y108{128}, iblockIdx.x110{64}, iS116{4}, iS118{1}, ithreadIdx.y117{8}, ithreadIdx.x119{16}, iUR113{2}, iV115{8}] ca_pos( 2 ) produce_pos( 6 )
   = Set( T11_l___bfloat[iblockIdx.y96{128}, iblockIdx.x98{64}, iS104{4}, iS106{1}, ithreadIdx.y105{8}, ithreadIdx.x107{16}, iUR101{2}, iS103{8}] ca_pos( 6 ) produce_pos( 6 ), cache_op=Streaming )
T5_g___bfloat[iblockIdx.y32{128}, iblockIdx.x34{64}, iB33{64}, iB35{128}] ca_pos( 2 ) produce_pos( 2 )
   = CpAsyncBulkTensorTile( T10_s___bfloat[iblockIdx.y108{128}, iblockIdx.x110{64}, iS116{4}, iS118{1}, ithreadIdx.y117{8}, ithreadIdx.x119{16}, iUR113{2}, iV115{8}] ca_pos( 2 ) produce_pos( 6 ) )
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__bfloat, 2, 2> T0, Tensor<__bfloat, 2, 2> T1, const __grid_constant__ TensorMap var0, const __grid_constant__ TensorMap var1, const __grid_constant__ TensorMap var2, Tensor<__bfloat, 2, 2> T5) {
  alignas(128) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  const TensorMap* ptr3;
  ptr3 = &var0;
  nvfuser_index_t i4;
  i4 = 128 * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i5;
  i5 = 64 * ((nvfuser_index_t)blockIdx.y);
  Array<int, 2, 1> a6;
  a6 = Array<int, 2, 1>{__to_int32(i4), __to_int32(i5)};
  const TensorMap* ptr7;
  ptr7 = &var1;
  nvfuser_index_t i8;
  i8 = 8 * ((nvfuser_index_t)threadIdx.x);
  nvfuser_index_t i9;
  i9 = i8 + (256 * ((nvfuser_index_t)threadIdx.y));
  bool b10;
  b10 = ((nvfuser_index_t)threadIdx.y) == 0ULL;
  bool b11;
  b11 = (((nvfuser_index_t)threadIdx.x) == 0ULL) && b10;
  bool b12;
  b12 = ((7 + i8) + i4) < 8192;
  nvfuser_index_t i13;
  i13 = 2 * ((nvfuser_index_t)threadIdx.y);
  nvfuser_index_t i14;
  i14 = (-8191 + i13) + i5;
  nvfuser_index_t i15;
  i15 = (-8192 + i13) + i5;
  __bfloat* T7 = reinterpret_cast<__bfloat*>(array + smem_offset + 0);
  __bfloat* T6 = reinterpret_cast<__bfloat*>(array + smem_offset + 32896);
  __bfloat* T10 = reinterpret_cast<__bfloat*>(array + smem_offset + 16512);
  uint64_t* T12 = reinterpret_cast<uint64_t*>(array + smem_offset + 16512);
  mbarrier::init(toSmem(T12), 1U);
  __syncthreads();
  if (b11) {
    uint64_t i16;
    i16 = mbarrier::arriveExpectTX(toSmem(T12), 16384U);
    Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr3, a6, toSmem(T12) }), toSmem(T7));
    mbarrier::wait(toSmem(T12), i16);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T12));
  uint64_t* T13 = reinterpret_cast<uint64_t*>(array + smem_offset + 16384);
  mbarrier::init(toSmem(T13), 1U);
  __syncthreads();
  if (b11) {
    uint64_t i17;
    i17 = mbarrier::arriveExpectTX(toSmem(T13), 16384U);
    Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr7, a6, toSmem(T13) }), toSmem(T6));
    mbarrier::wait(toSmem(T13), i17);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T13));
  #pragma unroll
  for(nvfuser_index_t i18 = 0; i18 < 4; ++i18) {
    nvfuser_index_t i19;
    i19 = i9 + (2048 * i18);
    #pragma unroll
    for(nvfuser_index_t i20 = 0; i20 < 2; ++i20) {
      arraySet<__bfloat, 8>(&T10[(i19 + (128 * i20))], (__bfloat)0);
    }
  }
  #pragma unroll
  for(nvfuser_index_t i18 = 0; i18 < 4; ++i18) {
    nvfuser_index_t i21;
    i21 = i9 + (2048 * i18);
    nvfuser_index_t i22;
    i22 = -(16 * i18);
    bool b23;
    b23 = b12 && (i14 < i22);
    Array<__bfloat, 16, 8> T8;
    #pragma unroll
    for(nvfuser_index_t i24 = 0; i24 < 2; ++i24) {
      T8.set(__bfloat(0));
    }
    if (b23) {
      #pragma unroll
      for(nvfuser_index_t i24 = 0; i24 < 2; ++i24) {
        loadGeneric<__bfloat, 8>( &T8[(8 * i24)],  &T6[(i21 + (128 * i24))]);
      }
    } else {
      #pragma unroll
      for(nvfuser_index_t i24 = 0; i24 < 2; ++i24) {
        if ((b12 && (i15 < (i22 - i24)))) {
          loadGeneric<__bfloat, 8>( &T8[(8 * i24)],  &T6[(i21 + (128 * i24))]);
        }
      }
    }
    Array<float, 16, 1> T2;
    #pragma unroll
    for(nvfuser_index_t i25 = 0; i25 < 2; ++i25) {
      nvfuser_index_t i26;
      i26 = 8 * i25;
      #pragma unroll
      for(nvfuser_index_t i27 = 0; i27 < 8; ++i27) {
        nvfuser_index_t i28;
        i28 = i26 + i27;
        T2[i28]
           = __bfloat2float(T8[i28]);
      }
    }
    Array<__bfloat, 16, 8> T9;
    #pragma unroll
    for(nvfuser_index_t i29 = 0; i29 < 2; ++i29) {
      T9.set(__bfloat(0));
    }
    if (b23) {
      #pragma unroll
      for(nvfuser_index_t i29 = 0; i29 < 2; ++i29) {
        loadGeneric<__bfloat, 8>( &T9[(8 * i29)],  &T7[(i21 + (128 * i29))]);
      }
    } else {
      #pragma unroll
      for(nvfuser_index_t i29 = 0; i29 < 2; ++i29) {
        if ((b12 && (i15 < (i22 - i29)))) {
          loadGeneric<__bfloat, 8>( &T9[(8 * i29)],  &T7[(i21 + (128 * i29))]);
        }
      }
    }
    Array<float, 16, 1> T3;
    #pragma unroll
    for(nvfuser_index_t i30 = 0; i30 < 2; ++i30) {
      nvfuser_index_t i31;
      i31 = 8 * i30;
      #pragma unroll
      for(nvfuser_index_t i32 = 0; i32 < 8; ++i32) {
        nvfuser_index_t i33;
        i33 = i31 + i32;
        T3[i33]
           = __bfloat2float(T9[i33]);
      }
    }
    Array<float, 16, 1> T4;
    #pragma unroll
    for(nvfuser_index_t i34 = 0; i34 < 2; ++i34) {
      nvfuser_index_t i35;
      i35 = 8 * i34;
      #pragma unroll
      for(nvfuser_index_t i36 = 0; i36 < 8; ++i36) {
        nvfuser_index_t i37;
        i37 = i35 + i36;
        T4[i37]
          = T2[i37]
          * T3[i37];
      }
    }
    // Alias Allocation - register
    auto& T11 = T9;
    #pragma unroll
    for(nvfuser_index_t i38 = 0; i38 < 2; ++i38) {
      nvfuser_index_t i39;
      i39 = 8 * i38;
      #pragma unroll
      for(nvfuser_index_t i40 = 0; i40 < 8; ++i40) {
        nvfuser_index_t i41;
        i41 = i39 + i40;
        T11[i41]
           = __float2bfloat(T4[i41]);
      }
    }
    if (b23) {
      #pragma unroll
      for(nvfuser_index_t i20 = 0; i20 < 2; ++i20) {
        loadGeneric<__bfloat, 8>( &T10[(i21 + (128 * i20))],  &T11[(8 * i20)]);
      }
    } else {
      #pragma unroll
      for(nvfuser_index_t i20 = 0; i20 < 2; ++i20) {
        if ((b12 && (i15 < (i22 - i20)))) {
          loadGeneric<__bfloat, 8>( &T10[(i21 + (128 * i20))],  &T11[(8 * i20)]);
        }
      }
    }
  }
  __syncthreads();
  fenceAsyncProxy();
  if (b10) {
    Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ (&var2), a6 }), toSmem(T10));
  }
  cpAsyncBulkCommitGroup();
  cpAsyncBulkWaitGroup<0LL>();
}

@github-actions
Copy link

github-actions bot commented Oct 28, 2025

Review updated until commit 0caba7c

Description

  • Add parameterized test for pointwise fusion with multi-wave TMA

  • Implement test for invalid 2D TMA load on 1D tensor

  • Enhance scheduling with TMA load/store and unroll support

  • Improve error handling for TMA dimension constraints


Changes walkthrough 📝

Relevant files
Bug fix
test_memory.cpp
Add test for 1D tensor TMA load error                                       

tests/cpp/test_memory.cpp

  • Add new test NdTmaLoad1dTensor to validate 1D tensor loading via TMA
  • Include ir/base_nodes.h for required IR node definitions
  • Test expects failure when using 2D TMA on 1D tensor with dim=512
  • Guard test to run only on CUDA arch 9.0+
  • +34/-0   
    Enhancement
    test_pointwise.cpp
    Add multi-wave TMA pointwise fusion test                                 

    tests/cpp/test_pointwise.cpp

  • Add parameterized test PointwiseMultiWaveTMATest for TMA pointwise
    fusion
  • Support combinations of TMA store and explicit unroll
  • Schedule TMA loads to shared memory and compute in registers
  • Use inlineMost() to inline compute tensors for optimization
  • +145/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The vectorization condition in the test includes tv3 when not using TMA store, which may lead to incorrect vectorization on the output tensor that is not intended for vectorization in the non-TMA path.

    bool vectorize_condition =
        (tv == tv0_reg || tv == tv1_reg || (use_tma_store && tv == tv3_smem) ||
         (!use_tma_store && tv == tv3));
    if (vectorize_condition) {
      tv->axis(7)->parallelize(ParallelType::Vectorize);
    Performance Concern

    The use of inlineMost() after scheduling may interfere with TMA optimizations or memory hierarchy assumptions, potentially affecting multi-wave TMA performance; inlining strategy should be validated in context of TMA.

    inlineMost();

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    }
    }
    // Inline most tensors
    inlineMost();
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    inlineMost();
    // inlineMost(&fusion); // Removed undefined function call

    [Suggested by AI] The change simply removes the call to the (apparently undefined) function inlineMost()—now commented out—to eliminate the resulting build/undefined-reference error. No other logic was modified.

    "along each of the tensorRank dimensions, must be non-zero and less "
    "than or equal to 256. box_dim_val = 512")));
    }
    } // namespace nvfuser
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Hi @rdspring1 and @naoyam, what do you think about this comment?

    This test demonstrates that we cannot use nd-tma to load more than 256 elements from a 1D tensor. The dimension in nd-tma corresponds to the logical domain, and we lack the flexibility to merge or split logical domains to form loop domains that could be parallelized with TMA.

    It also doesn’t seem like a good idea to use reshape to alter the logical domains. We might be better off reverting to using a 1D TMA instead or keep both 1D and nD TMA and check performance difference.

    Copy link
    Collaborator

    @rdspring1 rdspring1 Oct 30, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Each individual multi-dimensional tma load has a 256 limit for each box dimension. However, you can issue multiple tma loads for a given mbarrier.

    We have to do this for matmuls to load a (256, 256) tile with 4x (256, 64) tma loads to avoid bank conflicts.

    if warp == load_warp:
        mbarrier::wait(load-mbarrier)
        for i in range(4):
            tma-load([256])
        arriveExpectTx(load-mbarrier, 256 * 2)
    elif warp == compute_warp:
        mbarrier::wait(load-mbarrier)
        # compute 256 * 2 tile
        mbarrier::arrive(load-mbarrier)

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants