Skip to content

Commit b96e105

Browse files
authored
Add grid_sample example to metal_kernel docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel` * add grid sample to docs * zero_outputs -> init_value * add missing header for linux
1 parent 3b4d548 commit b96e105

File tree

6 files changed

+337
-15
lines changed

6 files changed

+337
-15
lines changed

docs/src/dev/custom_metal_kernels.rst

Lines changed: 292 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ The full function signature will be generated using:
4343
* The keys and shapes/dtypes of ``inputs``
4444
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
4545
so we will add ``const device float16_t* inp`` to the signature.
46-
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience.
46+
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
47+
in ``source``.
4748
* The keys and values of ``output_shapes`` and ``output_dtypes``
4849
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
4950
so we add ``device float16_t* out``.
@@ -73,7 +74,7 @@ Putting this all together, the generated function signature for ``myexp`` is as
7374
7475
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
7576
76-
You can print the generated code for a ``mx.fast.metal_kernel`` by passing ``verbose=True`` when you call it.
77+
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
7778

7879
Using Shape/Strides
7980
-------------------
@@ -121,3 +122,292 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
121122
a = a[::2]
122123
b = exp_elementwise(a)
123124
assert mx.allclose(b, mx.exp(a))
125+
126+
Complex Example
127+
-----------------------------
128+
129+
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
130+
131+
We'll start with the following MLX implementation using standard ops:
132+
133+
.. code-block:: python
134+
135+
def grid_sample_ref(x, grid):
136+
N, H_in, W_in, _ = x.shape
137+
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
138+
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
139+
140+
ix_nw = mx.floor(ix).astype(mx.int32)
141+
iy_nw = mx.floor(iy).astype(mx.int32)
142+
143+
ix_ne = ix_nw + 1
144+
iy_ne = iy_nw
145+
146+
ix_sw = ix_nw
147+
iy_sw = iy_nw + 1
148+
149+
ix_se = ix_nw + 1
150+
iy_se = iy_nw + 1
151+
152+
nw = (ix_se - ix) * (iy_se - iy)
153+
ne = (ix - ix_sw) * (iy_sw - iy)
154+
sw = (ix_ne - ix) * (iy - iy_ne)
155+
se = (ix - ix_nw) * (iy - iy_nw)
156+
157+
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
158+
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
159+
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
160+
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
161+
162+
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
163+
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
164+
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
165+
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
166+
167+
I_nw *= mask_nw[..., None]
168+
I_ne *= mask_ne[..., None]
169+
I_sw *= mask_sw[..., None]
170+
I_se *= mask_se[..., None]
171+
172+
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
173+
174+
return output
175+
176+
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
177+
to write a fast GPU kernel for both the forward and backward passes.
178+
179+
First we'll implement the forward pass as a fused kernel:
180+
181+
.. code-block:: python
182+
183+
@mx.custom_function
184+
def grid_sample(x, grid):
185+
186+
assert x.ndim == 4, "`x` must be 4D."
187+
assert grid.ndim == 4, "`grid` must be 4D."
188+
189+
B, _, _, C = x.shape
190+
_, gN, gM, D = grid.shape
191+
out_shape = (B, gN, gM, C)
192+
193+
assert D == 2, "Last dim of `grid` must be size 2."
194+
195+
source = """
196+
uint elem = thread_position_in_grid.x;
197+
int H = x_shape[1];
198+
int W = x_shape[2];
199+
int C = x_shape[3];
200+
int gH = grid_shape[1];
201+
int gW = grid_shape[2];
202+
203+
int w_stride = C;
204+
int h_stride = W * w_stride;
205+
int b_stride = H * h_stride;
206+
207+
uint grid_idx = elem / C * 2;
208+
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
209+
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
210+
211+
int ix_nw = floor(ix);
212+
int iy_nw = floor(iy);
213+
214+
int ix_ne = ix_nw + 1;
215+
int iy_ne = iy_nw;
216+
217+
int ix_sw = ix_nw;
218+
int iy_sw = iy_nw + 1;
219+
220+
int ix_se = ix_nw + 1;
221+
int iy_se = iy_nw + 1;
222+
223+
T nw = (ix_se - ix) * (iy_se - iy);
224+
T ne = (ix - ix_sw) * (iy_sw - iy);
225+
T sw = (ix_ne - ix) * (iy - iy_ne);
226+
T se = (ix - ix_nw) * (iy - iy_nw);
227+
228+
int batch_idx = elem / C / gH / gW * b_stride;
229+
int channel_idx = elem % C;
230+
int base_idx = batch_idx + channel_idx;
231+
232+
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
233+
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
234+
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
235+
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
236+
237+
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
238+
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
239+
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
240+
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
241+
242+
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
243+
"""
244+
kernel = mx.fast.metal_kernel(
245+
name="grid_sample",
246+
source=source,
247+
)
248+
outputs = kernel(
249+
inputs={"x": x, "grid": grid},
250+
template={"T": x.dtype},
251+
output_shapes={"out": out_shape},
252+
output_dtypes={"out": x.dtype},
253+
grid=(np.prod(out_shape), 1, 1),
254+
threadgroup=(256, 1, 1),
255+
)
256+
return outputs["out"]
257+
258+
For a reasonably sized input such as:
259+
260+
.. code-block:: python
261+
262+
x.shape = (8, 1024, 1024, 64)
263+
grid.shape = (8, 256, 256, 2)
264+
265+
On an M1 Max, we see a big performance improvement:
266+
267+
``55.7ms -> 6.7ms => 8x speed up``
268+
269+
Grid Sample VJP
270+
---------------
271+
272+
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
273+
its custom vjp transform so MLX can differentiate it.
274+
275+
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
276+
requires a few extra ``mx.fast.metal_kernel`` features:
277+
278+
* ``init_value=0``
279+
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
280+
281+
* ``atomic_outputs=True``
282+
Designate all of the kernel outputs as ``atomic`` in the function signature.
283+
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
284+
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
285+
286+
We can then implement the backwards pass as follows:
287+
288+
.. code-block:: python
289+
290+
@grid_sample.vjp
291+
def grid_sample_vjp(primals, cotangent, _):
292+
x, grid = primals
293+
B, _, _, C = x.shape
294+
_, gN, gM, D = grid.shape
295+
296+
assert D == 2, "Last dim of `grid` must be size 2."
297+
298+
source = """
299+
uint elem = thread_position_in_grid.x;
300+
int H = x_shape[1];
301+
int W = x_shape[2];
302+
int C = x_shape[3];
303+
// Pad C to the nearest larger simdgroup size multiple
304+
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
305+
306+
int gH = grid_shape[1];
307+
int gW = grid_shape[2];
308+
309+
int w_stride = C;
310+
int h_stride = W * w_stride;
311+
int b_stride = H * h_stride;
312+
313+
uint grid_idx = elem / C_padded * 2;
314+
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
315+
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
316+
317+
int ix_nw = floor(ix);
318+
int iy_nw = floor(iy);
319+
320+
int ix_ne = ix_nw + 1;
321+
int iy_ne = iy_nw;
322+
323+
int ix_sw = ix_nw;
324+
int iy_sw = iy_nw + 1;
325+
326+
int ix_se = ix_nw + 1;
327+
int iy_se = iy_nw + 1;
328+
329+
T nw = (ix_se - ix) * (iy_se - iy);
330+
T ne = (ix - ix_sw) * (iy_sw - iy);
331+
T sw = (ix_ne - ix) * (iy - iy_ne);
332+
T se = (ix - ix_nw) * (iy - iy_nw);
333+
334+
int batch_idx = elem / C_padded / gH / gW * b_stride;
335+
int channel_idx = elem % C_padded;
336+
int base_idx = batch_idx + channel_idx;
337+
338+
T gix = T(0);
339+
T giy = T(0);
340+
if (channel_idx < C) {
341+
int cot_index = elem / C_padded * C + channel_idx;
342+
T cot = cotangent[cot_index];
343+
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
344+
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
345+
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
346+
347+
T I_nw = x[offset];
348+
gix -= I_nw * (iy_se - iy) * cot;
349+
giy -= I_nw * (ix_se - ix) * cot;
350+
}
351+
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
352+
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
353+
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
354+
355+
T I_ne = x[offset];
356+
gix += I_ne * (iy_sw - iy) * cot;
357+
giy -= I_ne * (ix - ix_sw) * cot;
358+
}
359+
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
360+
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
361+
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
362+
363+
T I_sw = x[offset];
364+
gix -= I_sw * (iy - iy_ne) * cot;
365+
giy += I_sw * (ix_ne - ix) * cot;
366+
}
367+
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
368+
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
369+
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
370+
371+
T I_se = x[offset];
372+
gix += I_se * (iy - iy_nw) * cot;
373+
giy += I_se * (ix - ix_nw) * cot;
374+
}
375+
}
376+
377+
T gix_mult = W / 2;
378+
T giy_mult = H / 2;
379+
380+
// Reduce across each simdgroup first.
381+
// This is much faster than relying purely on atomics.
382+
gix = simd_sum(gix);
383+
giy = simd_sum(giy);
384+
385+
if (thread_index_in_simdgroup == 0) {
386+
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
387+
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
388+
}
389+
"""
390+
kernel = mx.fast.metal_kernel(
391+
name="grid_sample_grad",
392+
source=source,
393+
atomic_outputs=True,
394+
)
395+
# pad the output channels to simd group size
396+
# so that our `simd_sum`s don't overlap.
397+
simdgroup_size = 32
398+
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
399+
grid_size = B * gN * gM * C_padded
400+
outputs = kernel(
401+
inputs={"x": x, "grid": grid, "cotangent": cotangent},
402+
template={"T": x.dtype},
403+
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape},
404+
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype},
405+
grid=(grid_size, 1, 1),
406+
threadgroup=(256, 1, 1),
407+
init_value=0,
408+
)
409+
return outputs["x_grad"], outputs["grid_grad"]
410+
411+
There's an even larger speed up for the vjp:
412+
413+
``676.4ms -> 16.7ms => 40x speed up``

mlx/backend/metal/custom_kernel.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@ void CustomKernel::eval_gpu(
1212
std::vector<array>& outputs) {
1313
auto& s = stream();
1414

15+
std::vector<array> copies;
16+
1517
for (auto& out : outputs) {
1618
out.set_data(allocator::malloc_or_wait(out.nbytes()));
19+
if (init_value_) {
20+
array init = array(init_value_.value(), out.dtype());
21+
copy_gpu(init, out, CopyType::Scalar, s);
22+
copies.push_back(init);
23+
}
1724
}
1825

19-
std::vector<array> copies;
20-
2126
auto check_input = [&copies, &s, this](const array& x) -> const array {
2227
bool no_copy = x.flags().row_contiguous;
2328
if (!ensure_row_contiguous_ || no_copy) {

mlx/fast.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,7 @@ void write_signature(
949949
std::map<std::string, Dtype>& output_dtypes,
950950
std::optional<std::map<std::string, TemplateArg>> template_args,
951951
std::vector<CustomKernelShapeInfo>& shape_infos,
952+
bool atomic_outputs,
952953
std::ostringstream& kernel_source) {
953954
// Auto-generate a function signature based on `template_args`
954955
// and the dtype/shape of the arrays passed as `inputs`.
@@ -1042,8 +1043,14 @@ void write_signature(
10421043
}
10431044
// Add outputs
10441045
for (const auto& [name, dtype] : output_dtypes) {
1045-
kernel_source << " device " << get_type_string(dtype) << "* " << name
1046-
<< " [[buffer(" << index << ")]]";
1046+
kernel_source << " device ";
1047+
auto type_string = get_type_string(dtype);
1048+
if (atomic_outputs) {
1049+
kernel_source << "atomic<" << type_string << ">";
1050+
} else {
1051+
kernel_source << type_string;
1052+
}
1053+
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
10471054
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
10481055
kernel_source << "," << std::endl;
10491056
} else {
@@ -1094,6 +1101,7 @@ std::map<std::string, array> MetalKernel::operator()(
10941101
std::tuple<int, int, int> grid,
10951102
std::tuple<int, int, int> threadgroup,
10961103
std::optional<std::map<std::string, TemplateArg>> template_args,
1104+
std::optional<float> init_value,
10971105
bool verbose,
10981106
StreamOrDevice s_) {
10991107
validate_output_shapes(output_shapes, output_dtypes);
@@ -1129,6 +1137,7 @@ std::map<std::string, array> MetalKernel::operator()(
11291137
output_dtypes,
11301138
template_args,
11311139
shape_infos,
1140+
atomic_outputs_,
11321141
kernel_source);
11331142

11341143
if (needs_template) {
@@ -1174,7 +1183,8 @@ std::map<std::string, array> MetalKernel::operator()(
11741183
grid,
11751184
threadgroup,
11761185
shape_infos,
1177-
ensure_row_contiguous_),
1186+
ensure_row_contiguous_,
1187+
init_value),
11781188
in_arrs);
11791189

11801190
int i = 0;

0 commit comments

Comments
 (0)