|
303 | 303 | " # type (which corresponds to numpy's `float32` type), and it must be a\n", |
304 | 304 | " # static parameter (i.e. not a JAX array).\n", |
305 | 305 | " eps=np.float32(eps),\n", |
306 | | - " # The `vectorized` parameter controls this function's behavior under `vmap`\n", |
| 306 | + " # The `vmap_method` parameter controls this function's behavior under `vmap`\n", |
307 | 307 | " # as discussed below.\n", |
308 | | - " vectorized=True,\n", |
| 308 | + " vmap_method=\"broadcast_fullrank\",\n", |
309 | 309 | " )\n", |
310 | 310 | "\n", |
311 | 311 | "\n", |
|
325 | 325 | "Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n", |
326 | 326 | "Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n", |
327 | 327 | "\n", |
328 | | - "The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", |
| 328 | + "The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", |
329 | 329 | "\n", |
330 | 330 | "```{tip}\n", |
331 | 331 | "If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n", |
|
336 | 336 | "(ffi-call-vmap)=\n", |
337 | 337 | "### Batching with `vmap`\n", |
338 | 338 | "\n", |
339 | | - "All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n", |
340 | | - "By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", |
341 | | - "This default implementation is general purpose, but it doesn't parallelize very well.\n", |
342 | | - "But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n", |
| 339 | + "{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n", |
| 340 | + "The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n", |
343 | 341 | "\n", |
344 | | - "The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n", |
| 342 | + "The simplest `vmap_method` is `\"sequential\"`.\n", |
| 343 | + "In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", |
| 344 | + "This implementation is general purpose, but it doesn't parallelize very well.\n", |
| 345 | + "Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n", |
| 346 | + "\n", |
| 347 | + "In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n", |
| 348 | + "The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n", |
345 | 349 | "Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n", |
346 | 350 | "\n", |
347 | 351 | "```python\n", |
348 | 352 | "ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n", |
349 | 353 | "```\n", |
350 | 354 | "\n", |
351 | | - "Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:" |
| 355 | + "```{tip}\n", |
| 356 | + "Note that things get a bit more complicated when we have multiple input arguments.\n", |
| 357 | + "For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n", |
| 358 | + "The documentation for {func}`~jax.pure_callback` includes some examples of this\n", |
| 359 | + "```\n", |
| 360 | + "\n", |
| 361 | + "Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:" |
352 | 362 | ] |
353 | 363 | }, |
354 | 364 | { |
|
380 | 390 | "cell_type": "markdown", |
381 | 391 | "metadata": {}, |
382 | 392 | "source": [ |
383 | | - "If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:" |
| 393 | + "Using `vmap_method=\"sequential\"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:" |
384 | 394 | ] |
385 | 395 | }, |
386 | 396 | { |
|
389 | 399 | "metadata": {}, |
390 | 400 | "outputs": [], |
391 | 401 | "source": [ |
392 | | - "def rms_norm_not_vectorized(x, eps=1e-5):\n", |
| 402 | + "def rms_norm_sequential(x, eps=1e-5):\n", |
393 | 403 | " return jex.ffi.ffi_call(\n", |
394 | 404 | " \"rms_norm\",\n", |
395 | 405 | " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", |
396 | 406 | " x,\n", |
397 | 407 | " eps=np.float32(eps),\n", |
398 | | - " vectorized=False, # This is the default behavior\n", |
| 408 | + " vmap_method=\"sequential\",\n", |
399 | 409 | " )\n", |
400 | 410 | "\n", |
401 | 411 | "\n", |
402 | | - "jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)" |
| 412 | + "jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)" |
403 | 413 | ] |
404 | 414 | }, |
405 | 415 | { |
406 | 416 | "cell_type": "markdown", |
407 | 417 | "metadata": {}, |
408 | 418 | "source": [ |
409 | | - "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)." |
| 419 | + "If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)." |
410 | 420 | ] |
411 | 421 | }, |
412 | 422 | { |
|
454 | 464 | " ),\n", |
455 | 465 | " x,\n", |
456 | 466 | " eps=np.float32(eps),\n", |
457 | | - " vectorized=True,\n", |
| 467 | + " vmap_method=\"broadcast_fullrank\",\n", |
458 | 468 | " )\n", |
459 | 469 | " return y, (res, x)\n", |
460 | 470 | "\n", |
|
471 | 481 | " res,\n", |
472 | 482 | " x,\n", |
473 | 483 | " ct,\n", |
474 | | - " vectorized=True,\n", |
| 484 | + " vmap_method=\"broadcast_fullrank\",\n", |
475 | 485 | " ),\n", |
476 | 486 | " )\n", |
477 | 487 | "\n", |
|
561 | 571 | " out_type,\n", |
562 | 572 | " x,\n", |
563 | 573 | " eps=np.float32(eps),\n", |
564 | | - " vectorized=True,\n", |
| 574 | + " vmap_method=\"broadcast_fullrank\",\n", |
565 | 575 | " )\n", |
566 | 576 | "\n", |
567 | 577 | " return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n", |
|
0 commit comments