Skip to content

Commit 72be44f

Browse files
authored
sycl : fix reorder function; add fp32/fp16 in build script (#24578)
1 parent 8872ab5 commit 72be44f

3 files changed

Lines changed: 123 additions & 78 deletions

File tree

examples/sycl/build.sh

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,45 @@
33
# Copyright (C) 2024 Intel Corporation
44
# SPDX-License-Identifier: MIT
55

6+
print_usage() {
7+
echo "Usage: ./build.sh [fp32|fp16] [--help]"
8+
echo ""
9+
echo "Options:"
10+
echo " fp32 Build with FP32 precision (default)"
11+
echo " fp16 Build with FP16 precision (faster for long-prompt inference)"
12+
echo " --help Print this help message"
13+
}
14+
15+
PRECISION=fp32
16+
17+
for arg in "$@"; do
18+
case "$arg" in
19+
--help)
20+
print_usage
21+
exit 0
22+
;;
23+
fp32|fp16)
24+
PRECISION="$arg"
25+
;;
26+
*)
27+
echo "Error: unknown option '$arg'"
28+
print_usage
29+
exit 1
30+
;;
31+
esac
32+
done
33+
634
mkdir -p build
735
cd build
836
source /opt/intel/oneapi/setvars.sh
937

10-
#for FP16
11-
#cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DLLAMA_OPENSSL=OFF # faster for long-prompt inference
12-
13-
#for FP32
14-
cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_OPENSSL=OFF
38+
if [ "$PRECISION" = "fp16" ]; then
39+
#for FP16
40+
cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DLLAMA_OPENSSL=OFF # faster for long-prompt inference
41+
else
42+
#for FP32
43+
cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_OPENSSL=OFF
44+
fi
1545

1646
#build example/main
1747
#cmake --build . --config Release --target main

examples/sycl/win-build-sycl.bat

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,23 @@
33
:: Copyright (C) 2024 Intel Corporation
44
:: SPDX-License-Identifier: MIT
55

6+
IF /I "%1"=="--help" (
7+
echo Usage: win-build-sycl.bat [fp32^|fp16] [--help]
8+
echo.
9+
echo Options:
10+
echo fp32 Build with FP32 precision ^(default^)
11+
echo fp16 Build with FP16 precision ^(faster for long-prompt inference^)
12+
echo --help Print this help message
13+
exit /B 0
14+
)
15+
16+
SET PRECISION=%1
17+
IF "%PRECISION%"=="" SET PRECISION=fp32
18+
IF /I NOT "%PRECISION%"=="fp32" IF /I NOT "%PRECISION%"=="fp16" (
19+
echo Error: invalid value '%PRECISION%'. Use 'fp32' or 'fp16'.
20+
echo Usage: win-build-sycl.bat [fp32^|fp16] [--help]
21+
exit /B 1
22+
)
623

724
IF not exist build (mkdir build)
825
cd build
@@ -11,12 +28,14 @@ if %errorlevel% neq 0 goto ERROR
1128
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
1229
if %errorlevel% neq 0 goto ERROR
1330

14-
:: for FP16
15-
:: faster for long-prompt inference
16-
:: cmake -G "MinGW Makefiles" .. -DLLAMA_OPENSSL=OFF -DGGML_SYCL=ON -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release -DGGML_SYCL_F16=ON
17-
18-
:: for FP32
19-
cmake -G "Ninja" .. -DLLAMA_OPENSSL=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release
31+
IF /I "%PRECISION%"=="fp16" (
32+
:: for FP16
33+
:: faster for long-prompt inference
34+
cmake -G "MinGW Makefiles" .. -DLLAMA_OPENSSL=OFF -DGGML_SYCL=ON -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release -DGGML_SYCL_F16=ON
35+
) ELSE (
36+
:: for FP32
37+
cmake -G "Ninja" .. -DLLAMA_OPENSSL=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release
38+
)
2039
if %errorlevel% neq 0 goto ERROR
2140

2241
:: build all binary

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 63 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -662,13 +662,12 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
662662
GGML_ASSERT(ncols % QK4_0 == 0);
663663
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
664664
constexpr size_t num_subgroups = WARP_SIZE;
665-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
666-
667-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
668-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
665+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
666+
const sycl::range<3> block_nums(1, 1, block_num_y);
667+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
669668

670669
stream->submit([&](sycl::handler & cgh) {
671-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
670+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
672671
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
673672
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
674673
nd_item);
@@ -683,13 +682,13 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols(
683682
const int stride_col_y_bytes, const int stride_col_dst,
684683
dpct::queue_ptr stream) {
685684
GGML_ASSERT(ncols % QK4_0 == 0);
686-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
687-
constexpr size_t num_subgroups = 16;
688-
GGML_ASSERT(block_num_y % num_subgroups == 0);
689-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
690-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
685+
constexpr size_t num_subgroups = WARP_SIZE;
686+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
687+
const sycl::range<3> block_nums(1, 1, block_num_y);
688+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
689+
691690
stream->submit([&](sycl::handler & cgh) {
692-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
691+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
693692
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
694693
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>, ncols_dst>(
695694
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
@@ -1080,13 +1079,12 @@ static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy,
10801079
GGML_ASSERT(ncols % QK8_0 == 0);
10811080
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
10821081
constexpr size_t num_subgroups = WARP_SIZE;
1083-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1084-
1085-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
1086-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1082+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1083+
const sycl::range<3> block_nums(1, 1, block_num_y);
1084+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
10871085

10881086
stream->submit([&](sycl::handler & cgh) {
1089-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1087+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
10901088
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
10911089
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>>(vx, vy, dst, ncols, nrows,
10921090
nd_item);
@@ -1101,13 +1099,13 @@ static void reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols(
11011099
const int stride_col_y_bytes, const int stride_col_dst,
11021100
dpct::queue_ptr stream) {
11031101
GGML_ASSERT(ncols % QK8_0 == 0);
1104-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1105-
constexpr size_t num_subgroups = 16;
1106-
GGML_ASSERT(block_num_y % num_subgroups == 0);
1107-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1108-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1102+
constexpr size_t num_subgroups = WARP_SIZE;
1103+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1104+
const sycl::range<3> block_nums(1, 1, block_num_y);
1105+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1106+
11091107
stream->submit([&](sycl::handler & cgh) {
1110-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1108+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
11111109
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
11121110
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>, ncols_dst>(
11131111
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
@@ -1289,13 +1287,12 @@ static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy,
12891287

12901288
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
12911289
constexpr size_t num_subgroups = WARP_SIZE;
1292-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1293-
1294-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1295-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1290+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1291+
const sycl::range<3> block_nums(1, 1, block_num_y);
1292+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
12961293

12971294
stream->submit([&](sycl::handler & cgh) {
1298-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1295+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
12991296
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
13001297
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>>(vx, vy, dst, ncols, nrows,
13011298
nd_item);
@@ -1310,13 +1307,13 @@ static void reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols(
13101307
const int stride_col_y_bytes, const int stride_col_dst,
13111308
dpct::queue_ptr stream) {
13121309
GGML_ASSERT(ncols % QK_K == 0);
1313-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1314-
constexpr size_t num_subgroups = 16;
1315-
GGML_ASSERT(block_num_y % num_subgroups == 0);
1316-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1317-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1310+
constexpr size_t num_subgroups = WARP_SIZE;
1311+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1312+
const sycl::range<3> block_nums(1, 1, block_num_y);
1313+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1314+
13181315
stream->submit([&](sycl::handler & cgh) {
1319-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1316+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
13201317
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
13211318
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>, ncols_dst>(
13221319
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
@@ -1457,13 +1454,12 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
14571454

14581455
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
14591456
constexpr size_t num_subgroups = WARP_SIZE;
1460-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1461-
1462-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1463-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1457+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1458+
const sycl::range<3> block_nums(1, 1, block_num_y);
1459+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
14641460

14651461
stream->submit([&](sycl::handler & cgh) {
1466-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1462+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
14671463
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
14681464
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
14691465
nrows, nd_item);
@@ -1478,13 +1474,14 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols(
14781474
const int stride_col_y_bytes, const int stride_col_dst,
14791475
dpct::queue_ptr stream) {
14801476
GGML_ASSERT(ncols % QK_K == 0);
1481-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1482-
constexpr size_t num_subgroups = 16;
1483-
GGML_ASSERT(block_num_y % num_subgroups == 0);
1484-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1485-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1477+
1478+
constexpr size_t num_subgroups = WARP_SIZE;
1479+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1480+
const sycl::range<3> block_nums(1, 1, block_num_y);
1481+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1482+
14861483
stream->submit([&](sycl::handler & cgh) {
1487-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1484+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
14881485
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
14891486
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>, ncols_dst>(
14901487
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
@@ -1583,15 +1580,13 @@ static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy,
15831580
const int nrows, dpct::queue_ptr stream) {
15841581
GGML_ASSERT(ncols % QK_K == 0);
15851582

1586-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1587-
constexpr size_t num_subgroups = 16;
1588-
GGML_ASSERT(block_num_y % num_subgroups == 0);
1589-
1590-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1591-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1583+
constexpr size_t num_subgroups = WARP_SIZE;
1584+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1585+
const sycl::range<3> block_nums(1, 1, block_num_y);
1586+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
15921587

15931588
stream->submit([&](sycl::handler & cgh) {
1594-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1589+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
15951590
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
15961591
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>>(vx, vy, dst, ncols,
15971592
nrows, nd_item);
@@ -1606,13 +1601,14 @@ static void reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols(
16061601
const int stride_col_y_bytes, const int stride_col_dst,
16071602
dpct::queue_ptr stream) {
16081603
GGML_ASSERT(ncols % QK_K == 0);
1609-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1610-
constexpr size_t num_subgroups = 16;
1611-
GGML_ASSERT(block_num_y % num_subgroups == 0);
1612-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1613-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1604+
1605+
constexpr size_t num_subgroups = WARP_SIZE;
1606+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1607+
const sycl::range<3> block_nums(1, 1, block_num_y);
1608+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1609+
16141610
stream->submit([&](sycl::handler & cgh) {
1615-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1611+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
16161612
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
16171613
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>, ncols_dst>(
16181614
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
@@ -1643,13 +1639,13 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
16431639
GGML_ASSERT(ncols % QK_K == 0);
16441640
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
16451641
constexpr size_t num_subgroups = WARP_SIZE;
1646-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1642+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1643+
const sycl::range<3> block_nums(1, 1, block_num_y);
1644+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
16471645

1648-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1649-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
16501646

16511647
stream->submit([&](sycl::handler & cgh) {
1652-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1648+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
16531649
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
16541650
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
16551651
nd_item);
@@ -1664,13 +1660,13 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols(
16641660
const int stride_col_y_bytes, const int stride_col_dst,
16651661
dpct::queue_ptr stream) {
16661662
GGML_ASSERT(ncols % QK_K == 0);
1667-
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1668-
constexpr size_t num_subgroups = 16;
1669-
GGML_ASSERT(block_num_y % num_subgroups == 0);
1670-
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1671-
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1663+
constexpr size_t num_subgroups = WARP_SIZE;
1664+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups);
1665+
const sycl::range<3> block_nums(1, 1, block_num_y);
1666+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1667+
16721668
stream->submit([&](sycl::handler & cgh) {
1673-
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1669+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
16741670
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
16751671
mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>, ncols_dst>(
16761672
vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);

0 commit comments

Comments
 (0)