Skip to content

Commit dc9c815

Browse files
danielsuoKfacJaxDev
authored andcommitted
[kfac_jax] Prepare for jax_pmap_shmap_merge=True.
Condition on `jax_pmap_shmap_merge` to grab explicitly grab the first shard rather than rely on `x[0]`. See https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays for more information. NOTE: `kfac_jax.utils.index_if_not_scalar` has been removed and inlined into `kfac_jax.utils.get_first`. - If you need to grab the first shard of a semantically replicated array (i.e., each shard is really a replica), use `kfac_jax.utils.get_first`. - If you really are trying to index into an array, use the usual array indexing operators (i.e., `x[0]`). PiperOrigin-RevId: 846329906
1 parent c0d25ad commit dc9c815

File tree

3 files changed

+25
-18
lines changed

3 files changed

+25
-18
lines changed

examples/training.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,8 @@ def format_path_entry(entry: Any) -> str:
520520
logging.info("%s %s %s", "=" * 20, "Parameters", "=" * 20)
521521
for path, var in jax.tree_util.tree_flatten_with_path(self._params)[0]:
522522
# Because of pmap
523-
var = var[0]
523+
var = kfac_jax.utils.get_first(var)
524+
assert isinstance(var, Array) # params are always arrays
524525
logging.info(
525526
"%s - %s, %s",
526527
"-".join(format_path_entry(p) for p in path),
@@ -541,7 +542,8 @@ def format_path_entry(entry: Any) -> str:
541542
# For __class__ entries
542543
continue
543544
# Because of pmap
544-
var = var[0]
545+
var = kfac_jax.utils.get_first(var)
546+
assert isinstance(var, Array) # optimizer state entries are always arrays
545547
logging.info(
546548
"%s - %s, %s",
547549
"/".join(format_path_entry(p) for p in path),
@@ -671,7 +673,10 @@ def train_step(self, global_step: Array, rng: PRNGKey) -> dict[str, Numeric]:
671673
for i in range(gathered_stat.shape[0]):
672674
stats[f"{name}_{i}"] = jnp.array([gathered_stat[i]])
673675

674-
stats = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), stats)
676+
if jax.config.jax_pmap_shmap_merge:
677+
stats = kfac_jax.utils.get_first(stats)
678+
else:
679+
stats = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), stats)
675680

676681
self._python_step += 1
677682
stats["progress"] = self.progress(self._python_step)

kfac_jax/_src/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
psum_if_pmap = parallel.psum_if_pmap
7474
pmap_mean = parallel.pmap_mean
7575
pmap_sum = parallel.pmap_sum
76-
index_if_not_scalar = parallel.index_if_not_scalar
7776
get_first = parallel.get_first
7877
get_mean = parallel.get_mean
7978
get_sum = parallel.get_sum

kfac_jax/_src/utils/parallel.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,27 @@ def psum_if_pmap(obj: TArrayTree, axis_name: str | None) -> TArrayTree:
8282
pmap_sum = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i")
8383

8484

85-
def index_if_not_scalar(value: Numeric, index: int = 0) -> Numeric:
86-
"""Index `value` at axis 0 if it is not a scalar, otherwise return it."""
87-
88-
if isinstance(value, Array):
85+
def get_first(obj: TArrayTree) -> TArrayTree:
86+
"""Index the PyTree leaves `x` of `obj` by `x[0]` if they are not scalars."""
8987

90-
if value.ndim > 0:
91-
return value[index]
92-
else:
93-
return value
88+
def _get_first(value: Numeric) -> Numeric:
89+
if isinstance(value, Array):
9490

95-
elif isinstance(value, (float, int)):
96-
return value
97-
raise ValueError("The input should be an instance of `Numeric`.")
91+
if value.ndim > 0:
92+
if jax.config.jax_pmap_shmap_merge:
93+
shard_data = value.addressable_shards[0].data
94+
if not value.sharding.is_fully_replicated:
95+
return shard_data.squeeze(0)
96+
return shard_data
97+
return value[0]
98+
else:
99+
return value
98100

101+
elif isinstance(value, (float, int)):
102+
return value
103+
raise ValueError("The input should be an instance of `Numeric`.")
99104

100-
def get_first(obj: TArrayTree) -> TArrayTree:
101-
"""Index the PyTree leaves `x` of `obj` by `x[0]` if they are not scalars."""
102-
return jax.tree_util.tree_map(index_if_not_scalar, obj)
105+
return jax.tree_util.tree_map(_get_first, obj)
103106

104107

105108
def get_mean(obj: TArrayTree) -> TArrayTree:

0 commit comments

Comments
 (0)