|
497 | 497 | "outputs": [], |
498 | 498 | "source": [ |
499 | 499 | "states = jnp.concatenate([state_init[:, None], states], axis=1)\n", |
500 | | - "frames = jax.vmap(ca.render)(states)\n", |
501 | | - "frames_final = jax.vmap(ca.render)(state_final)\n", |
502 | | - "frames_final_rgba = jax.vmap(ca.render_rgba)(state_final)" |
| 500 | + "frames = nnx.vmap(\n", |
| 501 | + "\tlambda ca, state: ca.render(state),\n", |
| 502 | + "\tin_axes=(None, 0),\n", |
| 503 | + ")(ca, states)\n", |
| 504 | + "\n", |
| 505 | + "frames_final = nnx.vmap(\n", |
| 506 | + "\tlambda ca, state: ca.render(state),\n", |
| 507 | + "\tin_axes=(None, 0),\n", |
| 508 | + ")(ca, state_final)\n", |
| 509 | + "\n", |
| 510 | + "frames_final_rgba = nnx.vmap(\n", |
| 511 | + "\tlambda ca, state: ca.render_rgba(state),\n", |
| 512 | + "\tin_axes=(None, 0),\n", |
| 513 | + ")(ca, state_final)" |
503 | 514 | ] |
504 | 515 | }, |
505 | 516 | { |
|
582 | 593 | "\n", |
583 | 594 | "# Visualize\n", |
584 | 595 | "states = jnp.concatenate([state_init[None], states])\n", |
585 | | - "frames = jax.vmap(ca.render)(states)\n", |
| 596 | + "frames = nnx.vmap(\n", |
| 597 | + "\tlambda ca, state: ca.render(state),\n", |
| 598 | + "\tin_axes=(None, 0),\n", |
| 599 | + ")(ca, states)\n", |
586 | 600 | "\n", |
587 | 601 | "mediapy.show_video(frames, width=128, height=128, codec=\"gif\")" |
588 | 602 | ] |
|
625 | 639 | "\n", |
626 | 640 | "# Visualize\n", |
627 | 641 | "states = jnp.concatenate([state_init[None], states])\n", |
628 | | - "frames = jax.vmap(ca.render)(states)\n", |
| 642 | + "frames = nnx.vmap(\n", |
| 643 | + "\tlambda ca, state: ca.render(state),\n", |
| 644 | + "\tin_axes=(None, 0),\n", |
| 645 | + ")(ca, states)\n", |
629 | 646 | "\n", |
630 | 647 | "mediapy.show_video(frames, width=128, height=128, codec=\"gif\")" |
631 | 648 | ] |
|
685 | 702 | } |
686 | 703 | ], |
687 | 704 | "source": [ |
688 | | - "frames = jax.vmap(ca.render_rgba)(final_states)\n", |
| 705 | + "frames = nnx.vmap(\n", |
| 706 | + "\tlambda ca, state: ca.render_rgba(state),\n", |
| 707 | + "\tin_axes=(None, 0),\n", |
| 708 | + ")(ca, final_states)\n", |
689 | 709 | "\n", |
690 | 710 | "mediapy.show_images(frames, width=128, height=128)" |
691 | 711 | ] |
|
0 commit comments