Skip to content

Commit 29c5ee0

Browse files
authored
[cutlass] support fp8/int4/mxfp4 weights grouped gemm (vllm-project#88)
* xe grouped gemm bf16/fp16/fp8 Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * format Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * remove useless code Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * format Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * support w4a16 without scales Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * support w4a16 with scales Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * support mxfp4 Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * optimize int4/mxfp4 Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * refactor code Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * fix bug Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * refine code Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * add policy for better performance Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * refactor code Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * finetune policy for small m Signed-off-by: mayuyuace <qiming1.zhang@intel.com> --------- Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
1 parent 46b027f commit 29c5ee0

8 files changed

Lines changed: 1480 additions & 0 deletions

File tree

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 Intel Corporation. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice,
9+
*this list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29+
*POSSIBILITY OF SUCH DAMAGE.
30+
*
31+
**************************************************************************************************/
32+
/*! \file
33+
\brief CUTLASS Intel BMG MoE API example based on sycl-tla Group GEMM
34+
35+
*/
36+
#include <torch/all.h>
37+
#include "utils.h"
38+
39+
#include <cute/tensor.hpp>
40+
#include <random>
41+
42+
#include <cute/util/compat.hpp>
43+
#include <sycl/ext/intel/experimental/grf_size_properties.hpp>
44+
#include <sycl/sycl.hpp>
45+
46+
#include <cute/tensor.hpp>
47+
48+
#include "cutlass/kernel_hardware_info.h"
49+
#include "cutlass/platform/platform.h"
50+
#include "cutlass/tensor_ref.h"
51+
#include "cutlass/util/GPU_Clock.hpp"
52+
#include "cutlass/util/device_memory.h"
53+
#include "cutlass/util/initialize_block.hpp"
54+
#include "cutlass/util/reference/device/gemm_complex.h"
55+
#include "cutlass/util/reference/device/tensor_compare.h"
56+
#include "cutlass/util/reference/host/tensor_fill.h"
57+
#include "cutlass/util/sycl_event_manager.hpp"
58+
59+
#include "xe_gemm_policy.hpp"
60+
#include "xe_grouped_gemm.hpp"
61+
62+
#pragma clang diagnostic ignored "-Wpass-failed"
63+
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
64+
65+
namespace MoE {
66+
using namespace cute;
67+
68+
// type tag to define a unique sycl kernel name
69+
template <typename, typename, typename, typename, char, char, class>
70+
class GemmCuteName;
71+
72+
template <
73+
char layoutA,
74+
char layoutB,
75+
class policy,
76+
typename ElementA,
77+
typename ElementB,
78+
typename ElementS,
79+
typename ElementBI,
80+
typename ElementD>
81+
void MoEGEMMLauncher(
82+
sycl::queue& stream,
83+
const ElementA* activations,
84+
const ElementB* weights,
85+
const ElementS* scales,
86+
const ElementBI* bias,
87+
ElementD* outputs,
88+
const int gemm_n,
89+
const int gemm_k,
90+
const int* num_rows_per_expert_device,
91+
const int num_experts,
92+
const int group_size,
93+
int32_t* atomic_buffer) {
94+
using ElementA_non_CV = cutlass::platform::remove_cv_t<ElementA>;
95+
auto op = XE_DPAS_TT<8, float, ElementA_non_CV>{};
96+
97+
using WGTile = typename policy::WGTile;
98+
using SGLayout = typename policy::SGLayout;
99+
using MMA = typename TiledMMAHelper<
100+
MMA_Atom<decltype(op)>,
101+
Layout<WGTile>,
102+
SGLayout>::TiledMMA;
103+
auto mma = MMA{};
104+
105+
int sm_count =
106+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
107+
auto MaxThreadsPerWorkgroup = size(mma);
108+
109+
static constexpr int MaxThreadsPerSM = 512;
110+
111+
TORCH_CHECK(
112+
MaxThreadsPerSM % MaxThreadsPerWorkgroup == 0,
113+
"MaxThreadsPerSM must be divisible by MaxThreadsPerWorkgroup")
114+
115+
sycl::range<3> local(1, 1, MaxThreadsPerWorkgroup);
116+
sycl::range<3> global(
117+
1, sm_count * MaxThreadsPerSM / MaxThreadsPerWorkgroup, 1);
118+
119+
namespace syclex = sycl::ext::oneapi::experimental;
120+
namespace intelex = sycl::ext::intel::experimental;
121+
122+
syclex::properties kernel_props{
123+
syclex::sub_group_size<16>, intelex::grf_size<256>};
124+
125+
using GmemTiledCopyA = typename policy::GmemTiledCopyA;
126+
using GmemTiledCopyB = typename policy::GmemTiledCopyB;
127+
using GmemTiledCopyD = typename policy::GmemTiledCopyD;
128+
129+
auto event = stream.submit([&](sycl::handler& cgh) {
130+
sycl::local_accessor<int32_t, 1> local_mem(sycl::range<1>(1), cgh);
131+
cgh.parallel_for<GemmCuteName<
132+
ElementA,
133+
ElementB,
134+
ElementS,
135+
ElementD,
136+
layoutA,
137+
layoutB,
138+
policy>>(
139+
sycl::nd_range<3>{global * local, local}, kernel_props, [=](auto) {
140+
MoE::MoEGEMM<
141+
GmemTiledCopyA,
142+
GmemTiledCopyB,
143+
GmemTiledCopyD,
144+
layoutA,
145+
layoutB,
146+
'R'>(
147+
activations,
148+
weights,
149+
scales,
150+
bias,
151+
outputs,
152+
mma,
153+
num_rows_per_expert_device,
154+
num_experts,
155+
group_size,
156+
gemm_n,
157+
gemm_k,
158+
atomic_buffer,
159+
local_mem);
160+
});
161+
});
162+
EventManager::getInstance().addEvent(event);
163+
}
164+
165+
at::Tensor cutlass_xe_grouped_gemm(
166+
at::Tensor& ptr_A,
167+
at::Tensor& ptr_B,
168+
const c10::optional<at::Tensor>& ptr_scales,
169+
const c10::optional<at::Tensor>& ptr_bias,
170+
at::Tensor& ptr_D,
171+
at::Tensor& num_rows_per_expert_device,
172+
int64_t N,
173+
int64_t K,
174+
int64_t num_experts,
175+
bool is_B_int4,
176+
bool is_B_mxfp4) {
177+
auto& dpcpp_queue =
178+
at::xpu::getCurrentXPUStream(ptr_A.device().index()).queue();
179+
auto A_dtype = ptr_A.dtype();
180+
auto B_dtype = ptr_B.dtype();
181+
bool is_weight_fp8 =
182+
((B_dtype == at::kFloat8_e4m3fn) || (B_dtype == at::kFloat8_e5m2));
183+
184+
TORCH_CHECK(N % 32 == 0, "N must be divisible by 32");
185+
186+
TORCH_CHECK(ptr_A.dim() == 2, "ptr_A must be 2D [Total_M, K]");
187+
TORCH_CHECK(ptr_B.dim() == 3, "ptr_B must be 3D [num_experts, K, N]");
188+
TORCH_CHECK(ptr_D.dim() == 2, "ptr_D must be 2D [Total_M, N]");
189+
if (ptr_bias.has_value()) {
190+
TORCH_CHECK(ptr_bias->dim() == 2, "ptr_bias must be 2D [num_experts, N]");
191+
}
192+
193+
TORCH_CHECK(ptr_A.is_contiguous(), "ptr_A must be contiguous");
194+
TORCH_CHECK(ptr_B.is_contiguous(), "ptr_B must be contiguous");
195+
TORCH_CHECK(ptr_D.is_contiguous(), "ptr_D must be contiguous");
196+
if (ptr_bias.has_value()) {
197+
TORCH_CHECK(ptr_bias->is_contiguous(), "ptr_bias must be contiguous");
198+
}
199+
200+
int A_total_M = ptr_A.size(0);
201+
int A_K = ptr_A.size(1);
202+
203+
int B_E = ptr_B.size(0);
204+
int B_K = ptr_B.size(1);
205+
int B_N = ptr_B.size(2);
206+
if (is_B_int4 || is_B_mxfp4) {
207+
B_K = ptr_B.size(2) * 2;
208+
B_N = ptr_B.size(1);
209+
}
210+
211+
int D_total_M = ptr_D.size(0);
212+
int D_N = ptr_D.size(1);
213+
int group_size = -1;
214+
int A_avg_M = A_total_M / num_experts;
215+
216+
TORCH_CHECK(B_E == num_experts, "ptr_B.size(0) must match num_experts");
217+
TORCH_CHECK(A_total_M == D_total_M, "ptr_A.size(0) must match ptr_D.size(0)");
218+
TORCH_CHECK(A_K == B_K && B_K == K, "ptr_A.size(1) must match ptr_B.size(1)");
219+
TORCH_CHECK(B_N == D_N && D_N == N, "ptr_B.size(2) must match ptr_D.size(1)");
220+
if (ptr_bias.has_value()) {
221+
TORCH_CHECK(
222+
ptr_bias->size(0) == num_experts,
223+
"ptr_bias.size(0) must match num_experts");
224+
TORCH_CHECK(ptr_bias->size(1) == N, "ptr_bias.size(1) must match N");
225+
}
226+
227+
at::Tensor atomic_buffer =
228+
at::empty({static_cast<long>(1)}, ptr_A.options().dtype(at::kInt));
229+
230+
#define MoEGEMMLauncherCallER( \
231+
LayoutA, LayoutB, Policy, ElementA, ElementB, ElementS) \
232+
MoEGEMMLauncher<LayoutA, LayoutB, Policy>( \
233+
dpcpp_queue, \
234+
reinterpret_cast<ElementA*>(ptr_A.data_ptr()), \
235+
reinterpret_cast<ElementB*>(ptr_B.data_ptr()), \
236+
ptr_scales.has_value() \
237+
? reinterpret_cast<ElementS*>(ptr_scales->data_ptr()) \
238+
: static_cast<ElementS*>(nullptr), \
239+
ptr_bias.has_value() ? reinterpret_cast<ElementA*>(ptr_bias->data_ptr()) \
240+
: static_cast<ElementA*>(nullptr), \
241+
reinterpret_cast<ElementA*>(ptr_D.data_ptr()), \
242+
N, \
243+
K, \
244+
reinterpret_cast<int*>(num_rows_per_expert_device.data_ptr()), \
245+
num_experts, \
246+
group_size, \
247+
static_cast<int*>(atomic_buffer.data_ptr()));
248+
249+
if (is_B_int4 || is_B_mxfp4) {
250+
TORCH_CHECK(ptr_scales.has_value(), "w8a16 grouped gemm must have scales");
251+
TORCH_CHECK(ptr_scales->is_contiguous(), "ptr_scales must be contiguous");
252+
TORCH_CHECK(
253+
ptr_scales->dim() == 3,
254+
"ptr_scales of int4 must be 3D [num_experts, group_num, N]");
255+
TORCH_CHECK(
256+
ptr_scales->size(0) == num_experts,
257+
"ptr_scales.size(0) of int4 must match num_experts");
258+
TORCH_CHECK(
259+
K % ptr_scales->size(2) == 0,
260+
"ptr_scales.size(2) of int4 must be divisible by K");
261+
TORCH_CHECK(
262+
ptr_scales->size(1) == N, "ptr_scales.size(1) of int4 must match N");
263+
int group_num = ptr_scales->size(2);
264+
group_size = K / group_num;
265+
266+
TORCH_CHECK(
267+
group_size == 32 || group_size == 64 || group_size == 128 ||
268+
group_size == 256,
269+
"group_size must be 32, 64, 128 or 256");
270+
271+
#define W4A16LauncherCallER(policy) \
272+
if (is_B_int4) { \
273+
if (A_dtype == at::kBFloat16) { \
274+
using scalar_t = bfloat16_t; \
275+
MoEGEMMLauncherCallER('R', 'C', policy, scalar_t, uint8_t, scalar_t); \
276+
} else if (A_dtype == at::kHalf) { \
277+
using scalar_t = half_t; \
278+
MoEGEMMLauncherCallER('R', 'C', policy, scalar_t, uint8_t, scalar_t); \
279+
} \
280+
} else if (is_B_mxfp4) { \
281+
if (A_dtype == at::kBFloat16) { \
282+
using scalar_t = bfloat16_t; \
283+
MoEGEMMLauncherCallER('R', 'C', policy, scalar_t, uint8_t, uint8_t); \
284+
} else if (A_dtype == at::kHalf) { \
285+
using scalar_t = half_t; \
286+
MoEGEMMLauncherCallER('R', 'C', policy, scalar_t, uint8_t, uint8_t); \
287+
} \
288+
}
289+
290+
if (A_avg_M <= 32) {
291+
using policy = w4a16_policy_m_16;
292+
W4A16LauncherCallER(policy);
293+
} else if (A_avg_M <= 128) {
294+
using policy = w4a16_policy_m_32;
295+
W4A16LauncherCallER(policy);
296+
} else {
297+
using policy = w4a16_policy;
298+
W4A16LauncherCallER(policy);
299+
}
300+
#undef W4A16LauncherCallER
301+
} else if (is_weight_fp8) {
302+
TORCH_CHECK(ptr_scales.has_value(), "w8a16 grouped gemm must have scales");
303+
TORCH_CHECK(ptr_scales->is_contiguous(), "ptr_scales must be contiguous");
304+
TORCH_CHECK(
305+
ptr_scales->dim() == 1, "ptr_scales of fp8 must be 1D [num_experts]");
306+
TORCH_CHECK(
307+
ptr_scales->size(0) == num_experts,
308+
"ptr_scales.size(0) of fp8 must match num_experts");
309+
310+
#define W8A16LauncherCallER(policy) \
311+
if (B_dtype == at::kFloat8_e4m3fn && A_dtype == at::kHalf) { \
312+
using scalar_t = half_t; \
313+
MoEGEMMLauncherCallER('R', 'R', policy, scalar_t, float_e4m3_t, scalar_t); \
314+
} else if (B_dtype == at::kFloat8_e5m2 && A_dtype == at::kHalf) { \
315+
using scalar_t = half_t; \
316+
MoEGEMMLauncherCallER('R', 'R', policy, scalar_t, float_e5m2_t, scalar_t); \
317+
} else if (B_dtype == at::kFloat8_e4m3fn && A_dtype == at::kBFloat16) { \
318+
using scalar_t = bfloat16_t; \
319+
MoEGEMMLauncherCallER('R', 'R', policy, scalar_t, float_e4m3_t, scalar_t); \
320+
} else if (B_dtype == at::kFloat8_e5m2 && A_dtype == at::kBFloat16) { \
321+
using scalar_t = bfloat16_t; \
322+
MoEGEMMLauncherCallER('R', 'R', policy, scalar_t, float_e5m2_t, scalar_t); \
323+
}
324+
325+
if (A_avg_M <= 32) {
326+
using policy = w8a16_policy_m_16;
327+
W8A16LauncherCallER(policy);
328+
} else if (A_avg_M <= 128) {
329+
using policy = w8a16_policy_m_32;
330+
W8A16LauncherCallER(policy);
331+
} else {
332+
using policy = w8a16_policy;
333+
W8A16LauncherCallER(policy);
334+
}
335+
#undef W8A16LauncherCallER
336+
} else {
337+
TORCH_CHECK(
338+
!ptr_scales.has_value(), "w16a16 grouped gemm must not have scales");
339+
340+
#define W16A16LauncherCallER(policy) \
341+
if (A_dtype == at::kBFloat16) { \
342+
using scalar_t = bfloat16_t; \
343+
MoEGEMMLauncherCallER('R', 'R', policy, scalar_t, scalar_t, scalar_t); \
344+
} else if (A_dtype == at::kHalf) { \
345+
using scalar_t = half_t; \
346+
MoEGEMMLauncherCallER('R', 'R', policy, scalar_t, scalar_t, scalar_t); \
347+
}
348+
349+
if (A_avg_M <= 4) {
350+
using policy = w16a16_policy_m_16;
351+
W16A16LauncherCallER(policy);
352+
} else {
353+
using policy = w16a16_policy;
354+
W16A16LauncherCallER(policy);
355+
}
356+
#undef W16A16LauncherCallER
357+
}
358+
#undef MoEGEMMLauncherCallER
359+
return ptr_D;
360+
}
361+
} // namespace MoE

0 commit comments

Comments
 (0)