Skip to content

Commit 9c425e3

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
[pmap] Cache computing effective mesh devices.
Clean up / combine logic for computing global axis size / mesh devices. Improving the `jax.jit(jax.shard_map)` implementation of `jax.pmap`. PiperOrigin-RevId: 858891987
1 parent b1760ad commit 9c425e3

File tree

5 files changed

+469
-59
lines changed

5 files changed

+469
-59
lines changed

jax/_src/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1095,16 +1095,21 @@ pytype_strict_library(
10951095
srcs = ["pmap.py"],
10961096
deps = [
10971097
":api",
1098+
":api_util",
1099+
":config",
10981100
":core",
1101+
":dtypes",
10991102
":lax",
11001103
":mesh",
1104+
":random",
11011105
":shard_map",
1106+
":sharding_impls",
11021107
":stages",
11031108
":traceback_util",
11041109
":tree_util",
11051110
":util",
11061111
":xla_bridge",
1107-
],
1112+
] + py_deps("numpy"),
11081113
)
11091114

11101115
pytype_strict_library(

0 commit comments

Comments
 (0)