-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathbatch_matmul.rs
More file actions
82 lines (71 loc) · 2.72 KB
/
batch_matmul.rs
File metadata and controls
82 lines (71 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use cuda_async::device_operation::DeviceOperation;
use cuda_core::CudaContext;
use cutile;
use cutile::api::{ones, zeros};
use cutile::error::Error;
use cutile::tensor::{IntoPartition, Partition, Tensor, ToHostVec};
use cutile::tile_kernel::TileKernel;
use std::sync::Arc;
#[cutile::module]
mod my_module {
use cutile::core::*;
#[cutile::entry()]
fn batch_matmul<E: ElementType, const BM: i32, const BN: i32, const BK: i32, const K: i32>(
a: &Tensor<E, { [-1, -1, K] }>,
b: &Tensor<E, { [-1, K, -1] }>,
c: &mut Tensor<E, { [1, BM, BN] }>,
) {
let pid: (i32, i32, i32) = get_tile_block_id(); // (batch_idx, m_idx, n_idx)
let batch_idx = pid.0;
let m_idx = pid.1;
let n_idx = pid.2;
let a_part: Partition<E, { [1, BM, BK] }> = a.partition(const_shape![1, BM, BK]);
let b_part: Partition<E, { [1, BK, BN] }> = b.partition(const_shape![1, BK, BN]);
let acc_val: E = convert_scalar(0i32);
let mut acc: Tile<E, { [BM, BN] }> = broadcast_scalar(acc_val, const_shape![BM, BN]);
for k_idx in 0i32..(K / BK) {
let a_tile: Tile<E, { [BM, BK] }> = a_part
.load([batch_idx, m_idx, k_idx])
.reshape(const_shape![BM, BK]);
let b_tile: Tile<E, { [BK, BN] }> = b_part
.load([batch_idx, k_idx, n_idx])
.reshape(const_shape![BK, BN]);
acc = mma(a_tile, b_tile, acc);
}
c.store(acc.reshape(const_shape![1, BM, BN]));
}
}
use my_module::batch_matmul_sync;
fn main() -> Result<(), Error> {
let ctx = CudaContext::new(0)?;
let stream = ctx.new_stream()?;
let batch = 4usize;
let (m, n, k) = (128usize, 256usize, 64usize);
let (bm, bn, bk) = (64i32, 64i32, 32i32);
let a: Arc<Tensor<f32>> = ones([batch, m, k]).sync_on(&stream)?.into();
let b: Arc<Tensor<f32>> = ones([batch, k, n]).sync_on(&stream)?.into();
let c: Partition<Tensor<f32>> = zeros([batch, m, n])
.sync_on(&stream)?
.partition([1, bm, bn]);
let generics = vec![
"f32".to_string(),
bm.to_string(),
bn.to_string(),
bk.to_string(),
k.to_string(),
];
let (_a, _b, c) = batch_matmul_sync(a, b, c)
.generics(generics)
.sync_on(&stream)?;
let c_host: Vec<f32> = c.unpartition().to_host_vec().sync_on(&stream)?;
let expected = k as f32;
for (idx, value) in c_host.iter().enumerate().take(10) {
println!("c_host[{idx}] = {value}, expected = {expected}");
assert!((value - expected).abs() <= 1e-3);
}
Ok(())
}