File tree Expand file tree Collapse file tree
transformer_engine/common/cast/mxfp8 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -705,6 +705,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
705705 if constexpr (COLWISE_SCALING ) {
706706 thread_partial_dbias = partial_dbias_colwise;
707707 } else {
708+ ptx::cp_async_bulk_wait_group_read<0 >();
709+ __syncthreads ();
708710 float *partial_dbias_rowwise = reinterpret_cast <float *>(dshmem);
709711
710712 constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1 );
Original file line number Diff line number Diff line change @@ -498,6 +498,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
498498 if constexpr (COLWISE_SCALING ) {
499499 thread_partial_dbias = partial_dbias_colwise;
500500 } else {
501+ ptx::cp_async_bulk_wait_group_read<0 >();
502+ __syncthreads ();
501503 // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH]
502504 // HEIGHT = THREADS_Y
503505 // WIDTH = THREADS_X * (SCALE_DIM_X + 1)
You can’t perform that action at this time.
0 commit comments