1+ /* **************************************************************************************************
2+ * Copyright (c) 2025 Intel Corporation. All rights reserved.
3+ * SPDX-License-Identifier: BSD-3-Clause
4+ *
5+ * Redistribution and use in source and binary forms, with or without
6+ * modification, are permitted provided that the following conditions are met:
7+ *
8+ * 1. Redistributions of source code must retain the above copyright notice,
9+ *this list of conditions and the following disclaimer.
10+ *
11+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12+ * this list of conditions and the following disclaimer in the documentation
13+ * and/or other materials provided with the distribution.
14+ *
15+ * 3. Neither the name of the copyright holder nor the names of its
16+ * contributors may be used to endorse or promote products derived from
17+ * this software without specific prior written permission.
18+ *
19+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29+ *POSSIBILITY OF SUCH DAMAGE.
30+ *
31+ **************************************************************************************************/
32+ /* ! \file
33+ \brief CUTLASS Intel BMG MoE API example based on sycl-tla Group GEMM
34+
35+ */
36+ #include < torch/all.h>
37+ #include " utils.h"
38+
39+ #include < cute/tensor.hpp>
40+ #include < random>
41+
42+ #include < cute/util/compat.hpp>
43+ #include < sycl/ext/intel/experimental/grf_size_properties.hpp>
44+ #include < sycl/sycl.hpp>
45+
46+ #include < cute/tensor.hpp>
47+
48+ #include " cutlass/kernel_hardware_info.h"
49+ #include " cutlass/platform/platform.h"
50+ #include " cutlass/tensor_ref.h"
51+ #include " cutlass/util/GPU_Clock.hpp"
52+ #include " cutlass/util/device_memory.h"
53+ #include " cutlass/util/initialize_block.hpp"
54+ #include " cutlass/util/reference/device/gemm_complex.h"
55+ #include " cutlass/util/reference/device/tensor_compare.h"
56+ #include " cutlass/util/reference/host/tensor_fill.h"
57+ #include " cutlass/util/sycl_event_manager.hpp"
58+
59+ #include " xe_gemm_policy.hpp"
60+ #include " xe_grouped_gemm.hpp"
61+
62+ #pragma clang diagnostic ignored "-Wpass-failed"
63+ #pragma clang diagnostic ignored "-Wdeprecated-declarations"
64+
65+ namespace MoE {
66+ using namespace cute ;
67+
68+ // type tag to define a unique sycl kernel name
69+ template <typename , typename , typename , typename , char , char , class >
70+ class GemmCuteName ;
71+
72+ template <
73+ char layoutA,
74+ char layoutB,
75+ class policy ,
76+ typename ElementA,
77+ typename ElementB,
78+ typename ElementS,
79+ typename ElementBI,
80+ typename ElementD>
81+ void MoEGEMMLauncher (
82+ sycl::queue& stream,
83+ const ElementA* activations,
84+ const ElementB* weights,
85+ const ElementS* scales,
86+ const ElementBI* bias,
87+ ElementD* outputs,
88+ const int gemm_n,
89+ const int gemm_k,
90+ const int * num_rows_per_expert_device,
91+ const int num_experts,
92+ const int group_size,
93+ int32_t * atomic_buffer) {
94+ using ElementA_non_CV = cutlass::platform::remove_cv_t <ElementA>;
95+ auto op = XE_DPAS_TT<8 , float , ElementA_non_CV>{};
96+
97+ using WGTile = typename policy::WGTile;
98+ using SGLayout = typename policy::SGLayout;
99+ using MMA = typename TiledMMAHelper<
100+ MMA_Atom<decltype (op)>,
101+ Layout<WGTile>,
102+ SGLayout>::TiledMMA;
103+ auto mma = MMA{};
104+
105+ int sm_count =
106+ cutlass::KernelHardwareInfo::query_device_multiprocessor_count (0 );
107+ auto MaxThreadsPerWorkgroup = size (mma);
108+
109+ static constexpr int MaxThreadsPerSM = 512 ;
110+
111+ TORCH_CHECK (
112+ MaxThreadsPerSM % MaxThreadsPerWorkgroup == 0 ,
113+ " MaxThreadsPerSM must be divisible by MaxThreadsPerWorkgroup" )
114+
115+ sycl::range<3 > local (1 , 1 , MaxThreadsPerWorkgroup);
116+ sycl::range<3 > global (
117+ 1 , sm_count * MaxThreadsPerSM / MaxThreadsPerWorkgroup, 1 );
118+
119+ namespace syclex = sycl::ext::oneapi::experimental;
120+ namespace intelex = sycl::ext::intel::experimental;
121+
122+ syclex::properties kernel_props{
123+ syclex::sub_group_size<16 >, intelex::grf_size<256 >};
124+
125+ using GmemTiledCopyA = typename policy::GmemTiledCopyA;
126+ using GmemTiledCopyB = typename policy::GmemTiledCopyB;
127+ using GmemTiledCopyD = typename policy::GmemTiledCopyD;
128+
129+ auto event = stream.submit ([&](sycl::handler& cgh) {
130+ sycl::local_accessor<int32_t , 1 > local_mem (sycl::range<1 >(1 ), cgh);
131+ cgh.parallel_for <GemmCuteName<
132+ ElementA,
133+ ElementB,
134+ ElementS,
135+ ElementD,
136+ layoutA,
137+ layoutB,
138+ policy>>(
139+ sycl::nd_range<3 >{global * local, local}, kernel_props, [=](auto ) {
140+ MoE::MoEGEMM<
141+ GmemTiledCopyA,
142+ GmemTiledCopyB,
143+ GmemTiledCopyD,
144+ layoutA,
145+ layoutB,
146+ ' R' >(
147+ activations,
148+ weights,
149+ scales,
150+ bias,
151+ outputs,
152+ mma,
153+ num_rows_per_expert_device,
154+ num_experts,
155+ group_size,
156+ gemm_n,
157+ gemm_k,
158+ atomic_buffer,
159+ local_mem);
160+ });
161+ });
162+ EventManager::getInstance ().addEvent (event);
163+ }
164+
165+ at::Tensor cutlass_xe_grouped_gemm (
166+ at::Tensor& ptr_A,
167+ at::Tensor& ptr_B,
168+ const c10::optional<at::Tensor>& ptr_scales,
169+ const c10::optional<at::Tensor>& ptr_bias,
170+ at::Tensor& ptr_D,
171+ at::Tensor& num_rows_per_expert_device,
172+ int64_t N,
173+ int64_t K,
174+ int64_t num_experts,
175+ bool is_B_int4,
176+ bool is_B_mxfp4) {
177+ auto & dpcpp_queue =
178+ at::xpu::getCurrentXPUStream (ptr_A.device ().index ()).queue ();
179+ auto A_dtype = ptr_A.dtype ();
180+ auto B_dtype = ptr_B.dtype ();
181+ bool is_weight_fp8 =
182+ ((B_dtype == at::kFloat8_e4m3fn ) || (B_dtype == at::kFloat8_e5m2 ));
183+
184+ TORCH_CHECK (N % 32 == 0 , " N must be divisible by 32" );
185+
186+ TORCH_CHECK (ptr_A.dim () == 2 , " ptr_A must be 2D [Total_M, K]" );
187+ TORCH_CHECK (ptr_B.dim () == 3 , " ptr_B must be 3D [num_experts, K, N]" );
188+ TORCH_CHECK (ptr_D.dim () == 2 , " ptr_D must be 2D [Total_M, N]" );
189+ if (ptr_bias.has_value ()) {
190+ TORCH_CHECK (ptr_bias->dim () == 2 , " ptr_bias must be 2D [num_experts, N]" );
191+ }
192+
193+ TORCH_CHECK (ptr_A.is_contiguous (), " ptr_A must be contiguous" );
194+ TORCH_CHECK (ptr_B.is_contiguous (), " ptr_B must be contiguous" );
195+ TORCH_CHECK (ptr_D.is_contiguous (), " ptr_D must be contiguous" );
196+ if (ptr_bias.has_value ()) {
197+ TORCH_CHECK (ptr_bias->is_contiguous (), " ptr_bias must be contiguous" );
198+ }
199+
200+ int A_total_M = ptr_A.size (0 );
201+ int A_K = ptr_A.size (1 );
202+
203+ int B_E = ptr_B.size (0 );
204+ int B_K = ptr_B.size (1 );
205+ int B_N = ptr_B.size (2 );
206+ if (is_B_int4 || is_B_mxfp4) {
207+ B_K = ptr_B.size (2 ) * 2 ;
208+ B_N = ptr_B.size (1 );
209+ }
210+
211+ int D_total_M = ptr_D.size (0 );
212+ int D_N = ptr_D.size (1 );
213+ int group_size = -1 ;
214+ int A_avg_M = A_total_M / num_experts;
215+
216+ TORCH_CHECK (B_E == num_experts, " ptr_B.size(0) must match num_experts" );
217+ TORCH_CHECK (A_total_M == D_total_M, " ptr_A.size(0) must match ptr_D.size(0)" );
218+ TORCH_CHECK (A_K == B_K && B_K == K, " ptr_A.size(1) must match ptr_B.size(1)" );
219+ TORCH_CHECK (B_N == D_N && D_N == N, " ptr_B.size(2) must match ptr_D.size(1)" );
220+ if (ptr_bias.has_value ()) {
221+ TORCH_CHECK (
222+ ptr_bias->size (0 ) == num_experts,
223+ " ptr_bias.size(0) must match num_experts" );
224+ TORCH_CHECK (ptr_bias->size (1 ) == N, " ptr_bias.size(1) must match N" );
225+ }
226+
227+ at::Tensor atomic_buffer =
228+ at::empty ({static_cast <long >(1 )}, ptr_A.options ().dtype (at::kInt ));
229+
230+ #define MoEGEMMLauncherCallER ( \
231+ LayoutA, LayoutB, Policy, ElementA, ElementB, ElementS) \
232+ MoEGEMMLauncher<LayoutA, LayoutB, Policy>( \
233+ dpcpp_queue, \
234+ reinterpret_cast <ElementA*>(ptr_A.data_ptr ()), \
235+ reinterpret_cast <ElementB*>(ptr_B.data_ptr ()), \
236+ ptr_scales.has_value () \
237+ ? reinterpret_cast <ElementS*>(ptr_scales->data_ptr ()) \
238+ : static_cast <ElementS*>(nullptr ), \
239+ ptr_bias.has_value () ? reinterpret_cast <ElementA*>(ptr_bias->data_ptr ()) \
240+ : static_cast <ElementA*>(nullptr ), \
241+ reinterpret_cast <ElementA*>(ptr_D.data_ptr ()), \
242+ N, \
243+ K, \
244+ reinterpret_cast <int *>(num_rows_per_expert_device.data_ptr ()), \
245+ num_experts, \
246+ group_size, \
247+ static_cast <int *>(atomic_buffer.data_ptr ()));
248+
249+ if (is_B_int4 || is_B_mxfp4) {
250+ TORCH_CHECK (ptr_scales.has_value (), " w8a16 grouped gemm must have scales" );
251+ TORCH_CHECK (ptr_scales->is_contiguous (), " ptr_scales must be contiguous" );
252+ TORCH_CHECK (
253+ ptr_scales->dim () == 3 ,
254+ " ptr_scales of int4 must be 3D [num_experts, group_num, N]" );
255+ TORCH_CHECK (
256+ ptr_scales->size (0 ) == num_experts,
257+ " ptr_scales.size(0) of int4 must match num_experts" );
258+ TORCH_CHECK (
259+ K % ptr_scales->size (2 ) == 0 ,
260+ " ptr_scales.size(2) of int4 must be divisible by K" );
261+ TORCH_CHECK (
262+ ptr_scales->size (1 ) == N, " ptr_scales.size(1) of int4 must match N" );
263+ int group_num = ptr_scales->size (2 );
264+ group_size = K / group_num;
265+
266+ TORCH_CHECK (
267+ group_size == 32 || group_size == 64 || group_size == 128 ||
268+ group_size == 256 ,
269+ " group_size must be 32, 64, 128 or 256" );
270+
271+ #define W4A16LauncherCallER (policy ) \
272+ if (is_B_int4) { \
273+ if (A_dtype == at::kBFloat16 ) { \
274+ using scalar_t = bfloat16_t ; \
275+ MoEGEMMLauncherCallER (' R' , ' C' , policy, scalar_t , uint8_t , scalar_t ); \
276+ } else if (A_dtype == at::kHalf ) { \
277+ using scalar_t = half_t ; \
278+ MoEGEMMLauncherCallER (' R' , ' C' , policy, scalar_t , uint8_t , scalar_t ); \
279+ } \
280+ } else if (is_B_mxfp4) { \
281+ if (A_dtype == at::kBFloat16 ) { \
282+ using scalar_t = bfloat16_t ; \
283+ MoEGEMMLauncherCallER (' R' , ' C' , policy, scalar_t , uint8_t , uint8_t ); \
284+ } else if (A_dtype == at::kHalf ) { \
285+ using scalar_t = half_t ; \
286+ MoEGEMMLauncherCallER (' R' , ' C' , policy, scalar_t , uint8_t , uint8_t ); \
287+ } \
288+ }
289+
290+ if (A_avg_M <= 32 ) {
291+ using policy = w4a16_policy_m_16;
292+ W4A16LauncherCallER (policy);
293+ } else if (A_avg_M <= 128 ) {
294+ using policy = w4a16_policy_m_32;
295+ W4A16LauncherCallER (policy);
296+ } else {
297+ using policy = w4a16_policy;
298+ W4A16LauncherCallER (policy);
299+ }
300+ #undef W4A16LauncherCallER
301+ } else if (is_weight_fp8) {
302+ TORCH_CHECK (ptr_scales.has_value (), " w8a16 grouped gemm must have scales" );
303+ TORCH_CHECK (ptr_scales->is_contiguous (), " ptr_scales must be contiguous" );
304+ TORCH_CHECK (
305+ ptr_scales->dim () == 1 , " ptr_scales of fp8 must be 1D [num_experts]" );
306+ TORCH_CHECK (
307+ ptr_scales->size (0 ) == num_experts,
308+ " ptr_scales.size(0) of fp8 must match num_experts" );
309+
310+ #define W8A16LauncherCallER (policy ) \
311+ if (B_dtype == at::kFloat8_e4m3fn && A_dtype == at::kHalf ) { \
312+ using scalar_t = half_t ; \
313+ MoEGEMMLauncherCallER (' R' , ' R' , policy, scalar_t , float_e4m3_t , scalar_t ); \
314+ } else if (B_dtype == at::kFloat8_e5m2 && A_dtype == at::kHalf ) { \
315+ using scalar_t = half_t ; \
316+ MoEGEMMLauncherCallER (' R' , ' R' , policy, scalar_t , float_e5m2_t , scalar_t ); \
317+ } else if (B_dtype == at::kFloat8_e4m3fn && A_dtype == at::kBFloat16 ) { \
318+ using scalar_t = bfloat16_t ; \
319+ MoEGEMMLauncherCallER (' R' , ' R' , policy, scalar_t , float_e4m3_t , scalar_t ); \
320+ } else if (B_dtype == at::kFloat8_e5m2 && A_dtype == at::kBFloat16 ) { \
321+ using scalar_t = bfloat16_t ; \
322+ MoEGEMMLauncherCallER (' R' , ' R' , policy, scalar_t , float_e5m2_t , scalar_t ); \
323+ }
324+
325+ if (A_avg_M <= 32 ) {
326+ using policy = w8a16_policy_m_16;
327+ W8A16LauncherCallER (policy);
328+ } else if (A_avg_M <= 128 ) {
329+ using policy = w8a16_policy_m_32;
330+ W8A16LauncherCallER (policy);
331+ } else {
332+ using policy = w8a16_policy;
333+ W8A16LauncherCallER (policy);
334+ }
335+ #undef W8A16LauncherCallER
336+ } else {
337+ TORCH_CHECK (
338+ !ptr_scales.has_value (), " w16a16 grouped gemm must not have scales" );
339+
340+ #define W16A16LauncherCallER (policy ) \
341+ if (A_dtype == at::kBFloat16 ) { \
342+ using scalar_t = bfloat16_t ; \
343+ MoEGEMMLauncherCallER (' R' , ' R' , policy, scalar_t , scalar_t , scalar_t ); \
344+ } else if (A_dtype == at::kHalf ) { \
345+ using scalar_t = half_t ; \
346+ MoEGEMMLauncherCallER (' R' , ' R' , policy, scalar_t , scalar_t , scalar_t ); \
347+ }
348+
349+ if (A_avg_M <= 4 ) {
350+ using policy = w16a16_policy_m_16;
351+ W16A16LauncherCallER (policy);
352+ } else {
353+ using policy = w16a16_policy;
354+ W16A16LauncherCallER (policy);
355+ }
356+ #undef W16A16LauncherCallER
357+ }
358+ #undef MoEGEMMLauncherCallER
359+ return ptr_D;
360+ }
361+ } // namespace MoE
0 commit comments