3030 ***************************************************************************************************/
3131
3232
33-
3433#include < exception>
3534#include < iostream>
3635#include < memory>
@@ -95,31 +94,29 @@ struct identity_op {
9594 T operator ()(T val) const { return val; }
9695};
9796
98-
99-
100- using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_epilogue =
97+ using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_epilogue =
10198 typename cutlass::epilogue::collective::CollectiveBuilder<
10299 cutlass::arch::Xe20, cutlass::arch::OpClassTensorOp,
103100 cute::Shape<cute::_256, cute::_256, cute::_32>,
104101 cute::Shape<cute::_1, cute::_1, cute::_1>,
105102 cutlass::epilogue::collective::EpilogueTileAuto,
106103 float , float ,
107- float , cutlass::layout::RowMajor, 4 ,
108- float , cutlass::layout::RowMajor, 4 ,
104+ cutlass:: bfloat16_t , cutlass::layout::ColumnMajor, 8 , // Bias
105+ cutlass:: bfloat16_t , cutlass::layout::RowMajor, 8 , // Output
109106 cutlass::epilogue::collective::EpilogueScheduleAuto,
110107 cutlass::epilogue::fusion::LinearCombination<
108+ cutlass::bfloat16_t ,
111109 float ,
112- float ,
113- float ,
110+ cutlass::bfloat16_t ,
114111 float
115112 >
116113 >::CollectiveOp;
117114
118- using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop =
115+ using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_mainloop =
119116 typename cutlass::gemm::collective::CollectiveBuilder<
120117 cutlass::arch::Xe20, cutlass::arch::OpClassTensorOp,
121- cutlass::bfloat16_t , cutlass::layout::ColumnMajor , 8 ,
122- cutlass::bfloat16_t , cutlass::layout::ColumnMajor , 8 ,
118+ cutlass::bfloat16_t , cutlass::layout::RowMajor , 8 , // A
119+ cutlass::bfloat16_t , cutlass::layout::RowMajor , 8 , // B
123120 float ,
124121 cute::Shape<cute::_256, cute::_256, cute::_32>,
125122 cute::Shape<cute::_1, cute::_1, cute::_1>,
@@ -128,34 +125,88 @@ using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop =
128125 >::CollectiveOp;
129126
130127// Gemm operator cutlass3x_xe11_tensorop_gemm_bf16_128x256_16x0_tn_align2
131- using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_base = cutlass::gemm::kernel::GemmUniversal<
128+ using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_base = cutlass::gemm::kernel::GemmUniversal<
132129 cute::Shape<int ,int ,int ,int >,
133- cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop ,
134- cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_epilogue ,
130+ cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_mainloop ,
131+ cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_epilogue ,
135132 cutlass::gemm::PersistentScheduler>;
136133
137134// Define named type
138- struct cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8 :
139- public cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_base { };
140-
135+ struct cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8 :
136+ public cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_base { };
141137
142- using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8 >;
138+ using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8 >;
143139
144140// When workspace_size is not a nullptr, populates requested workspace_size and returns.
145141// Otherwise, computes the Gemm kernel using the given workspace ptr.
146142extern " C" {
147- PT_EXPORT int sycl_tla_gemm_xe20_bf16 (const uint16_t * X, const uint16_t * W, uint16_t * Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const int X_offset, const int W_offset, const int Y_offset, const uint8_t swizzle, size_t * workspace_size, uint8_t * workspace, sycl::queue* stream) {
143+ PT_EXPORT int sycl_tla_gemm_xe20_bf16 (const cutlass:: bfloat16_t * X, const cutlass:: bfloat16_t * W, const cutlass:: bfloat16_t * Bias, cutlass:: bfloat16_t * Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const int X_offset, const int W_offset, const int Bias_offset , const int Y_offset, const uint8_t swizzle, size_t * workspace_size, uint8_t * workspace, sycl::queue* stream) {
148144 try {
149- using ElementComputeEpilogue = cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type ::ElementAccumulator;
145+ using ElementComputeEpilogue = cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_device_type ::ElementAccumulator;
150146 using coord_t = cutlass::gemm::GemmCoord::Index;
151147 static cutlass::KernelHardwareInfo hw_info;
152148 if (hw_info.sm_count == 0 ) {
153149 hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count (0 );
154150 CUTLASS_TRACE_HOST (" Query result for SM count per device: " << hw_info.sm_count );
155151 }
156152
153+ cutlass::DeviceAllocation<cutlass::bfloat16_t > block_A;
154+ cutlass::DeviceAllocation<cutlass::bfloat16_t > block_B;
155+ cutlass::DeviceAllocation<cutlass::bfloat16_t > block_C;
156+ cutlass::DeviceAllocation<cutlass::bfloat16_t > block_D;
157+
158+ if (!workspace_size) {
159+ if (!X || !W) {
160+ std::cerr << " Input host pointers null!" << std::endl;
161+ return -1 ;
162+ }
163+ else {
164+ block_A.reset (static_cast <std::size_t >(M) * K * B);
165+ block_B.reset (static_cast <std::size_t >(K) * N * B);
166+ if (!block_A.get () || !block_B.get ()) {
167+ std::cerr << " Device allocation of inputs failed!" << std::endl;
168+ return -1 ;
169+ }
170+ compat::wait ();
171+ compat::memcpy (block_A.get (), (X + X_offset), (M * K * B) * sizeof (cutlass::bfloat16_t ));
172+ compat::wait ();
173+ compat::memcpy (block_B.get (), (W + W_offset), (K * N * B) * sizeof (cutlass::bfloat16_t ));
174+ compat::wait ();
175+ }
176+
177+ if (!Bias) {
178+ std::cerr << " Bias host pointer null!" << std::endl;
179+ return -1 ;
180+ }
181+ else {
182+ block_C.reset (static_cast <std::size_t >(M) * N * B);
183+ if (!block_C.get ()) {
184+ std::cerr << " Device allocation of bias failed!" << std::endl;
185+ return -1 ;
186+ }
187+ compat::wait ();
188+ compat::memcpy (block_C.get (), (Bias + Bias_offset), (M * N * B) * sizeof (cutlass::bfloat16_t ));
189+ compat::wait ();
190+ }
191+
192+ if (!Y) {
193+ std::cerr << " Output host pointer null!" << std::endl;
194+ return -1 ;
195+ }
196+ else {
197+ block_D.reset (static_cast <std::size_t >(M) * N * B);
198+ if (!block_D.get ()) {
199+ std::cerr << " Device allocation of output failed!" << std::endl;
200+ return -1 ;
201+ }
202+ compat::wait ();
203+ compat::memset (block_D.get (), 0 , (M * N * B) * sizeof (cutlass::bfloat16_t ));
204+ compat::wait ();
205+ }
206+ }
207+
157208 // Initialize GemmUniversal3xInstance arguments using constructor
158- cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type ::Arguments arguments{
209+ cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_device_type ::Arguments arguments{
159210 cutlass::gemm::GemmUniversalMode::kGemm , // GemmUniversalMode mode
160211 {
161212 static_cast <coord_t >(M),
@@ -164,28 +215,30 @@ PT_EXPORT int sycl_tla_gemm_xe20_bf16(const uint16_t* X, const uint16_t* W, uint
164215 static_cast <coord_t >(B)
165216 }, // ProblemShape problem_shape
166217 {
167- (cutlass::bfloat16_t *)(X + X_offset ), // ElementA const* ptr_A
168- cute::make_tuple ( cute::Int<1 >{}, int64_t (lda), int64_t ( 0 )), // StrideA dA (column-major: stride_m=1, stride_n=lda, batch=0)
169- (cutlass::bfloat16_t *)(W + W_offset ), // ElementB const* ptr_B
170- cute::make_tuple ( int64_t (ldb), cute:: Int<1 >{}, int64_t (0 )), // StrideB dB (column-major: stride_m=ldb, stride_n=1, batch=0)
218+ (cutlass::bfloat16_t *)(block_A. get () ), // ElementA const* ptr_A
219+ { int64_t (lda), cute::Int<1 >{}, int64_t (0 )},
220+ (cutlass::bfloat16_t *)(block_B. get () ), // ElementB const* ptr_B
221+ { cute::Int<1 >{}, int64_t (ldb), int64_t ( 0 )},
171222 }, // MainloopArguments mainloop
172223
173224 // see https://tinyurl.com/4rk89z48
174225 {
175- {ElementComputeEpilogue ( 1 ), ElementComputeEpilogue ( 0 ) }, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
176- nullptr , // ElementC const* ptr_C
177- cute::make_tuple ( int64_t ( 0 ), cute:: Int<1 >{}, int64_t (0 )), // StrideC dC (row-major: stride_m, stride_n=1, batch=0)
178- (float *)(Y + Y_offset) , // ElementD ptr_D (output is float, not bfloat16)
179- cute::make_tuple ( int64_t (ldd), cute::Int<1 >{}, int64_t (0 )), // StrideD dD (row-major: stride_m=ldd, stride_n=1, batch=0)
226+ {1 . f , 1 . f }, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
227+ (cutlass:: bfloat16_t *)(block_C. get ()) , // ElementC const* ptr_C
228+ { cute::Int<1 >{}, int64_t (ldc), int64_t ( 0 )},
229+ (cutlass:: bfloat16_t *)(block_D. get ()) , // ElementD const* ptr_D
230+ { int64_t (ldd), cute::Int<1 >{}, int64_t (0 )},
180231 }, // EpilogueArguments epilogue,
181232 hw_info
182233 };
234+
183235 arguments.scheduler .max_swizzle_size = swizzle;
184- cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type gemm_op;
236+ cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8_device_type gemm_op;
185237 if (workspace_size) {
186238 *workspace_size = gemm_op.get_workspace_size (arguments);
187239 return 0 ;
188240 }
241+
189242 // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
190243#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
191244 {
@@ -209,6 +262,10 @@ PT_EXPORT int sycl_tla_gemm_xe20_bf16(const uint16_t* X, const uint16_t* W, uint
209262 {
210263 auto status = gemm_op (stream);
211264 CUTLASS_CHECK (status);
265+
266+ compat::wait ();
267+ compat::memcpy ((Y + Y_offset), block_D.get (), (M * N * B) * sizeof (cutlass::bfloat16_t ));
268+ compat::wait ();
212269 }
213270 }
214271 catch (std::exception& e) {
@@ -222,4 +279,4 @@ PT_EXPORT int sycl_tla_gemm_xe20_bf16(const uint16_t* X, const uint16_t* W, uint
222279}
223280}
224281
225- // configuration name: cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8
282+ // configuration name: cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_tt_align8
0 commit comments