You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add grid_sample example to metal_kernel docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`
* add grid sample to docs
* zero_outputs -> init_value
* add missing header for linux
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
243
+
"""
244
+
kernel = mx.fast.metal_kernel(
245
+
name="grid_sample",
246
+
source=source,
247
+
)
248
+
outputs = kernel(
249
+
inputs={"x": x, "grid": grid},
250
+
template={"T": x.dtype},
251
+
output_shapes={"out": out_shape},
252
+
output_dtypes={"out": x.dtype},
253
+
grid=(np.prod(out_shape), 1, 1),
254
+
threadgroup=(256, 1, 1),
255
+
)
256
+
return outputs["out"]
257
+
258
+
For a reasonably sized input such as:
259
+
260
+
.. code-block:: python
261
+
262
+
x.shape = (8, 1024, 1024, 64)
263
+
grid.shape = (8, 256, 256, 2)
264
+
265
+
On an M1 Max, we see a big performance improvement:
266
+
267
+
``55.7ms -> 6.7ms => 8x speed up``
268
+
269
+
Grid Sample VJP
270
+
---------------
271
+
272
+
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
273
+
its custom vjp transform so MLX can differentiate it.
274
+
275
+
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
276
+
requires a few extra ``mx.fast.metal_kernel`` features:
277
+
278
+
* ``init_value=0``
279
+
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
280
+
281
+
* ``atomic_outputs=True``
282
+
Designate all of the kernel outputs as ``atomic`` in the function signature.
283
+
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
284
+
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
285
+
286
+
We can then implement the backwards pass as follows:
287
+
288
+
.. code-block:: python
289
+
290
+
@grid_sample.vjp
291
+
defgrid_sample_vjp(primals, cotangent, _):
292
+
x, grid = primals
293
+
B, _, _, C = x.shape
294
+
_, gN, gM, D = grid.shape
295
+
296
+
assert D ==2, "Last dim of `grid` must be size 2."
297
+
298
+
source ="""
299
+
uint elem = thread_position_in_grid.x;
300
+
int H = x_shape[1];
301
+
int W = x_shape[2];
302
+
int C = x_shape[3];
303
+
// Pad C to the nearest larger simdgroup size multiple
304
+
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
0 commit comments