@@ -37,14 +37,20 @@ struct Alibi {
3737 const int col_idx_offset = col_idx_offset_ + (lane_id % 4 ) * 2 ;
3838 if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
3939 #pragma unroll
40- for (int nj = 0 ; nj < size<1 , 1 >(tensor); ++nj ) {
41- const int col_idx_base = col_idx_offset + nj * 8 ;
40+ for (int mi = 0 ; mi < size<0 , 1 >(tensor); ++mi ) {
41+ const int row_idx_base = row_idx_offset + mi * warp_row_stride ;
4242 #pragma unroll
43- for (int j = 0 ; j < size<1 , 0 >(tensor); ++j) {
44- const int col_idx = col_idx_base + j;
43+ for (int i = 0 ; i < size<0 , 0 >(tensor); ++i) {
44+ const int row_idx = row_idx_base + i * 8 ;
45+ const int col_idx_limit_right = std::min (max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q);
4546 #pragma unroll
46- for (int mi = 0 ; mi < size<0 >(tensor); ++mi) {
47- tensor (mi, make_coord (j, nj)) += alibi_slope * col_idx;
47+ for (int nj = 0 ; nj < size<1 , 1 >(tensor); ++nj) {
48+ const int col_idx_base = col_idx_offset + nj * 8 ;
49+ #pragma unroll
50+ for (int j = 0 ; j < size<1 , 0 >(tensor); ++j) {
51+ const int col_idx = col_idx_base + j;
52+ tensor (make_coord (i, mi), make_coord (j, nj)) += ((col_idx == (col_idx_limit_right - 1 )) ? 0 : alibi_slope);
53+ }
4854 }
4955 }
5056 }
@@ -61,7 +67,7 @@ struct Alibi {
6167 #pragma unroll
6268 for (int j = 0 ; j < size<1 , 0 >(tensor); ++j) {
6369 const int col_idx = col_idx_base + j;
64- tensor (make_coord (i, mi), make_coord (j, nj)) -= alibi_slope * abs ( row_idx + max_seqlen_k - max_seqlen_q - col_idx);
70+ tensor (make_coord (i, mi), make_coord (j, nj)) += ((( row_idx + max_seqlen_k - max_seqlen_q - col_idx) == 0 ) ? 0 : alibi_slope );
6571 }
6672 }
6773 }
0 commit comments