Skip to content

Commit 2225374

Browse files
authored
Adds mx.fast.layer_norm (#870)
1 parent 105d236 commit 2225374

File tree

11 files changed

+600
-8
lines changed

11 files changed

+600
-8
lines changed

mlx/backend/metal/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ target_sources(
3333
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
3434
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
3535
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
36-
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cpp
36+
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
3737
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
3939
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ set(
2424
"quantized"
2525
"random"
2626
"rms_norm"
27+
"layer_norm"
2728
"rope"
2829
"scan"
2930
"scaled_dot_product_attention"
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
#include <metal_common>
4+
#include <metal_simdgroup>
5+
6+
#include "mlx/backend/metal/kernels/bf16.h"
7+
#include "mlx/backend/metal/kernels/defines.h"
8+
#include "mlx/backend/metal/kernels/utils.h"
9+
10+
using namespace metal;
11+
12+
template <typename T, int N_READS = RMS_N_READS>
13+
[[kernel]] void layer_norm_single_row(
14+
const device T* x,
15+
const device T* w,
16+
const device T* b,
17+
device T* out,
18+
constant float& eps,
19+
constant uint& axis_size,
20+
constant uint& w_stride,
21+
constant uint& b_stride,
22+
uint gid [[threadgroup_position_in_grid]],
23+
uint lid [[thread_position_in_threadgroup]],
24+
uint simd_lane_id [[thread_index_in_simdgroup]],
25+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
26+
float sumx = 0;
27+
float sumx2 = 0;
28+
float thread_x[N_READS];
29+
30+
constexpr int SIMD_SIZE = 32;
31+
32+
threadgroup float local_sumx[SIMD_SIZE];
33+
threadgroup float local_sumx2[SIMD_SIZE];
34+
threadgroup float local_mean[1];
35+
threadgroup float local_normalizer[1];
36+
37+
x += gid * axis_size + lid * N_READS;
38+
w += w_stride * lid * N_READS;
39+
b += b_stride * lid * N_READS;
40+
41+
if (lid * N_READS + N_READS <= axis_size) {
42+
for (int i = 0; i < N_READS; i++) {
43+
thread_x[i] = x[i];
44+
sumx2 += thread_x[i] * thread_x[i];
45+
sumx += thread_x[i];
46+
}
47+
} else {
48+
for (int i = 0; i < N_READS; i++) {
49+
if ((lid * N_READS + i) < axis_size) {
50+
thread_x[i] = x[i];
51+
sumx2 += thread_x[i] * thread_x[i];
52+
sumx += thread_x[i];
53+
}
54+
}
55+
}
56+
57+
sumx = simd_sum(sumx);
58+
sumx2 = simd_sum(sumx2);
59+
60+
// Initialize shared memory
61+
if (simd_group_id == 0) {
62+
local_sumx[simd_lane_id] = 0;
63+
local_sumx2[simd_lane_id] = 0;
64+
}
65+
threadgroup_barrier(mem_flags::mem_threadgroup);
66+
67+
// Write simd accumulations into shared memory
68+
if (simd_lane_id == 0) {
69+
local_sumx[simd_group_id] = sumx;
70+
local_sumx2[simd_group_id] = sumx2;
71+
}
72+
threadgroup_barrier(mem_flags::mem_threadgroup);
73+
74+
// Accumulate over simd groups
75+
if (simd_group_id == 0) {
76+
sumx = simd_sum(local_sumx[simd_lane_id]);
77+
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
78+
if (simd_lane_id == 0) {
79+
float mean = sumx / axis_size;
80+
float variance = sumx2 / axis_size - mean * mean;
81+
82+
local_mean[0] = mean;
83+
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
84+
}
85+
}
86+
threadgroup_barrier(mem_flags::mem_threadgroup);
87+
88+
float mean = local_mean[0];
89+
float normalizer = local_normalizer[0];
90+
91+
// Write the outputs
92+
out += gid * axis_size + lid * N_READS;
93+
if (lid * N_READS + N_READS <= axis_size) {
94+
for (int i = 0; i < N_READS; i++) {
95+
thread_x[i] = (thread_x[i] - mean) * normalizer;
96+
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
97+
}
98+
} else {
99+
for (int i = 0; i < N_READS; i++) {
100+
if ((lid * N_READS + i) < axis_size) {
101+
thread_x[i] = (thread_x[i] - mean) * normalizer;
102+
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
103+
}
104+
}
105+
}
106+
}
107+
108+
template <typename T, int N_READS = RMS_N_READS>
109+
[[kernel]] void layer_norm_looped(
110+
const device T* x,
111+
const device T* w,
112+
const device T* b,
113+
device T* out,
114+
constant float& eps,
115+
constant uint& axis_size,
116+
constant uint& w_stride,
117+
constant uint& b_stride,
118+
uint gid [[threadgroup_position_in_grid]],
119+
uint lid [[thread_position_in_threadgroup]],
120+
uint lsize [[threads_per_threadgroup]],
121+
uint simd_lane_id [[thread_index_in_simdgroup]],
122+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
123+
float sumx = 0;
124+
float sumx2 = 0;
125+
126+
constexpr int SIMD_SIZE = 32;
127+
128+
threadgroup float local_sumx[SIMD_SIZE];
129+
threadgroup float local_sumx2[SIMD_SIZE];
130+
threadgroup float local_mean[1];
131+
threadgroup float local_normalizer[1];
132+
133+
x += gid * axis_size + lid * N_READS;
134+
w += w_stride * lid * N_READS;
135+
b += b_stride * lid * N_READS;
136+
137+
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
138+
if (r + lid * N_READS + N_READS <= axis_size) {
139+
for (int i = 0; i < N_READS; i++) {
140+
float xi = x[i + r];
141+
sumx2 += xi * xi;
142+
sumx += xi;
143+
}
144+
} else {
145+
for (int i = 0; i < N_READS; i++) {
146+
if ((r + lid * N_READS + i) < axis_size) {
147+
float xi = x[i + r];
148+
sumx2 += xi * xi;
149+
sumx += xi;
150+
}
151+
}
152+
}
153+
}
154+
155+
sumx = simd_sum(sumx);
156+
sumx2 = simd_sum(sumx2);
157+
158+
// Initialize shared memory
159+
if (simd_group_id == 0) {
160+
local_sumx[simd_lane_id] = 0;
161+
local_sumx2[simd_lane_id] = 0;
162+
}
163+
threadgroup_barrier(mem_flags::mem_threadgroup);
164+
165+
// Write simd accumulations into shared memory
166+
if (simd_lane_id == 0) {
167+
local_sumx[simd_group_id] = sumx;
168+
local_sumx2[simd_group_id] = sumx2;
169+
}
170+
threadgroup_barrier(mem_flags::mem_threadgroup);
171+
172+
// Accumulate over simd groups
173+
if (simd_group_id == 0) {
174+
sumx = simd_sum(local_sumx[simd_lane_id]);
175+
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
176+
if (simd_lane_id == 0) {
177+
float mean = sumx / axis_size;
178+
float variance = sumx2 / axis_size - mean * mean;
179+
180+
local_mean[0] = mean;
181+
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
182+
}
183+
}
184+
threadgroup_barrier(mem_flags::mem_threadgroup);
185+
186+
float mean = local_mean[0];
187+
float normalizer = local_normalizer[0];
188+
189+
// Write the outputs
190+
out += gid * axis_size + lid * N_READS;
191+
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
192+
if (r + lid * N_READS + N_READS <= axis_size) {
193+
for (int i = 0; i < N_READS; i++) {
194+
float xi = (x[r + i] - mean) * normalizer;
195+
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
196+
}
197+
} else {
198+
for (int i = 0; i < N_READS; i++) {
199+
if ((r + lid * N_READS + i) < axis_size) {
200+
float xi = (x[r + i] - mean) * normalizer;
201+
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
202+
}
203+
}
204+
}
205+
}
206+
}
207+
208+
209+
// clang-format off
210+
#define instantiate_layer_norm_single_row(name, itype) \
211+
template [[host_name("layer_norm" #name)]] [[kernel]] void \
212+
layer_norm_single_row<itype>( \
213+
const device itype* x, \
214+
const device itype* w, \
215+
const device itype* b, \
216+
device itype* out, \
217+
constant float& eps, \
218+
constant uint& axis_size, \
219+
constant uint& w_stride, \
220+
constant uint& b_stride, \
221+
uint gid [[thread_position_in_grid]], \
222+
uint lid [[thread_position_in_threadgroup]], \
223+
uint simd_lane_id [[thread_index_in_simdgroup]], \
224+
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
225+
226+
#define instantiate_layer_norm_looped(name, itype) \
227+
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
228+
layer_norm_looped<itype>( \
229+
const device itype* x, \
230+
const device itype* w, \
231+
const device itype* b, \
232+
device itype* out, \
233+
constant float& eps, \
234+
constant uint& axis_size, \
235+
constant uint& w_stride, \
236+
constant uint& b_stride, \
237+
uint gid [[thread_position_in_grid]], \
238+
uint lid [[thread_position_in_threadgroup]], \
239+
uint lsize [[threads_per_threadgroup]], \
240+
uint simd_lane_id [[thread_index_in_simdgroup]], \
241+
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
242+
243+
#define instantiate_layer_norm(name, itype) \
244+
instantiate_layer_norm_single_row(name, itype) \
245+
instantiate_layer_norm_looped(name, itype)
246+
247+
instantiate_layer_norm(float32, float)
248+
instantiate_layer_norm(float16, half)
249+
instantiate_layer_norm(bfloat16, bfloat16_t)
250+
// clang-format on
251+
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,91 @@ void RMSNorm::eval_gpu(
9595
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
9696
}
9797

98+
void LayerNorm::eval_gpu(
99+
const std::vector<array>& inputs,
100+
std::vector<array>& outputs) {
101+
auto& s = stream();
102+
auto& d = metal::device(s.device);
103+
auto& out = outputs[0];
104+
105+
// Make sure that the last dimension is contiguous
106+
std::vector<array> copies;
107+
auto check_input = [&copies, &s](const array& x) {
108+
bool no_copy = x.strides()[x.ndim() - 1] == 1;
109+
if (x.ndim() > 1) {
110+
auto s = x.strides()[x.ndim() - 2];
111+
no_copy &= (s == 0 || s == x.shape().back());
112+
}
113+
if (no_copy) {
114+
return x;
115+
} else {
116+
array x_copy(x.shape(), x.dtype(), nullptr, {});
117+
copy_gpu(x, x_copy, CopyType::General, s);
118+
copies.push_back(x_copy);
119+
return x_copy;
120+
}
121+
};
122+
const array& x = check_input(inputs[0]);
123+
const array& w = inputs[1];
124+
const array& b = inputs[2];
125+
126+
if (x.is_donatable()) {
127+
out.move_shared_buffer(x);
128+
} else {
129+
out.set_data(
130+
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
131+
x.data_size(),
132+
x.strides(),
133+
x.flags());
134+
}
135+
136+
auto axis_size = static_cast<uint32_t>(x.shape().back());
137+
int n_rows = x.data_size() / axis_size;
138+
139+
const int simd_size = 32;
140+
const int n_reads = RMS_N_READS;
141+
const int looped_limit = RMS_LOOPED_LIMIT;
142+
std::string op_name = "layer_norm";
143+
if (axis_size > looped_limit) {
144+
op_name += "_looped";
145+
}
146+
op_name += type_to_name(out);
147+
auto compute_encoder = d.get_command_encoder(s.index);
148+
{
149+
auto kernel = d.get_kernel(op_name);
150+
151+
MTL::Size grid_dims, group_dims;
152+
if (axis_size <= looped_limit) {
153+
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
154+
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
155+
size_t threadgroup_size = simd_size * simds_needed;
156+
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
157+
size_t n_threads = n_rows * threadgroup_size;
158+
grid_dims = MTL::Size(n_threads, 1, 1);
159+
group_dims = MTL::Size(threadgroup_size, 1, 1);
160+
} else {
161+
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
162+
size_t n_threads = n_rows * threadgroup_size;
163+
grid_dims = MTL::Size(n_threads, 1, 1);
164+
group_dims = MTL::Size(threadgroup_size, 1, 1);
165+
}
166+
167+
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
168+
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
169+
compute_encoder->setComputePipelineState(kernel);
170+
set_array_buffer(
171+
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
172+
set_array_buffer(compute_encoder, w, 1);
173+
set_array_buffer(compute_encoder, b, 2);
174+
set_array_buffer(compute_encoder, out, 3);
175+
compute_encoder->setBytes(&eps_, sizeof(float), 4);
176+
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
177+
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
178+
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
179+
compute_encoder->dispatchThreads(grid_dims, group_dims);
180+
}
181+
d.get_command_buffer(s.index)->addCompletedHandler(
182+
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
183+
}
184+
98185
} // namespace mlx::core::fast

mlx/backend/no_metal/primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ NO_GPU(Transpose)
102102
NO_GPU(Inverse)
103103

104104
namespace fast {
105+
NO_GPU_MULTI(LayerNorm)
105106
NO_GPU_MULTI(RMSNorm)
106107
NO_GPU_MULTI(RoPE)
107108
NO_GPU(ScaledDotProductAttention)

0 commit comments

Comments
 (0)