-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathgemm.rs
More file actions
71 lines (66 loc) · 2.15 KB
/
gemm.rs
File metadata and controls
71 lines (66 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use cuda_async::device_operation::*;
use cuda_core::CudaContext;
use cutile;
use cutile::api;
use cutile::candle_core::WithDType;
use cutile::error::Error;
use cutile::tensor::*;
use cutile::tile_kernel::*;
use my_module::gemm_sync;
use std::fmt::Debug;
#[cutile::module]
mod my_module {
use cutile::core::*;
#[cutile::entry(print_ir = true)]
fn gemm<E: ElementType, const BM: i32, const BN: i32, const BK: i32, const K: i32>(
z: &mut Tensor<E, { [BM, BN] }>,
x: &Tensor<E, { [-1, K] }>,
y: &Tensor<E, { [K, -1] }>,
) {
let part_x = x.partition(const_shape![BM, BK]);
let part_y = y.partition(const_shape![BK, BN]);
let pid: (i32, i32, i32) = get_tile_block_id();
let mut tile_z = load_tile_mut(z);
for i in 0i32..(K / BK) {
let tile_x = part_x.load([pid.0, i]);
let tile_y = part_y.load([i, pid.1]);
tile_z = mma(tile_x, tile_y, tile_z);
}
z.store(tile_z);
}
}
fn gemm<T: WithDType + Debug>() -> Result<(), Error> {
let ctx = CudaContext::new(0)?;
let stream = ctx.new_stream()?;
let scale = 2usize.pow(10); // On the order of megabytes.
let (bm, bn, bk) = (16, 16, 8);
let (m, n, k) = (
scale * bm as usize,
scale * bn as usize,
scale * bk as usize,
);
let generics = vec![
T::DTYPE.as_str().to_string(),
bm.to_string(),
bn.to_string(),
bk.to_string(),
k.to_string(),
];
let z = api::zeros([m, n]).partition([bm, bn]).sync_on(&stream)?;
let x = api::ones([m, k]).arc().sync_on(&stream)?;
let y = api::ones([k, n]).arc().sync_on(&stream)?;
let launcher = gemm_sync(z, x.clone(), y.clone());
let (z, _x, _y) = launcher.generics(generics.clone()).sync_on(&stream)?;
let z_host: Vec<T> = z.unpartition().to_host_vec().sync_on(&stream)?;
for i in 0..10 {
println!("z_host[{i}] = {} answer = {}", z_host[i], k);
}
Ok(())
}
fn main() -> Result<(), Error> {
gemm::<f32>()
}