Skip to content

Commit 9d523a2

Browse files
committed
Fix flax error for nnx.vmap
1 parent 5954188 commit 9d523a2

19 files changed

Lines changed: 816 additions & 67 deletions

examples/00_getting_started.ipynb

Lines changed: 62 additions & 19 deletions
Large diffs are not rendered by default.

examples/11_life.ipynb

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,10 @@
187187
],
188188
"source": [
189189
"states = jnp.concatenate([state_init[None], states])\n",
190-
"frames = jax.vmap(ca.render)(states)\n",
190+
"frames = nnx.vmap(\n",
191+
"\tlambda ca, state: ca.render(state),\n",
192+
"\tin_axes=(None, 0),\n",
193+
")(ca, states)\n",
191194
"\n",
192195
"mediapy.show_video(frames, width=256, height=256, codec=\"gif\")"
193196
]
@@ -282,7 +285,10 @@
282285
],
283286
"source": [
284287
"states = jnp.concatenate([state_init[None], states])\n",
285-
"frames = jax.vmap(ca.render)(states)\n",
288+
"frames = nnx.vmap(\n",
289+
"\tlambda ca, state: ca.render(state),\n",
290+
"\tin_axes=(None, 0),\n",
291+
")(ca, states)\n",
286292
"\n",
287293
"mediapy.show_video(frames, width=256, height=256, codec=\"gif\")"
288294
]

examples/20_lenia.ipynb

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,10 @@
292292
],
293293
"source": [
294294
"states = jnp.concatenate([state_init[None], states])\n",
295-
"frames = jax.vmap(ca.render)(states)\n",
295+
"frames = nnx.vmap(\n",
296+
"\tlambda ca, state: ca.render(state),\n",
297+
"\tin_axes=(None, 0),\n",
298+
")(ca, states)\n",
296299
"\n",
297300
"mediapy.show_video(frames, width=256, height=256, codec=\"gif\")"
298301
]
@@ -813,7 +816,10 @@
813816
],
814817
"source": [
815818
"states = jnp.concatenate([state_init[None], states])\n",
816-
"frames = jax.vmap(ca.render)(states)\n",
819+
"frames = nnx.vmap(\n",
820+
"\tlambda ca, state: ca.render(state),\n",
821+
"\tin_axes=(None, 0),\n",
822+
")(ca, states)\n",
817823
"\n",
818824
"mediapy.show_video(frames, width=256, height=256, codec=\"gif\")"
819825
]
@@ -1023,7 +1029,10 @@
10231029
],
10241030
"source": [
10251031
"states = jnp.concatenate([state_init[None], states])\n",
1026-
"frames = jax.vmap(ca.render)(states)\n",
1032+
"frames = nnx.vmap(\n",
1033+
"\tlambda ca, state: ca.render(state),\n",
1034+
"\tin_axes=(None, 0),\n",
1035+
")(ca, states)\n",
10271036
"\n",
10281037
"mediapy.show_video(frames, width=256, height=256, codec=\"gif\")"
10291038
]
@@ -1246,7 +1255,10 @@
12461255
],
12471256
"source": [
12481257
"states = jnp.concatenate([state_init[None], states])\n",
1249-
"frames = jax.vmap(ca.render)(states)\n",
1258+
"frames = nnx.vmap(\n",
1259+
"\tlambda ca, state: ca.render(state),\n",
1260+
"\tin_axes=(None, 0),\n",
1261+
")(ca, states)\n",
12501262
"\n",
12511263
"mediapy.show_video(frames, width=256, height=256, codec=\"gif\")"
12521264
]

examples/21_flow_lenia.ipynb

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,10 @@
287287
],
288288
"source": [
289289
"states = jnp.concatenate([state_init[None], states])\n",
290-
"frames = jax.vmap(ca.render)(states)\n",
290+
"frames = nnx.vmap(\n",
291+
"\tlambda ca, state: ca.render(state),\n",
292+
"\tin_axes=(None, 0),\n",
293+
")(ca, states)\n",
291294
"\n",
292295
"mediapy.show_video(frames, width=256, height=256, codec=\"gif\")"
293296
]
@@ -481,7 +484,10 @@
481484
],
482485
"source": [
483486
"states = jnp.concatenate([state_init[None], states])\n",
484-
"frames = jax.vmap(ca.render)(states)\n",
487+
"frames = nnx.vmap(\n",
488+
"\tlambda ca, state: ca.render(state),\n",
489+
"\tin_axes=(None, 0),\n",
490+
")(ca, states)\n",
485491
"\n",
486492
"mediapy.show_video(frames, width=256, height=256)"
487493
]
@@ -701,7 +707,10 @@
701707
],
702708
"source": [
703709
"states = jnp.concatenate([state_init[None], states])\n",
704-
"frames = jax.vmap(ca.render)(states)\n",
710+
"frames = nnx.vmap(\n",
711+
"\tlambda ca, state: ca.render(state),\n",
712+
"\tin_axes=(None, 0),\n",
713+
")(ca, states)\n",
705714
"\n",
706715
"mediapy.show_video(frames, width=256, height=256, codec=\"gif\")"
707716
]
@@ -921,7 +930,10 @@
921930
],
922931
"source": [
923932
"states = jnp.concatenate([state_init[None], states])\n",
924-
"frames = jax.vmap(ca.render)(states)\n",
933+
"frames = nnx.vmap(\n",
934+
"\tlambda ca, state: ca.render(state),\n",
935+
"\tin_axes=(None, 0),\n",
936+
")(ca, states)\n",
925937
"\n",
926938
"mediapy.show_video(frames, width=256, height=256, codec=\"gif\")"
927939
]

examples/22_particle_lenia.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@
230230
],
231231
"source": [
232232
"states = jnp.concatenate([state_init[None], states])\n",
233-
"frames = jax.vmap(lambda x: ca.render(x, resolution=512, particle_radius=0.3))(states)\n",
233+
"frames = nnx.vmap(\n",
234+
"\tlambda ca, state: ca.render(state, resolution=512, particle_radius=0.3),\n",
235+
"\tin_axes=(None, 0),\n",
236+
")(ca, states)\n",
234237
"\n",
235238
"mediapy.show_video(frames, width=256, height=256, fps=600)"
236239
]
@@ -257,7 +260,7 @@
257260
"name": "python",
258261
"nbconvert_exporter": "python",
259262
"pygments_lexer": "ipython3",
260-
"version": "3.12.8"
263+
"version": "3.13.3"
261264
}
262265
},
263266
"nbformat": 4,

examples/30_particle_life.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@
209209
],
210210
"source": [
211211
"states = jax.tree.map(lambda x, xs: jnp.concatenate([x[None], xs]), state_init, states)\n",
212-
"frames = jax.vmap(lambda state: ca.render(state, particle_radius=0.003))(states)\n",
212+
"frames = nnx.vmap(\n",
213+
"\tlambda ca, state: ca.render(state, particle_radius=0.003),\n",
214+
"\tin_axes=(None, 0),\n",
215+
")(ca, states)\n",
213216
"\n",
214217
"mediapy.show_video(frames, width=512, height=512, fps=int(1 / dt))"
215218
]
@@ -231,7 +234,7 @@
231234
"name": "python",
232235
"nbconvert_exporter": "python",
233236
"pygments_lexer": "ipython3",
234-
"version": "3.12.8"
237+
"version": "3.13.3"
235238
}
236239
},
237240
"nbformat": 4,

examples/31_boids.ipynb

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,10 @@
211211
],
212212
"source": [
213213
"states = jax.tree.map(lambda x, xs: jnp.concatenate([x[None], xs]), state_init, states)\n",
214-
"frames = jax.vmap(lambda state: ca.render(state, boids_size=0.01))(states)\n",
214+
"frames = nnx.vmap(\n",
215+
"\tlambda ca, state: ca.render(state, boids_size=0.01),\n",
216+
"\tin_axes=(None, 0),\n",
217+
")(ca, states)\n",
215218
"\n",
216219
"mediapy.show_video(frames, width=512, height=512, fps=int(1 / dt))"
217220
]
@@ -447,7 +450,10 @@
447450
],
448451
"source": [
449452
"states = jax.tree.map(lambda x, xs: jnp.concatenate([x[None], xs]), state_init, states)\n",
450-
"frames = jax.vmap(lambda state: ca.render(state, boids_size=0.01))(states)\n",
453+
"frames = nnx.vmap(\n",
454+
"\tlambda ca, state: ca.render(state, boids_size=0.01),\n",
455+
"\tin_axes=(None, 0),\n",
456+
")(ca, states)\n",
451457
"\n",
452458
"mediapy.show_video(frames, width=512, height=512, fps=int(1 / dt))"
453459
]
@@ -469,7 +475,7 @@
469475
"name": "python",
470476
"nbconvert_exporter": "python",
471477
"pygments_lexer": "ipython3",
472-
"version": "3.12.8"
478+
"version": "3.13.3"
473479
}
474480
},
475481
"nbformat": 4,

examples/40_growing_nca.ipynb

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,15 @@
492492
}
493493
],
494494
"source": [
495-
"frames_final = jax.vmap(ca.render)(state_final)\n",
496-
"frames_final_rgba = jax.vmap(ca.render_rgba)(state_final)\n",
495+
"frames_final = nnx.vmap(\n",
496+
"\tlambda ca, state: ca.render(state),\n",
497+
"\tin_axes=(None, 0),\n",
498+
")(ca, state_final)\n",
499+
"\n",
500+
"frames_final_rgba = nnx.vmap(\n",
501+
"\tlambda ca, state: ca.render_rgba(state),\n",
502+
"\tin_axes=(None, 0),\n",
503+
")(ca, state_final)\n",
497504
"\n",
498505
"mediapy.show_images(frames_final, width=128, height=128)\n",
499506
"mediapy.show_images(frames_final_rgba, width=128, height=128)"
@@ -519,7 +526,10 @@
519526
],
520527
"source": [
521528
"states = jnp.concatenate([state_init[:, None], states], axis=1)\n",
522-
"frames = jax.vmap(ca.render)(states)\n",
529+
"frames = nnx.vmap(\n",
530+
"\tlambda ca, state: ca.render(state),\n",
531+
"\tin_axes=(None, 0),\n",
532+
")(ca, states)\n",
523533
"\n",
524534
"mediapy.show_videos(frames, width=128, height=128, codec=\"gif\")"
525535
]
@@ -541,7 +551,7 @@
541551
"name": "python",
542552
"nbconvert_exporter": "python",
543553
"pygments_lexer": "ipython3",
544-
"version": "3.12.8"
554+
"version": "3.13.3"
545555
}
546556
},
547557
"nbformat": 4,

examples/41_growing_conditional_nca.ipynb

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,20 @@
497497
"outputs": [],
498498
"source": [
499499
"states = jnp.concatenate([state_init[:, None], states], axis=1)\n",
500-
"frames = jax.vmap(ca.render)(states)\n",
501-
"frames_final = jax.vmap(ca.render)(state_final)\n",
502-
"frames_final_rgba = jax.vmap(ca.render_rgba)(state_final)"
500+
"frames = nnx.vmap(\n",
501+
"\tlambda ca, state: ca.render(state),\n",
502+
"\tin_axes=(None, 0),\n",
503+
")(ca, states)\n",
504+
"\n",
505+
"frames_final = nnx.vmap(\n",
506+
"\tlambda ca, state: ca.render(state),\n",
507+
"\tin_axes=(None, 0),\n",
508+
")(ca, state_final)\n",
509+
"\n",
510+
"frames_final_rgba = nnx.vmap(\n",
511+
"\tlambda ca, state: ca.render_rgba(state),\n",
512+
"\tin_axes=(None, 0),\n",
513+
")(ca, state_final)"
503514
]
504515
},
505516
{
@@ -582,7 +593,10 @@
582593
"\n",
583594
"# Visualize\n",
584595
"states = jnp.concatenate([state_init[None], states])\n",
585-
"frames = jax.vmap(ca.render)(states)\n",
596+
"frames = nnx.vmap(\n",
597+
"\tlambda ca, state: ca.render(state),\n",
598+
"\tin_axes=(None, 0),\n",
599+
")(ca, states)\n",
586600
"\n",
587601
"mediapy.show_video(frames, width=128, height=128, codec=\"gif\")"
588602
]
@@ -625,7 +639,10 @@
625639
"\n",
626640
"# Visualize\n",
627641
"states = jnp.concatenate([state_init[None], states])\n",
628-
"frames = jax.vmap(ca.render)(states)\n",
642+
"frames = nnx.vmap(\n",
643+
"\tlambda ca, state: ca.render(state),\n",
644+
"\tin_axes=(None, 0),\n",
645+
")(ca, states)\n",
629646
"\n",
630647
"mediapy.show_video(frames, width=128, height=128, codec=\"gif\")"
631648
]
@@ -685,7 +702,10 @@
685702
}
686703
],
687704
"source": [
688-
"frames = jax.vmap(ca.render_rgba)(final_states)\n",
705+
"frames = nnx.vmap(\n",
706+
"\tlambda ca, state: ca.render_rgba(state),\n",
707+
"\tin_axes=(None, 0),\n",
708+
")(ca, final_states)\n",
689709
"\n",
690710
"mediapy.show_images(frames, width=128, height=128)"
691711
]

examples/42_growing_unsupervised_nca.ipynb

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,10 @@
517517
],
518518
"source": [
519519
"states = jnp.concatenate([state_init[:, None], states], axis=1)\n",
520-
"frames = jax.vmap(ca.render)(states)\n",
520+
"frames = nnx.vmap(\n",
521+
"\tlambda ca, state: ca.render(state),\n",
522+
"\tin_axes=(None, 0),\n",
523+
")(ca, states)\n",
521524
"\n",
522525
"mediapy.show_images(y, width=128, height=128)\n",
523526
"mediapy.show_videos(frames, width=128, height=128, codec=\"gif\")"
@@ -583,7 +586,10 @@
583586
}
584587
],
585588
"source": [
586-
"frames = jax.vmap(ca.render)(state_final)\n",
589+
"frames = nnx.vmap(\n",
590+
"\tlambda ca, state: ca.render(state),\n",
591+
"\tin_axes=(None, 0),\n",
592+
")(ca, state_final)\n",
587593
"\n",
588594
"mediapy.show_images(frames, width=128, height=128)"
589595
]
@@ -605,7 +611,7 @@
605611
"name": "python",
606612
"nbconvert_exporter": "python",
607613
"pygments_lexer": "ipython3",
608-
"version": "3.12.8"
614+
"version": "3.13.3"
609615
}
610616
},
611617
"nbformat": 4,

0 commit comments

Comments
 (0)