|
| 1 | +// Copyright © 2024 Apple Inc. |
| 2 | + |
| 3 | +#include <metal_common> |
| 4 | +#include <metal_simdgroup> |
| 5 | + |
| 6 | +#include "mlx/backend/metal/kernels/bf16.h" |
| 7 | +#include "mlx/backend/metal/kernels/defines.h" |
| 8 | +#include "mlx/backend/metal/kernels/utils.h" |
| 9 | + |
| 10 | +using namespace metal; |
| 11 | + |
| 12 | +template <typename T, int N_READS = RMS_N_READS> |
| 13 | +[[kernel]] void layer_norm_single_row( |
| 14 | + const device T* x, |
| 15 | + const device T* w, |
| 16 | + const device T* b, |
| 17 | + device T* out, |
| 18 | + constant float& eps, |
| 19 | + constant uint& axis_size, |
| 20 | + constant uint& w_stride, |
| 21 | + constant uint& b_stride, |
| 22 | + uint gid [[threadgroup_position_in_grid]], |
| 23 | + uint lid [[thread_position_in_threadgroup]], |
| 24 | + uint simd_lane_id [[thread_index_in_simdgroup]], |
| 25 | + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { |
| 26 | + float sumx = 0; |
| 27 | + float sumx2 = 0; |
| 28 | + float thread_x[N_READS]; |
| 29 | + |
| 30 | + constexpr int SIMD_SIZE = 32; |
| 31 | + |
| 32 | + threadgroup float local_sumx[SIMD_SIZE]; |
| 33 | + threadgroup float local_sumx2[SIMD_SIZE]; |
| 34 | + threadgroup float local_mean[1]; |
| 35 | + threadgroup float local_normalizer[1]; |
| 36 | + |
| 37 | + x += gid * axis_size + lid * N_READS; |
| 38 | + w += w_stride * lid * N_READS; |
| 39 | + b += b_stride * lid * N_READS; |
| 40 | + |
| 41 | + if (lid * N_READS + N_READS <= axis_size) { |
| 42 | + for (int i = 0; i < N_READS; i++) { |
| 43 | + thread_x[i] = x[i]; |
| 44 | + sumx2 += thread_x[i] * thread_x[i]; |
| 45 | + sumx += thread_x[i]; |
| 46 | + } |
| 47 | + } else { |
| 48 | + for (int i = 0; i < N_READS; i++) { |
| 49 | + if ((lid * N_READS + i) < axis_size) { |
| 50 | + thread_x[i] = x[i]; |
| 51 | + sumx2 += thread_x[i] * thread_x[i]; |
| 52 | + sumx += thread_x[i]; |
| 53 | + } |
| 54 | + } |
| 55 | + } |
| 56 | + |
| 57 | + sumx = simd_sum(sumx); |
| 58 | + sumx2 = simd_sum(sumx2); |
| 59 | + |
| 60 | + // Initialize shared memory |
| 61 | + if (simd_group_id == 0) { |
| 62 | + local_sumx[simd_lane_id] = 0; |
| 63 | + local_sumx2[simd_lane_id] = 0; |
| 64 | + } |
| 65 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 66 | + |
| 67 | + // Write simd accumulations into shared memory |
| 68 | + if (simd_lane_id == 0) { |
| 69 | + local_sumx[simd_group_id] = sumx; |
| 70 | + local_sumx2[simd_group_id] = sumx2; |
| 71 | + } |
| 72 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 73 | + |
| 74 | + // Accumulate over simd groups |
| 75 | + if (simd_group_id == 0) { |
| 76 | + sumx = simd_sum(local_sumx[simd_lane_id]); |
| 77 | + sumx2 = simd_sum(local_sumx2[simd_lane_id]); |
| 78 | + if (simd_lane_id == 0) { |
| 79 | + float mean = sumx / axis_size; |
| 80 | + float variance = sumx2 / axis_size - mean * mean; |
| 81 | + |
| 82 | + local_mean[0] = mean; |
| 83 | + local_normalizer[0] = metal::precise::rsqrt(variance + eps); |
| 84 | + } |
| 85 | + } |
| 86 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 87 | + |
| 88 | + float mean = local_mean[0]; |
| 89 | + float normalizer = local_normalizer[0]; |
| 90 | + |
| 91 | + // Write the outputs |
| 92 | + out += gid * axis_size + lid * N_READS; |
| 93 | + if (lid * N_READS + N_READS <= axis_size) { |
| 94 | + for (int i = 0; i < N_READS; i++) { |
| 95 | + thread_x[i] = (thread_x[i] - mean) * normalizer; |
| 96 | + out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i]; |
| 97 | + } |
| 98 | + } else { |
| 99 | + for (int i = 0; i < N_READS; i++) { |
| 100 | + if ((lid * N_READS + i) < axis_size) { |
| 101 | + thread_x[i] = (thread_x[i] - mean) * normalizer; |
| 102 | + out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i]; |
| 103 | + } |
| 104 | + } |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +template <typename T, int N_READS = RMS_N_READS> |
| 109 | +[[kernel]] void layer_norm_looped( |
| 110 | + const device T* x, |
| 111 | + const device T* w, |
| 112 | + const device T* b, |
| 113 | + device T* out, |
| 114 | + constant float& eps, |
| 115 | + constant uint& axis_size, |
| 116 | + constant uint& w_stride, |
| 117 | + constant uint& b_stride, |
| 118 | + uint gid [[threadgroup_position_in_grid]], |
| 119 | + uint lid [[thread_position_in_threadgroup]], |
| 120 | + uint lsize [[threads_per_threadgroup]], |
| 121 | + uint simd_lane_id [[thread_index_in_simdgroup]], |
| 122 | + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { |
| 123 | + float sumx = 0; |
| 124 | + float sumx2 = 0; |
| 125 | + |
| 126 | + constexpr int SIMD_SIZE = 32; |
| 127 | + |
| 128 | + threadgroup float local_sumx[SIMD_SIZE]; |
| 129 | + threadgroup float local_sumx2[SIMD_SIZE]; |
| 130 | + threadgroup float local_mean[1]; |
| 131 | + threadgroup float local_normalizer[1]; |
| 132 | + |
| 133 | + x += gid * axis_size + lid * N_READS; |
| 134 | + w += w_stride * lid * N_READS; |
| 135 | + b += b_stride * lid * N_READS; |
| 136 | + |
| 137 | + for (uint r = 0; r < axis_size; r += lsize * N_READS) { |
| 138 | + if (r + lid * N_READS + N_READS <= axis_size) { |
| 139 | + for (int i = 0; i < N_READS; i++) { |
| 140 | + float xi = x[i + r]; |
| 141 | + sumx2 += xi * xi; |
| 142 | + sumx += xi; |
| 143 | + } |
| 144 | + } else { |
| 145 | + for (int i = 0; i < N_READS; i++) { |
| 146 | + if ((r + lid * N_READS + i) < axis_size) { |
| 147 | + float xi = x[i + r]; |
| 148 | + sumx2 += xi * xi; |
| 149 | + sumx += xi; |
| 150 | + } |
| 151 | + } |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + sumx = simd_sum(sumx); |
| 156 | + sumx2 = simd_sum(sumx2); |
| 157 | + |
| 158 | + // Initialize shared memory |
| 159 | + if (simd_group_id == 0) { |
| 160 | + local_sumx[simd_lane_id] = 0; |
| 161 | + local_sumx2[simd_lane_id] = 0; |
| 162 | + } |
| 163 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 164 | + |
| 165 | + // Write simd accumulations into shared memory |
| 166 | + if (simd_lane_id == 0) { |
| 167 | + local_sumx[simd_group_id] = sumx; |
| 168 | + local_sumx2[simd_group_id] = sumx2; |
| 169 | + } |
| 170 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 171 | + |
| 172 | + // Accumulate over simd groups |
| 173 | + if (simd_group_id == 0) { |
| 174 | + sumx = simd_sum(local_sumx[simd_lane_id]); |
| 175 | + sumx2 = simd_sum(local_sumx2[simd_lane_id]); |
| 176 | + if (simd_lane_id == 0) { |
| 177 | + float mean = sumx / axis_size; |
| 178 | + float variance = sumx2 / axis_size - mean * mean; |
| 179 | + |
| 180 | + local_mean[0] = mean; |
| 181 | + local_normalizer[0] = metal::precise::rsqrt(variance + eps); |
| 182 | + } |
| 183 | + } |
| 184 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 185 | + |
| 186 | + float mean = local_mean[0]; |
| 187 | + float normalizer = local_normalizer[0]; |
| 188 | + |
| 189 | + // Write the outputs |
| 190 | + out += gid * axis_size + lid * N_READS; |
| 191 | + for (uint r = 0; r < axis_size; r += lsize * N_READS) { |
| 192 | + if (r + lid * N_READS + N_READS <= axis_size) { |
| 193 | + for (int i = 0; i < N_READS; i++) { |
| 194 | + float xi = (x[r + i] - mean) * normalizer; |
| 195 | + out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)]; |
| 196 | + } |
| 197 | + } else { |
| 198 | + for (int i = 0; i < N_READS; i++) { |
| 199 | + if ((r + lid * N_READS + i) < axis_size) { |
| 200 | + float xi = (x[r + i] - mean) * normalizer; |
| 201 | + out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)]; |
| 202 | + } |
| 203 | + } |
| 204 | + } |
| 205 | + } |
| 206 | +} |
| 207 | + |
| 208 | + |
| 209 | +// clang-format off |
| 210 | +#define instantiate_layer_norm_single_row(name, itype) \ |
| 211 | + template [[host_name("layer_norm" #name)]] [[kernel]] void \ |
| 212 | + layer_norm_single_row<itype>( \ |
| 213 | + const device itype* x, \ |
| 214 | + const device itype* w, \ |
| 215 | + const device itype* b, \ |
| 216 | + device itype* out, \ |
| 217 | + constant float& eps, \ |
| 218 | + constant uint& axis_size, \ |
| 219 | + constant uint& w_stride, \ |
| 220 | + constant uint& b_stride, \ |
| 221 | + uint gid [[thread_position_in_grid]], \ |
| 222 | + uint lid [[thread_position_in_threadgroup]], \ |
| 223 | + uint simd_lane_id [[thread_index_in_simdgroup]], \ |
| 224 | + uint simd_group_id [[simdgroup_index_in_threadgroup]]); |
| 225 | + |
| 226 | +#define instantiate_layer_norm_looped(name, itype) \ |
| 227 | + template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \ |
| 228 | + layer_norm_looped<itype>( \ |
| 229 | + const device itype* x, \ |
| 230 | + const device itype* w, \ |
| 231 | + const device itype* b, \ |
| 232 | + device itype* out, \ |
| 233 | + constant float& eps, \ |
| 234 | + constant uint& axis_size, \ |
| 235 | + constant uint& w_stride, \ |
| 236 | + constant uint& b_stride, \ |
| 237 | + uint gid [[thread_position_in_grid]], \ |
| 238 | + uint lid [[thread_position_in_threadgroup]], \ |
| 239 | + uint lsize [[threads_per_threadgroup]], \ |
| 240 | + uint simd_lane_id [[thread_index_in_simdgroup]], \ |
| 241 | + uint simd_group_id [[simdgroup_index_in_threadgroup]]); |
| 242 | + |
| 243 | +#define instantiate_layer_norm(name, itype) \ |
| 244 | + instantiate_layer_norm_single_row(name, itype) \ |
| 245 | + instantiate_layer_norm_looped(name, itype) |
| 246 | + |
| 247 | +instantiate_layer_norm(float32, float) |
| 248 | +instantiate_layer_norm(float16, half) |
| 249 | +instantiate_layer_norm(bfloat16, bfloat16_t) |
| 250 | + // clang-format on |
| 251 | + |
0 commit comments