-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathflash_attention_causal.rs
More file actions
305 lines (273 loc) · 10.7 KB
/
flash_attention_causal.rs
File metadata and controls
305 lines (273 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
extern crate core;
use cuda_async::device_operation::DeviceOperation;
use cuda_core::CudaContext;
use cutile;
use cutile::api::{randn_f32, zeros};
use cutile::error::Error;
use cutile::tensor::{IntoPartition, Partition, Tensor, ToHostVec};
use cutile::tile_kernel::TileKernel;
use std::sync::Arc;
#[cutile::module]
mod my_module {
use cutile::core::*;
#[cutile::entry(print_ir=false,
unchecked_accesses=false,
optimization_hints = (
tensor_dim_factor = 16,
sm_120 = (num_cta_in_cga=1,),
))]
fn fmha<
const BM: i32, // Query sequence tile size.
const BN: i32, // KV sequence tile size.
const D: i32, // Head dimension.
const H: i32, // Number of query heads.
const CAUSAL: i32,
const EVEN_K: i32,
>(
q: &Tensor<f32, { [-1, -1, -1, -1] }>, // (b, h, m, d)
k: &Tensor<f32, { [-1, -1, -1, -1] }>, // (b, hkv, n, d)
v: &Tensor<f32, { [-1, -1, -1, -1] }>, // (b, hkv, n, d)
out: &mut Tensor<f32, { [1, BM, D] }>, // (b*h, m, d)
qk_scale: f32,
input_pos: i32,
query_group_size: i32,
) {
let pid: (i32, i32, i32) = get_tile_block_id(); // (b*h, m/BM, 1)
let batch_idx = pid.0 / H;
let q_head_idx = pid.0 % H;
let q_m_idx = pid.1;
let kv_head_idx = q_head_idx / query_group_size;
let two: Tile<f32, { [] }> = constant(2.0f32, const_shape![]);
let log2: f32 = tile_to_scalar(log(two));
let qk_scale: Tile<f32, { [BM, BN] }> =
broadcast_scalar(qk_scale / log2, const_shape![BM, BN]);
let mut m_i: Tile<f32, { [BM, 1] }> = constant(f32::NEG_INFINITY, const_shape![BM, 1]);
let mut l_i: Tile<f32, { [BM, 1] }> = constant(0.0f32, const_shape![BM, 1]);
let mut acc: Tile<f32, { [BM, D] }> = constant(0.0f32, const_shape![BM, D]);
let q_part: Partition<f32, { [1, 1, BM, D] }> = q.partition(const_shape![1, 1, BM, D]);
let tq: Tile<f32, { [BM, D] }> = q_part
.load([batch_idx, q_head_idx, q_m_idx, 0i32])
.reshape(const_shape![BM, D]);
let k_seqlen: i32 = get_shape_dim(k.shape(), 2i32);
let m_end: i32 = input_pos + (q_m_idx + 1i32) * BM;
let mut mask_start: i32 = k_seqlen / BN;
let mut tc: i32 = ceil_div(k_seqlen, BN);
if CAUSAL == 1i32 {
mask_start = (input_pos + q_m_idx * BM) / BN;
let k_seqlen_tiles = k_seqlen / BN;
mask_start = min(mask_start, k_seqlen_tiles);
tc = ceil_div(min(m_end, k_seqlen), BN);
}
let k_part = k.partition(const_shape![1, 1, BN, D]);
let v_part = v.partition(const_shape![1, 1, BN, D]);
let transpose: Array<{ [1, 0] }> = Array::<{ [1, 0] }> {
dims: &[1i32, 0i32],
};
let offs_n_tile: Tile<i32, { [BN] }> = iota(const_shape![BN]);
let offs_n_tile: Tile<i32, { [BM, BN] }> = offs_n_tile
.reshape(const_shape![1, BN])
.broadcast(const_shape![BM, BN]);
let offs_m_iota: Tile<i32, { [BM] }> = iota(const_shape![BM]);
let offs_m_iota = offs_m_iota.reshape(const_shape![BM, 1]);
let offs_m: Tile<i32, { [BM, 1] }> =
broadcast_scalar(q_m_idx * BM + input_pos, const_shape![BM, 1]) + offs_m_iota;
let offs_m: Tile<i32, { [BM, BN] }> = offs_m.broadcast(const_shape![BM, BN]);
let k_seqlen_tile: Tile<i32, { [BM, BN] }> = k_seqlen.broadcast(const_shape![BM, BN]);
let mask_true: Tile<f32, { [BM, BN] }> = constant(0.0f32, const_shape![BM, BN]);
let mask_false: Tile<f32, { [BM, BN] }> = constant(f32::NEG_INFINITY, const_shape![BM, BN]);
for j in 0i32..tc {
let k_tile: Tile<f32, { [BN, D] }> = k_part
.load([batch_idx, kv_head_idx, j, 0i32])
.reshape(const_shape![BN, D]);
let k_tile_trans: Tile<f32, { [D, BN] }> = permute(k_tile, transpose);
let mut qk: Tile<f32, { [BM, BN] }> = constant(0.0f32, const_shape![BM, BN]);
qk = mma(tq, k_tile_trans, qk);
if (CAUSAL == 1i32 || EVEN_K == 0i32) && j >= mask_start {
let offs_n: Tile<i32, { [BM, BN] }> =
broadcast_scalar(j * BN, const_shape![BM, BN]) + offs_n_tile;
let mut mask: Tile<bool, { [BM, BN] }> = constant(true, const_shape![BM, BN]);
if EVEN_K == 0i32 {
let lt_res: Tile<bool, { [BM, BN] }> = lt_tile(offs_n, k_seqlen_tile);
mask = mask & lt_res;
}
if CAUSAL == 1i32 {
let ge_res: Tile<bool, { [BM, BN] }> = ge_tile(offs_m, offs_n);
mask = mask & ge_res;
}
qk = qk + select(mask, mask_true, mask_false);
}
qk = qk * qk_scale;
let qk_max: Tile<f32, { [BM] }> = reduce_max(qk, 1);
let qk_max: Tile<f32, { [BM, 1] }> = qk_max.reshape(const_shape![BM, 1]);
let m_ij: Tile<f32, { [BM, 1] }> = max_tile(m_i, qk_max);
let qk = qk - m_ij.broadcast(const_shape![BM, BN]);
let p: Tile<f32, { [BM, BN] }> = exp2(qk);
let l_ij: Tile<f32, { [BM] }> = reduce_sum(p, 1);
let l_ij: Tile<f32, { [BM, 1] }> = l_ij.reshape(const_shape![BM, 1]);
let alpha: Tile<f32, { [BM, 1] }> = exp2(m_i - m_ij);
l_i = l_i * alpha + l_ij;
acc = acc * alpha.broadcast(const_shape![BM, D]);
let v_tile: Tile<f32, { [BN, D] }> = v_part
.load([batch_idx, kv_head_idx, j, 0i32])
.reshape(const_shape![BN, D]);
acc = mma(p, v_tile, acc);
m_i = m_ij;
}
out.store(
true_div(acc, l_i.broadcast(const_shape![BM, D])).reshape(const_shape![1, BM, D]),
);
}
}
use my_module::fmha_sync;
fn idx4(
a: usize,
b: usize,
c: usize,
d: usize,
_bsz: usize,
hsz: usize,
m: usize,
dsz: usize,
) -> usize {
(((a * hsz + b) * m + c) * dsz) + d
}
fn fmha_ref_cpu(
q: &[f32],
k: &[f32],
v: &[f32],
bsz: usize,
hsz: usize,
hkv: usize,
m: usize,
dsz: usize,
qk_scale: f32,
input_pos: usize,
query_group_size: usize,
causal: bool,
) -> Vec<f32> {
let mut out = vec![0.0f32; bsz * hsz * m * dsz];
let mut scores = vec![0.0f32; m];
for b in 0..bsz {
for h in 0..hsz {
let kv_h = h / query_group_size;
debug_assert!(kv_h < hkv);
for i in 0..m {
let mut max_logit = f32::NEG_INFINITY;
for (j, score) in scores.iter_mut().enumerate().take(m) {
if causal && j > i + input_pos {
*score = f32::NEG_INFINITY;
continue;
}
let mut dot = 0.0f32;
for dd in 0..dsz {
dot += q[idx4(b, h, i, dd, bsz, hsz, m, dsz)]
* k[idx4(b, kv_h, j, dd, bsz, hkv, m, dsz)];
}
*score = dot * qk_scale;
max_logit = max_logit.max(*score);
}
let mut denom = 0.0f32;
for score in scores.iter_mut().take(m) {
*score = (*score - max_logit).exp();
denom += *score;
}
for dd in 0..dsz {
let mut val = 0.0f32;
for (j, score) in scores.iter().enumerate().take(m) {
let p = *score / denom;
val += p * v[idx4(b, kv_h, j, dd, bsz, hkv, m, dsz)];
}
out[idx4(b, h, i, dd, bsz, hsz, m, dsz)] = val;
}
}
}
}
out
}
fn run_attention_fmha(causal: bool) -> Result<(), Error> {
let ctx = CudaContext::new(0)?;
let stream = ctx.new_stream()?;
let (batch, heads, heads_kv, seq_len, head_dim) = (2usize, 8usize, 8usize, 64usize, 32usize);
let (bm, bn) = (32i32, 32i32);
let bbh = 1i32;
let input_pos = if causal { 5i32 } else { 0i32 };
let even_k = seq_len as i32 % bn == 0;
let qk_scale = 1.0 / f32::sqrt(head_dim as f32);
let query_group_size = (heads / heads_kv) as i32;
let q: Arc<Tensor<f32>> = randn_f32(0.0, 1.0, [batch, heads, seq_len, head_dim], Some(7))
.sync_on(&stream)?
.into();
let k: Arc<Tensor<f32>> = randn_f32(0.0, 1.0, [batch, heads_kv, seq_len, head_dim], Some(11))
.sync_on(&stream)?
.into();
let v: Arc<Tensor<f32>> = randn_f32(0.0, 1.0, [batch, heads_kv, seq_len, head_dim], Some(13))
.sync_on(&stream)?
.into();
let out: Partition<Tensor<f32>> = zeros([batch * heads, seq_len, head_dim])
.sync_on(&stream)?
.partition([bbh, bm, head_dim as i32]);
let generics = vec![
bm.to_string(),
bn.to_string(),
head_dim.to_string(),
heads.to_string(),
(causal as i32).to_string(),
(even_k as i32).to_string(),
];
let (_, _, _, out, _, _, _): (_, _, _, Partition<Tensor<f32>>, _, _, _) = fmha_sync(
q.clone(),
k.clone(),
v.clone(),
out,
qk_scale,
input_pos,
query_group_size,
)
.generics(generics)
.sync_on(&stream)?;
let out_host: Vec<f32> = out.unpartition().to_host_vec().sync_on(&stream)?;
let q_host: Vec<f32> = q.to_host_vec().sync_on(&stream)?;
let k_host: Vec<f32> = k.to_host_vec().sync_on(&stream)?;
let v_host: Vec<f32> = v.to_host_vec().sync_on(&stream)?;
let ref_host: Vec<f32> = fmha_ref_cpu(
&q_host,
&k_host,
&v_host,
batch,
heads,
heads_kv,
seq_len,
head_dim,
qk_scale,
input_pos as usize,
query_group_size as usize,
causal,
);
let atol = 2e-3f32;
let rtol = 2e-3f32;
for iterand in out_host.iter().zip(ref_host.iter()).enumerate() {
let (i, (got, exp)): (usize, (&f32, &f32)) = iterand;
let ok = (got - exp).abs() <= atol + rtol * exp.abs();
if !ok {
println!("FMHA causal mismatch at idx={i}: got={got}, exp={exp}");
}
}
println!(
"attention_fmha ({}) check passed: shape=[{}, {}, {}, {}]",
if causal { "causal" } else { "non-causal" },
batch,
heads,
seq_len,
head_dim
);
Ok(())
}
fn main() -> Result<(), Error> {
run_attention_fmha(false)?;
run_attention_fmha(true)?;
Ok(())
}