-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathasync_gemm.rs
More file actions
102 lines (97 loc) · 2.96 KB
/
async_gemm.rs
File metadata and controls
102 lines (97 loc) · 2.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use cuda_async::device_operation::*;
use cutile;
use cutile::api::{self, copy_to_host};
use cutile::candle_core::WithDType;
use cutile::half::f16;
use cutile::num_traits::identities::*;
use cutile::tensor::{Tensor, ToHostVec, Unpartition};
use cutile::tile_kernel::{IntoDeviceOperationPartition, TileKernel};
use my_module::gemm_apply;
use std::fmt::Debug;
use std::sync::Arc;
#[cutile::module]
mod my_module {
use cutile::core::*;
#[cutile::entry()]
fn gemm<
E1: ElementType,
E2: ElementType,
const BM: i32,
const BN: i32,
const BK: i32,
const K: i32,
>(
z: &mut Tensor<E1, { [BM, BN] }>,
x: &Tensor<E2, { [-1, K] }>,
y: &Tensor<E2, { [K, -1] }>,
) {
let part_x = x.partition(const_shape![BM, BK]);
let part_y = y.partition(const_shape![BK, BN]);
let pid = get_tile_block_id();
let mut tile_z = z.load();
for i in 0i32..(K / BK) {
let tile_x = part_x.load([pid.0, i]);
let tile_y = part_y.load([i, pid.1]);
tile_z = mma(tile_x, tile_y, tile_z);
continue;
}
z.store(tile_z);
}
}
fn gemm<T1: WithDType + Debug, T2: WithDType + Debug>(
x: Arc<Tensor<T2>>,
y: Arc<Tensor<T2>>,
) -> impl DeviceOperation<Output = Tensor<T1>> {
let (m, n, k) = (
x.shape[0] as usize,
y.shape[1] as usize,
x.shape[1] as usize,
);
let (bm, bn, bk) = (16, 16, 8);
let generics = [
T1::DTYPE.as_str().to_string(),
T2::DTYPE.as_str().to_string(),
bm.to_string(),
bn.to_string(),
bk.to_string(),
k.to_string(),
];
let z = api::zeros([m, n]); // impl DeviceOperation
let args = zip!(
z.partition([bm, bn]),
x.device_operation(),
y.device_operation()
);
let (z, _x, _y) = args.apply(gemm_apply).generics(generics.to_vec()).unzip();
z.unpartition()
}
#[tokio::main(flavor = "multi_thread", worker_threads = 16)]
async fn main() -> Result<(), cuda_async::error::DeviceError> {
type IN = f16;
type OUT = f32;
let (m, n, k) = (64, 64, 16);
let x = api::randn(IN::zero(), IN::one(), [m, k]).arc().await?; // impl DeviceOperation
let y = api::randn(IN::zero(), IN::one(), [k, n]).arc().await?; // impl DeviceOperation
let z = gemm::<OUT, IN>(x.clone(), y.clone()).await?;
let z_host: Vec<OUT> = z.to_host_vec().await?;
let x_host = copy_to_host(&x).await?;
let y_host = copy_to_host(&y).await?;
let answer_host: Vec<f16> = x_host
.matmul(&y_host)
.unwrap()
.reshape(((),))
.unwrap()
.to_vec1()
.unwrap();
for i in 0..(m * n) as usize {
println!(
"z_host[{i}] == answer_host[{i}]? {} == {}",
z_host[i], answer_host[i]
);
}
Ok(())
}