-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathtensor_permute.rs
More file actions
153 lines (141 loc) · 5.43 KB
/
tensor_permute.rs
File metadata and controls
153 lines (141 loc) · 5.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
extern crate core;
use cuda_async::device_operation::DeviceOperation;
use cuda_core::CudaContext;
use cutile;
use cutile::api::{arange, zeros, DeviceOperationReshape};
use cutile::error::Error;
use cutile::tensor::{CopyToHost, IntoPartition, Partition, Tensor};
use cutile::tile_kernel::TileKernel;
use std::sync::Arc;
use cutile::candle_core;
use cutile::utils::pretty_print_matrix;
use my_module::tensor_permute_sync;
#[cutile::module]
mod my_module {
use cutile::core::*;
#[cutile::entry(
print_ir=true,
unchecked_accesses=false,
optimization_hints = (tensor_dim_factor = 16,)
)]
unsafe fn tensor_permute<
T: ElementType,
const BBH: i32,
const BB: i32,
const BH: i32,
const BM: i32,
const BD: i32,
const DIM_MAP: [i32; 4],
>(
src: &Tensor<T, { [-1, -1, -1, -1] }>,
dst: &mut Tensor<T, { [BBH, BD, BM] }>,
) {
let pid: (i32, i32, i32) = get_tile_block_id(); // (b/BB*h/BH, m/BM, d/BD)
// Tile dimensions BB and BH are collapsed into a single dimension.
// Partition indices corresponding to those dimensions are recovered as follows.
let h = get_shape_dim(src.shape(), 1i32);
let b_idx = pid.0 / (h / BH); // \in [0, b/BB)
let h_idx = pid.0 % (h / BH); // \in [0, h/BH)
let d_idx = pid.1; // \in [0, d/BD)
let m_idx = pid.2; // \in [0, m/BM)
// Uncomment for debugging, but choose smaller shapes (smaller launch grid).
// cuda_tile_print!("b_idx={}, h_idx={}, m_idx={}, d_idx={}\n", b_idx, h_idx, m_idx, d_idx);
// cuda_tile_print!("BB={}, BH={}, BM={}, BD={}\n", BB, BH, BM, BD);
// cuda_tile_print!("b={}, h={}, m={}, d={}\n", b, h, m, d);
// dim_map specifies a permutation of a tensor's shape.
let dim_map = const_array!(DIM_MAP);
// Specify the *permuted* dimensions as the tile argument to partition.
let src_part: Partition<T, { [BB, BH, BD, BM] }> =
src.partition_permuted(const_shape![BB, BH, BD, BM], dim_map);
// We load as-if the partition is laid out according to dim_map.
// In this example, we swapped the last two dimensions.
let src_tile: Tile<T, { [BB, BH, BD, BM] }> = src_part.load([b_idx, h_idx, d_idx, m_idx]);
// The loaded tile is permuted according to dim_map.
let src_tile = src_tile.reshape(const_shape![BBH, BD, BM]);
// Here we probably fuse various operations.
// We write the result to dst to check the answer.
dst.store(src_tile);
}
}
const BATCH: usize = 4;
const N_HEADS: usize = 32;
const N_CTX: usize = 1024;
const HEAD_DIM: usize = 64;
fn main() -> Result<(), Error> {
let ctx = CudaContext::new(0)?;
let stream = ctx.new_stream()?;
let b = BATCH; // = batch size.
let h = N_HEADS; // = number of heads (query).
let m = N_CTX; // = sequence length.
let d = HEAD_DIM; // = hidden size.
let partition = [1, 16, 128, 32];
let dim_map = [0, 1, 3, 2];
let bbh = partition[dim_map[0]] * partition[dim_map[1]];
let partition_shape_rank3 = [
bbh as i32,
partition[dim_map[2]] as i32,
partition[dim_map[3]] as i32,
];
let src: Arc<Tensor<f32>> = arange(b * h * m * d)
.reshape([b, h, m, d])
.sync_on(&stream)?
.into();
let dst: Partition<Tensor<f32>> = zeros([b * h, d, m])
.sync_on(&stream)?
.partition(partition_shape_rank3);
let mut generics: Vec<String> = [[bbh].as_slice(), partition.as_slice(), dim_map.as_slice()]
.concat()
.iter()
.map(|x| x.to_string())
.collect();
generics.insert(0, "f32".to_string());
let grid = dst.grid();
println!("in shape = {:?}", src.shape);
println!("in tile = {:?}", partition);
println!("out shape = {:?}", [b * h, d, m]);
println!("out tile = {:?}", partition_shape_rank3);
println!("grid = {:?}", grid);
println!("generics: {:?}", generics);
let (src, dst) = unsafe { tensor_permute_sync(src.clone(), dst) }
.generics(generics.clone())
.sync_on(&stream)?;
let out_host: candle_core::Tensor = dst.unpartition().copy_to_host().sync_on(&stream)?;
let answer_host = src.copy_to_host().sync_on(&stream)?;
let answer_host = answer_host
.permute((
dim_map[0] as usize,
dim_map[1] as usize,
dim_map[2] as usize,
dim_map[3] as usize,
))
.unwrap();
let answer_host = answer_host.reshape((b * h, d, m)).unwrap();
for i in 0..(b * h) {
let answer_mat = answer_host
.get_on_dim(0, i)
.expect("Failed to get {i} on dim 0.");
let out_mat = out_host
.get_on_dim(0, i)
.expect("Failed to get {i} on dim 0.");
let near_zero = (&answer_mat - &out_mat)
.unwrap()
.abs()
.unwrap()
.reshape((m * d,))
.unwrap();
let vec = near_zero.to_vec1::<f32>().unwrap();
let check = vec.iter().all(|x| x.abs() <= 1e-4);
if !check {
println!("Output:");
pretty_print_matrix::<f32>(&out_mat);
println!("Answer:");
pretty_print_matrix::<f32>(&answer_mat);
assert!(check, "output check failed.");
}
}
Ok(())
}