-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathflash_attention.rs
More file actions
256 lines (230 loc) · 10.2 KB
/
flash_attention.rs
File metadata and controls
256 lines (230 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
extern crate core;
use cuda_async::device_operation::DeviceOperation;
use cuda_core::CudaContext;
use cutile;
use cutile::api::{randn_f32, zeros};
use cutile::error::Error;
use cutile::tensor::{CopyToHost, IntoPartition, Partition, Tensor};
use cutile::tile_kernel::TileKernel;
use std::sync::Arc;
#[cutile::module]
mod my_module {
use cutile::core::*;
#[cutile::entry(print_ir=false,
unchecked_accesses=false,
optimization_hints = (
tensor_dim_factor = 16,
sm_120 = (num_cta_in_cga=1,),
))]
fn fmha<
const BM: i32, // Q sequence length partition size.
const BN: i32, // KV Sequence length partition size.
const D: i32, // Hidden size (weights).
>(
q: &Tensor<f32, { [-1, -1, -1, -1] }>, // (b, h, m, d)
k: &Tensor<f32, { [-1, -1, -1, -1] }>, // (b, hkv, n, d) where n == m
v: &Tensor<f32, { [-1, -1, -1, -1] }>, // (b, hkv, n, d) where n == m
out: &mut Tensor<f32, { [1, BM, D] }>, // (b*h, m, d)
qk_scale: f32,
query_group_size: i32,
) {
let pid: (i32, i32, i32) = get_tile_block_id(); // (b*h, m/bm, 1)
let h = get_shape_dim(q.shape(), 1i32);
let batch_idx = pid.0 / h; // \in [0, b)
let q_head_idx = pid.0 % h; // \in [0, h)
let q_m_idx = pid.1; // \in [0, m/bm)
let kv_head_idx = q_head_idx / query_group_size;
// This lets us use exp2 vs exp.
let two: Tile<f32, { [] }> = constant(2.0f32, const_shape![]);
let log2: f32 = tile_to_scalar(log(two));
let qk_scale: f32 = qk_scale / log2;
let qk_scale: Tile<f32, { [BM, BN] }> = qk_scale.broadcast(const_shape![BM, BN]);
// mask us needed for causal only.
// let mask_true: Tile<f32, {[BM, BN]}> = constant(0.0f32, const_shape![BM, BN]);
// let mask_false: Tile<f32, {[BM, BN]}> = constant(f32::NEG_INFINITY, const_shape![BM, BN]);
// offset is needed for causal only.
// let offs_n_tile: Tile<i32, {[BN]}> = iota(const_shape![BN]);
// let offs_n_tile: Tile<i32, {[BM, BN]}> = offs_n_tile.reshape(const_shape![1, BN])
// .broadcast(const_shape![BM, BN]);
// let offs_m: i32 = q_m_idx * BM;
// let offs_m: Tile<i32, {[BM]}> = offs_m.broadcast(const_shape![BM]);
// let m_arange: Tile<i32, {[BM]}> = iota(const_shape![BM]);
// let offs_m: Tile<i32, {[BM]}> = offs_m + m_arange;
// let offs_m: Tile<i32, {[BM, BN]}> = offs_m
// .reshape(const_shape![BM, 1])
// .broadcast(const_shape![BM, BN]);
// m and l are for softmax.
let mut m_i: Tile<f32, { [BM, 1] }> = constant(f32::NEG_INFINITY, const_shape![BM, 1]);
let mut l_i: Tile<f32, { [BM, 1] }> = constant(0.0f32, const_shape![BM, 1]);
// This is the output tile.
let mut acc: Tile<f32, { [BM, D] }> = constant(0.0f32, const_shape![BM, D]);
// We load just one query block per process.
let q_part: Partition<f32, { [1, 1, BM, D] }> = q.partition(const_shape![1, 1, BM, D]);
let tq: Tile<f32, { [1, 1, BM, D] }> = q_part.load([batch_idx, q_head_idx, q_m_idx, 0i32]);
let tq: Tile<f32, { [BM, D] }> = tq.reshape(const_shape![BM, D]);
let n: i32 = get_shape_dim(k.shape(), 2i32);
let num_tiles: i32 = ceil_div(n, BN);
// let mask_start: i32 = n / BN;
let k_part = k.partition(const_shape![1, 1, BN, D]); // permuted after loading.
let v_part = v.partition(const_shape![1, 1, BN, D]);
let transpose: Array<{ [1, 0] }> = Array::<{ [1, 0] }> {
dims: &[1i32, 0i32],
};
// j corresponds to tile index along key / value seq len dim.
for j in 0i32..num_tiles {
// cuda_tile_print!("batch_idx={}, kv_head_idx={}, q_m_idx={}, j={}\n",
// batch_idx, kv_head_idx, q_m_idx, j);
// Compute q @ k^T.
let k_tile: Tile<f32, { [BN, D] }> = k_part
.load([batch_idx, kv_head_idx, j, 0i32])
.reshape(const_shape![BN, D]);
let k_tile_trans: Tile<f32, { [D, BN] }> = permute(k_tile, transpose);
let qk: Tile<f32, { [BM, BN] }> = constant(0.0f32, const_shape![BM, BN]);
let qk: Tile<f32, { [BM, BN] }> = mma(tq, k_tile_trans, qk);
// Apply mask(q @ k^T).
// if j >= mask_start {
// let mask: Tile<bool, {[BM, BN]}> = constant(true, const_shape![BM, BN]);
// let offs_n: i32 = j * BN;
// let offs_n: Tile<i32, {[BM, BN]}> = offs_n.broadcast(const_shape![BM, BN]);
// let offs_n: Tile<i32, {[BM, BN]}> = offs_n + offs_n_tile;
// let mask: Tile<bool, {[BM, BN]}> = mask & ge_tile(offs_m, offs_n); // Causal only.
// let mask: Tile<f32, {[BM, BN]}> = select(mask, mask_true, mask_false);
// qk = qk + mask;
// }
// Apply scale(mask(q @ k^T)).
let qk: Tile<f32, { [BM, BN] }> = qk * qk_scale;
// Recenter before softmax.
let qk_max: Tile<f32, { [BM] }> = reduce_max(qk, 1);
let qk_max: Tile<f32, { [BM, 1] }> = qk_max.reshape(const_shape![BM, 1]);
let m_ij: Tile<f32, { [BM, 1] }> = max_tile(m_i, qk_max);
let qk = qk - m_ij.broadcast(const_shape![BM, BN]);
// Apply softmax(mask(scale(q @ k^T))).
let p: Tile<f32, { [BM, BN] }> = exp2(qk);
let l_ij: Tile<f32, { [BM] }> = reduce_sum(p, 1);
let l_ij: Tile<f32, { [BM, 1] }> = l_ij.reshape(const_shape![BM, 1]);
let alpha: Tile<f32, { [BM, 1] }> = exp2(m_i - m_ij);
l_i = l_i * alpha + l_ij;
let alpha: Tile<f32, { [BM, D] }> = alpha.broadcast(const_shape![BM, D]);
acc = acc * alpha;
// Compute softmax(mask(scale(q @ k^T))) @ v.
let v_tile: Tile<f32, { [1, 1, BN, D] }> =
v_part.load([batch_idx, kv_head_idx, j, 0i32]);
let v_tile: Tile<f32, { [BN, D] }> = v_tile.reshape(const_shape![BN, D]);
acc = mma(p, v_tile, acc);
m_i = m_ij;
}
acc = true_div(acc, l_i.broadcast(const_shape![BM, D]));
let acc = acc.reshape(const_shape![1, BM, D]);
out.store(acc);
}
}
use cutile::candle_core;
use cutile_examples::fmha_ref_exec;
use my_module::fmha_sync;
fn fmha(
b: usize, // batch size.
h: usize, // number of heads (query).
hkv: usize, // number of heads (key/value).
m: usize, // sequence length.
d: usize, // hidden size.
bm: i32, // q seq len part size.
bn: i32, // k, v seq len part size.
bbh: i32, // batch * num_heads part size.
) -> Result<(), Error> {
// Create a context. Device 0 is associated with the context.
let ctx = CudaContext::new(0)?;
// Create a new stream on which we run CUDA operations.
let stream = ctx.new_stream()?;
let seed = 123;
let q: Arc<Tensor<f32>> = randn_f32(0f32, 1., [b, h, m, d], Some(seed))
.sync_on(&stream)?
.into();
let k: Arc<Tensor<f32>> = randn_f32(0f32, 1., [b, hkv, m, d], Some(seed))
.sync_on(&stream)?
.into();
let v: Arc<Tensor<f32>> = randn_f32(0f32, 1., [b, hkv, m, d], Some(seed))
.sync_on(&stream)?
.into();
// launch grid = (b*h, m/bm, 1)
let out: Partition<Tensor<f32>> = zeros([b * h, m, d])
.sync_on(&stream)?
.partition([bbh, bm, d as i32]);
assert_eq!(out.grid()?, ((b * h) as u32, (m / bm as usize) as u32, 1));
let qk_scale = 1.0 / f32::sqrt(q.shape[3] as f32);
// This is always 1.
let num_heads = q.shape[1];
let kv_num_heads = k.shape[1];
assert_eq!(num_heads % kv_num_heads, 0);
let query_group_size = num_heads / kv_num_heads;
let generics = vec![bm.to_string(), bn.to_string(), d.to_string()];
let (_, _, _, out, _, _) = fmha_sync(
q.clone(),
k.clone(),
v.clone(),
out,
qk_scale,
query_group_size,
)
.generics(generics)
.sync_on(&stream)?;
let out_host: candle_core::Tensor = out.unpartition().copy_to_host().sync_on(&stream)?;
let answer_host = fmha_ref_exec(&q, &k, &v, qk_scale)
.reshape(((), m, d))
.expect("Failed to reshape.");
println!("out_host.shape() = {:?}", out_host.shape());
for i in 0..(b * h) {
let answer_mat = answer_host
.get_on_dim(0, i)
.expect("Failed to get {i} on dim 0.");
let out_mat = out_host
.get_on_dim(0, i)
.expect("Failed to get {i} on dim 0.");
let near_zero = (&answer_mat - &out_mat)
.unwrap()
.abs()
.unwrap()
.reshape((m * d,))
.unwrap();
let vec = near_zero.to_vec1::<f32>().unwrap();
let check = vec.iter().all(|x| x.abs() <= 1e-4);
// Looking at out_host[i, 0, :] (1d slice of last dim d, the predicted token at position 0).
let sample_dim = 0;
let sample_idx = 0;
let out_vec = out_mat
.get_on_dim(sample_dim, sample_idx)
.unwrap()
.to_vec1::<f32>()
.unwrap();
let answer_vec = answer_mat
.get_on_dim(sample_dim, sample_idx)
.unwrap()
.to_vec1::<f32>()
.unwrap();
println!("diff near zero? {check}: {:?}", vec[0]);
assert!(
check,
"output check failed. \noutput={:?} \nanswer={:?}",
out_vec, answer_vec
);
}
Ok(())
}
const BATCH: usize = 4;
const N_HEADS: usize = 32;
const HEAD_DIM: usize = 64; // or 128
const N_CTX: usize = 1024; // or some multiple of 1024
fn main() -> Result<(), Error> {
let b = BATCH; // = batch size.
let h = N_HEADS; // = number of heads (query).
let hkv = N_HEADS; // = number of heads (key/value).
let m = N_CTX; // = sequence length.
let d = HEAD_DIM; // = hidden size.
let (bm, bn) = (128, 64); // (q seq len part size, k and v seq len part size)
let bbh = 1; // batch * num_heads part size.
fmha(b, h, hkv, m, d, bm, bn, bbh)?;
Ok(())
}