Commit 4298fae
[pmap] Inline _prepare_pmap and clean up unused structs.
Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`.
PiperOrigin-RevId: 8588895831 parent b1760ad commit 4298fae
File tree
5 files changed
+441
-55
lines changed- jax/_src
- tests
- multiprocess
5 files changed
+441
-55
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1095 | 1095 | | |
1096 | 1096 | | |
1097 | 1097 | | |
| 1098 | + | |
| 1099 | + | |
1098 | 1100 | | |
| 1101 | + | |
1099 | 1102 | | |
1100 | 1103 | | |
| 1104 | + | |
1101 | 1105 | | |
| 1106 | + | |
1102 | 1107 | | |
1103 | 1108 | | |
1104 | 1109 | | |
1105 | 1110 | | |
1106 | 1111 | | |
1107 | | - | |
| 1112 | + | |
1108 | 1113 | | |
1109 | 1114 | | |
1110 | 1115 | | |
| |||
0 commit comments