File tree Expand file tree Collapse file tree 3 files changed +18
-6
lines changed
mlx/backend/metal/kernels/steel/gemm/kernels Expand file tree Collapse file tree 3 files changed +18
-6
lines changed Original file line number Diff line number Diff line change @@ -157,10 +157,14 @@ template <
157157 const short tm = SM * (simd_group_id / WN);
158158 const short tn = SN * (simd_group_id % WN);
159159
160- const short sgp_sm = align_M ? SM : min (SM, short (params->M - (c_row + tm)));
160+ const int sgp_sm_int =
161+ align_M ? int (SM) : min (int (SM), params->M - (c_row + tm));
162+ const short sgp_sm = short (sgp_sm_int);
161163 const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
162164
163- const short sgp_sn = align_N ? SN : min (SN, short (params->N - (c_col + tn)));
165+ const int sgp_sn_int =
166+ align_N ? int (SN) : min (int (SN), params->N - (c_col + tn));
167+ const short sgp_sn = short (sgp_sn_int);
164168 const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
165169
166170 A += transpose_a ? tm : (tm * params->lda );
Original file line number Diff line number Diff line change @@ -53,10 +53,14 @@ gather_mm_rhs_nax(
5353 const short tm = SM * (simd_group_id / WN);
5454 const short tn = SN * (simd_group_id % WN);
5555
56- const short sgp_sm = align_M ? SM : min (SM, short (params->M - (c_row + tm)));
56+ const int sgp_sm_int =
57+ align_M ? int (SM) : min (int (SM), params->M - (c_row + tm));
58+ const short sgp_sm = short (sgp_sm_int);
5759 const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
5860
59- const short sgp_sn = align_N ? SN : min (SN, short (params->N - (c_col + tn)));
61+ const int sgp_sn_int =
62+ align_N ? int (SN) : min (int (SN), params->N - (c_col + tn));
63+ const short sgp_sn = short (sgp_sn_int);
6064 const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
6165
6266 A += transpose_a ? tm : (tm * params->lda );
Original file line number Diff line number Diff line change @@ -86,10 +86,14 @@ template <
8686 const short tm = SM * (simd_group_id / WN);
8787 const short tn = SN * (simd_group_id % WN);
8888
89- const short sgp_sm = align_M ? SM : min (SM, short (params->M - (c_row + tm)));
89+ const int sgp_sm_int =
90+ align_M ? int (SM) : min (int (SM), params->M - (c_row + tm));
91+ const short sgp_sm = short (sgp_sm_int);
9092 const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
9193
92- const short sgp_sn = align_N ? SN : min (SN, short (params->N - (c_col + tn)));
94+ const int sgp_sn_int =
95+ align_N ? int (SN) : min (int (SN), params->N - (c_col + tn));
96+ const short sgp_sn = short (sgp_sn_int);
9397 const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
9498
9599 A += transpose_a ? tm : (tm * params->lda );
You can’t perform that action at this time.
0 commit comments