Skip to content

Commit adcbb91

Browse files
authored
Fix for NAX overflow. (#3092)
1 parent b56782b commit adcbb91

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,14 @@ template <
157157
const short tm = SM * (simd_group_id / WN);
158158
const short tn = SN * (simd_group_id % WN);
159159

160-
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
160+
const int sgp_sm_int =
161+
align_M ? int(SM) : min(int(SM), params->M - (c_row + tm));
162+
const short sgp_sm = short(sgp_sm_int);
161163
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
162164

163-
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
165+
const int sgp_sn_int =
166+
align_N ? int(SN) : min(int(SN), params->N - (c_col + tn));
167+
const short sgp_sn = short(sgp_sn_int);
164168
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
165169

166170
A += transpose_a ? tm : (tm * params->lda);

mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,14 @@ gather_mm_rhs_nax(
5353
const short tm = SM * (simd_group_id / WN);
5454
const short tn = SN * (simd_group_id % WN);
5555

56-
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
56+
const int sgp_sm_int =
57+
align_M ? int(SM) : min(int(SM), params->M - (c_row + tm));
58+
const short sgp_sm = short(sgp_sm_int);
5759
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
5860

59-
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
61+
const int sgp_sn_int =
62+
align_N ? int(SN) : min(int(SN), params->N - (c_col + tn));
63+
const short sgp_sn = short(sgp_sn_int);
6064
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
6165

6266
A += transpose_a ? tm : (tm * params->lda);

mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,14 @@ template <
8686
const short tm = SM * (simd_group_id / WN);
8787
const short tn = SN * (simd_group_id % WN);
8888

89-
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
89+
const int sgp_sm_int =
90+
align_M ? int(SM) : min(int(SM), params->M - (c_row + tm));
91+
const short sgp_sm = short(sgp_sm_int);
9092
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
9193

92-
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
94+
const int sgp_sn_int =
95+
align_N ? int(SN) : min(int(SN), params->N - (c_col + tn));
96+
const short sgp_sn = short(sgp_sn_int);
9397
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
9498

9599
A += transpose_a ? tm : (tm * params->lda);

0 commit comments

Comments
 (0)