@@ -1323,13 +1323,14 @@ template <typename T, int group_size, int bits, int D, bool batched>
13231323 uint quad_gid [[quadgroup_index_in_threadgroup]],
13241324 uint quad_lid [[thread_index_in_quadgroup]]) {
13251325 if (batched) {
1326+ int M = x_shape[x_batch_ndims];
13261327 adjust_matrix_offsets<T>(
13271328 x,
13281329 w,
13291330 scales,
13301331 biases,
13311332 y,
1332- out_vec_size,
1333+ out_vec_size * M ,
13331334 x_batch_ndims,
13341335 x_shape,
13351336 x_strides,
@@ -1374,13 +1375,14 @@ template <typename T, int group_size, int bits, bool batched>
13741375 uint simd_gid [[simdgroup_index_in_threadgroup]],
13751376 uint simd_lid [[thread_index_in_simdgroup]]) {
13761377 if (batched) {
1378+ int M = x_shape[x_batch_ndims];
13771379 adjust_matrix_offsets<T>(
13781380 x,
13791381 w,
13801382 scales,
13811383 biases,
13821384 y,
1383- out_vec_size,
1385+ out_vec_size * M ,
13841386 x_batch_ndims,
13851387 x_shape,
13861388 x_strides,
@@ -1425,13 +1427,14 @@ template <typename T, const int group_size, const int bits, bool batched>
14251427 uint simd_gid [[simdgroup_index_in_threadgroup]],
14261428 uint simd_lid [[thread_index_in_simdgroup]]) {
14271429 if (batched) {
1430+ int M = x_shape[x_batch_ndims];
14281431 adjust_matrix_offsets<T>(
14291432 x,
14301433 w,
14311434 scales,
14321435 biases,
14331436 y,
1434- out_vec_size,
1437+ out_vec_size * M ,
14351438 x_batch_ndims,
14361439 x_shape,
14371440 x_strides,
@@ -1476,13 +1479,14 @@ template <typename T, const int group_size, const int bits, bool batched>
14761479 uint simd_gid [[simdgroup_index_in_threadgroup]],
14771480 uint simd_lid [[thread_index_in_simdgroup]]) {
14781481 if (batched) {
1482+ int M = x_shape[x_batch_ndims];
14791483 adjust_matrix_offsets<T>(
14801484 x,
14811485 w,
14821486 scales,
14831487 biases,
14841488 y,
1485- out_vec_size,
1489+ out_vec_size * M ,
14861490 x_batch_ndims,
14871491 x_shape,
14881492 x_strides,
@@ -1527,13 +1531,14 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
15271531 uint3 tid [[threadgroup_position_in_grid]],
15281532 uint simd_gid [[simdgroup_index_in_threadgroup]],
15291533 uint simd_lid [[thread_index_in_simdgroup]]) {
1534+ int M = x_shape[x_batch_ndims];
15301535 adjust_matrix_offsets<T>(
15311536 x,
15321537 w,
15331538 scales,
15341539 biases,
15351540 y,
1536- out_vec_size,
1541+ out_vec_size * M ,
15371542 x_batch_ndims,
15381543 x_shape,
15391544 x_strides,
@@ -1706,6 +1711,7 @@ template <typename T, int group_size, int bits>
17061711 uint3 tid [[threadgroup_position_in_grid]],
17071712 uint simd_gid [[simdgroup_index_in_threadgroup]],
17081713 uint simd_lid [[thread_index_in_simdgroup]]) {
1714+ int M = x_shape[x_batch_ndims];
17091715 adjust_matrix_offsets<T>(
17101716 x,
17111717 w,
@@ -1714,7 +1720,7 @@ template <typename T, int group_size, int bits>
17141720 lhs_indices,
17151721 rhs_indices,
17161722 y,
1717- out_vec_size,
1723+ out_vec_size * M ,
17181724 batch_ndims,
17191725 batch_shape,
17201726 lhs_strides,
@@ -1767,6 +1773,7 @@ template <typename T, int group_size, int bits>
17671773 uint3 tid [[threadgroup_position_in_grid]],
17681774 uint simd_gid [[simdgroup_index_in_threadgroup]],
17691775 uint simd_lid [[thread_index_in_simdgroup]]) {
1776+ int M = x_shape[x_batch_ndims];
17701777 adjust_matrix_offsets<T>(
17711778 x,
17721779 w,
@@ -1775,7 +1782,7 @@ template <typename T, int group_size, int bits>
17751782 lhs_indices,
17761783 rhs_indices,
17771784 y,
1778- out_vec_size,
1785+ out_vec_size * M ,
17791786 batch_ndims,
17801787 batch_shape,
17811788 lhs_strides,
@@ -1828,6 +1835,7 @@ template <typename T, int group_size, int bits>
18281835 uint3 tid [[threadgroup_position_in_grid]],
18291836 uint simd_gid [[simdgroup_index_in_threadgroup]],
18301837 uint simd_lid [[thread_index_in_simdgroup]]) {
1838+ int M = x_shape[x_batch_ndims];
18311839 adjust_matrix_offsets<T>(
18321840 x,
18331841 w,
@@ -1836,7 +1844,7 @@ template <typename T, int group_size, int bits>
18361844 lhs_indices,
18371845 rhs_indices,
18381846 y,
1839- out_vec_size,
1847+ out_vec_size * M ,
18401848 batch_ndims,
18411849 batch_shape,
18421850 lhs_strides,
0 commit comments