Skip to content

Commit 1605eaa

Browse files
authored
misc:Remove unused k_smem_offset_w update in MLA kernel (#878)
The variable `ckv_smem_offset_w` and `kpe_smem_offset_w` are never used after update in current loop, and they will be redefined and recomputed in the next loop.I think these are redundant code.
1 parent 68a0378 commit 1605eaa

File tree

1 file changed

+0
-4
lines changed

1 file changed

+0
-4
lines changed

include/flashinfer/attention/mla_fa2.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ __device__ __forceinline__ void load_kv(
210210
warp_idx_in_wg * 4 + lane_idx / 8, 8 * mma_d + lane_idx % 8);
211211
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(ckv_smem_offset_w, ckv_ptr,
212212
q < kv_bound);
213-
ckv_smem_offset_w = ckv_smem.template advance_offset_by_column<8>(ckv_smem_offset_w, mma_d);
214213
ckv_ptr += 8 * upcast_size<DTypeKV>();
215214
}
216215

@@ -220,7 +219,6 @@ __device__ __forceinline__ void load_kv(
220219
warp_idx_in_wg * 4 + lane_idx / 8, 8 * mma_d + lane_idx % 8);
221220
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(kpe_smem_offset_w, kpe_ptr,
222221
q < kv_bound);
223-
kpe_smem_offset_w = kpe_smem.template advance_offset_by_column<8>(kpe_smem_offset_w, mma_d);
224222
kpe_ptr += 8 * upcast_size<DTypeKV>();
225223
}
226224
}
@@ -245,7 +243,6 @@ __device__ __forceinline__ void load_kv(
245243
8 * mma_d + lane_idx % 8);
246244
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(ckv_smem_offset_w, ckv_ptr,
247245
q < kv_bound);
248-
ckv_smem_offset_w = ckv_smem.template advance_offset_by_column<8>(ckv_smem_offset_w, mma_d);
249246
ckv_ptr += 8 * upcast_size<DTypeKV>();
250247
}
251248

@@ -256,7 +253,6 @@ __device__ __forceinline__ void load_kv(
256253
8 * mma_d + lane_idx % 8);
257254
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(kpe_smem_offset_w, kpe_ptr,
258255
q < kv_bound);
259-
kpe_smem_offset_w = kpe_smem.template advance_offset_by_column<8>(kpe_smem_offset_w, mma_d);
260256
kpe_ptr += 8 * upcast_size<DTypeKV>();
261257
}
262258
}

0 commit comments

Comments
 (0)