-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathasync_mlp_fused.rs
More file actions
136 lines (125 loc) · 4.63 KB
/
async_mlp_fused.rs
File metadata and controls
136 lines (125 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#![feature(type_alias_impl_trait)]
use cuda_async::device_context;
use cuda_async::device_context::global_policy;
use cuda_async::device_operation::*;
use cuda_async::launch::AsyncKernelLaunch;
use cuda_async::scheduling_policies::WithDeviceId;
use cuda_core::LaunchConfig;
use cutile::tensor::{Tensor, ToHostVec};
use cutile::tile_kernel::IntoDeviceOperationPartition;
use cutile::{api, error::Error};
use cutile_compiler::compiler::{CUDATileFunctionCompiler, CUDATileModules};
use cutile_compiler::cuda_tile::ModuleOperation;
use cutile_compiler::cuda_tile_runtime_utils::{compile_module, get_gpu_name};
use std::sync::Arc;
#[cutile::module]
pub mod my_kernels {
use cutile::core::*;
fn relu<const D: i32>(input: Tile<f32, { [D] }>) -> Tile<f32, { [D] }> {
let zero_tile: Tile<f32, { [D] }> = constant(0.0f32, const_shape![D]);
max_tile(zero_tile, input)
}
#[cutile::entry()]
fn fused_mlp<const BM: i32, const BN: i32, const BK: i32, const N: i32, const K: i32>(
out: &mut Tensor<f32, { [BM] }>,
data: &Tensor<f32, { [-1, -1] }>,
w0: &Tensor<f32, { [-1, K] }>,
w1: &Tensor<f32, { [K] }>,
) {
let part_data = data.partition(const_shape![BM, BN]);
let part_w0 = w0.partition(const_shape![BN, BK]);
let part_w1 = w1.partition(const_shape![BK]);
let pid: (i32, i32, i32) = get_tile_block_id();
let m = pid.0;
let mut tile_out = out.load().reshape(const_shape![BM, 1]);
for k in 0i32..(K / BK) {
// TODO (hme): Infer type from const.
let mut tile_data_x_w0: Tile<f32, { [BM, BK] }> = constant(0.0, const_shape![BM, BK]);
for n in 0i32..(N / BN) {
let tile_data = part_data.load([m, n]);
let tile_w0 = part_w0.load([n, k]);
tile_data_x_w0 = mma(tile_data, tile_w0, tile_data_x_w0);
}
let tile_w1 = part_w1.load([k]).reshape(const_shape![BK, 1]);
tile_out = mma(tile_data_x_w0, tile_w1, tile_out);
}
out.store(relu(tile_out.reshape(const_shape![BM])));
}
}
// Simulate loading input data.
fn load_data<const RANK: usize>(
batch_size: [usize; RANK],
) -> impl DeviceOperation<Output = Tensor<f32>> {
api::randn(0.0, 1.0, batch_size)
}
use my_kernels::_module_asts;
#[tokio::main(flavor = "multi_thread", worker_threads = 16)]
async fn main() -> Result<(), Error> {
// Data
let (m, n, k) = (16, 16, 16);
let (bm, bn, bk) = (4, 4, 4);
let data = load_data([m, n]).arc().await?;
let w0 = api::randn(0.0f32, 1.0, [n, k]).arc().await?; // impl DeviceOperation
let w1 = api::randn(0.0f32, 1.0, [k]).arc().await?; // impl DeviceOperation
let out = api::zeros::<1, f32>([m]).partition([bm]).await?;
// Compilation
let module_name = "my_kernels";
let function_name = "fused_mlp";
let function_entry = "fused_mlp_entry";
let modules = CUDATileModules::new(_module_asts())?;
let generics = [
bm.to_string(),
bn.to_string(),
bk.to_string(),
n.to_string(),
k.to_string(),
];
let stride_args = vec![
("out", vec![1]),
("data", vec![n as i32, 1]),
("w0", vec![k as i32, 1]),
("w1", vec![1]),
];
let compiler = CUDATileFunctionCompiler::new(
&modules,
module_name,
function_name,
&generics,
&stride_args
.iter()
.map(|x| (x.0, x.1.as_slice()))
.collect::<Vec<_>>(),
None,
get_gpu_name(0),
)?;
let module_op: ModuleOperation = compiler.compile()?;
println!("{}", module_op.as_operation().to_string());
let device = global_policy(0)?;
let module_filename = compile_module(&module_op, &get_gpu_name(device.get_device_id()));
let module = device_context::load_module_from_file(&module_filename, device.get_device_id())?;
let function = Arc::new(
module
.load_function(function_entry)
.expect("Failed to compile function."),
);
let launch_grid = (4, 1, 1);
let mut kernel_launch = AsyncKernelLaunch::new(function.clone());
kernel_launch
.push_arg(&out)
.push_arg_arc(&data)
.push_arg_arc(&w0)
.push_arg_arc(&w1)
.set_launch_config(LaunchConfig {
grid_dim: launch_grid,
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
});
kernel_launch.await?;
let host_vec = out.unpartition().to_host_vec().await?;
println!("{:?}", host_vec);
Ok(())
}