-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathrms_norm.rs
More file actions
90 lines (84 loc) · 3.85 KB
/
rms_norm.rs
File metadata and controls
90 lines (84 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use cuda_async::device_operation::DeviceOperation;
use cuda_core::CudaContext;
use cutile;
use cutile::api::{randn, zeros};
use cutile::error::Error;
use cutile::tensor::{IntoPartition, Partition, Tensor, ToHostVec};
use cutile::tile_kernel::TileKernel;
use std::sync::Arc;
#[cutile::module]
mod my_module {
use cutile::core::*;
#[cutile::entry()]
fn rms_norm<const N: i32, const BLOCK_SIZE: i32>(
x: &Tensor<f32, { [-1, N] }>,
w: &Tensor<f32, { [N] }>,
out: &mut Tensor<f32, { [1, N] }>,
eps: f32,
) {
let tile_shape: Shape<{ [1, BLOCK_SIZE] }> = const_shape![1, BLOCK_SIZE];
let num_tiles: i32 = N / BLOCK_SIZE;
// The launch grid is (M, 1, 1).
// row is a pid in [0, M).
let pid: (i32, i32, i32) = get_tile_block_id();
let row = pid.0;
let x_part: Partition<f32, { [1, BLOCK_SIZE] }> = x.partition(tile_shape);
// TODO (hme): Parse 0.0f32 syntax properly.
let mut rms: Tile<f32, { [1, BLOCK_SIZE] }> = constant(0.0, tile_shape);
for j in 0i32..num_tiles {
let tx: Tile<f32, { [1, BLOCK_SIZE] }> = x_part.load([row, j]);
rms = rms + tx * tx;
}
// TODO (hme): Try to make this something like:
// let rms = (1.0 / (rms.sum(/*axis=*/1, /*keepdims=*/true) / N + eps).sqrt()).broadcast(tile_shape);
let rms: Tile<f32, { [1] }> = reduce_sum(rms, 1i32);
let rms: Tile<f32, { [] }> = rms.reshape(const_shape![]);
let rms: f32 = tile_to_scalar(rms);
let n: f32 = convert_scalar(N);
let rms: f32 = 1.0f32 / (rms / n + eps);
let rms: Tile<f32, { [] }> = sqrt(scalar_to_tile(rms), "negative_inf");
let rms: f32 = tile_to_scalar(rms);
let rms: Tile<f32, { [1, BLOCK_SIZE] }> = rms.broadcast(tile_shape);
let w_part: Partition<f32, { [BLOCK_SIZE] }> = w.partition(const_shape![BLOCK_SIZE]);
// TODO (hme): This is a safety leak. If this partition goes out of scope, we can partition out again,
// and any memory ops will not succeed tokens corresponding to write operations (since those will also be dropped).
let mut out_part: PartitionMut<f32, { [1, BLOCK_SIZE] }> =
unsafe { out.partition_mut(tile_shape) };
for j in 0i32..num_tiles {
let tx: Tile<f32, { [1, BLOCK_SIZE] }> = x_part.load([row, j]);
let tw: Tile<f32, { [1, BLOCK_SIZE] }> = w_part.load([j]).reshape(tile_shape);
let tout: Tile<f32, { [1, BLOCK_SIZE] }> = tx * rms * tw;
unsafe { out_part.store(tout, [0i32, j]) };
}
}
}
use my_module::rms_norm_sync;
fn main() -> Result<(), Error> {
// Create a context. Device 0 is associated with the context.
let ctx = CudaContext::new(0)?;
// Create a new stream on which we run CUDA operations.
let stream = ctx.new_stream()?;
let (m, n) = (4, 8);
let block_size = 2;
let generics = vec![n.to_string(), block_size.to_string()];
let eps: f32 = 1e-8; // A sufficiently small number.
let x: Arc<Tensor<f32>> = randn(0.0, 1.0, [m, n]).sync_on(&stream)?.into();
let w: Arc<Tensor<f32>> = randn(0.0, 1.0, [n]).sync_on(&stream)?.into();
let out: Partition<Tensor<f32>> = zeros([m, n]).sync_on(&stream)?.partition([1, n as i32]);
let (_x, _w, out, _eps) = rms_norm_sync(x, w, out, eps)
.generics(generics)
.sync_on(&stream)?;
let out_host: Vec<f32> = out.unpartition().to_host_vec().sync_on(&stream)?;
for i in (0..out_host.len()).step_by(8) {
let x = out_host[i..i + 8].to_vec();
println!("{x:?}");
// let sum: f32 = x.iter().sum();
// println!("layer_norm(x).sum(axis=1) = {}", sum);
// assert!(sum.epsilon_close(1.0));
}
Ok(())
}