From 835f0bd7ac77cff9c0d279cf147d66bd238af881 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 1 Apr 2025 16:29:36 +0100 Subject: [PATCH 01/21] modify primitives --- axlearn/common/base_layer.py | 2 +- axlearn/common/base_layer_test.py | 4 ++-- axlearn/common/utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/axlearn/common/base_layer.py b/axlearn/common/base_layer.py index 47234d86f..a09b6597c 100644 --- a/axlearn/common/base_layer.py +++ b/axlearn/common/base_layer.py @@ -532,7 +532,7 @@ def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optiona return FanAxes(in_axis=-2, out_axis=-1) def _remat_name(self, x: Tensor, name: str) -> Tensor: - """Tags 'x' with 'name' using a custom jax.core.Primitive, which is otherwise a no-op. + """Tags 'x' with 'name' using a custom jax.extend.core.Primitive, which is otherwise a no-op. This is useful for custom activation rematerialization policies, as it allows us to filter on tagged points in the jaxpr. diff --git a/axlearn/common/base_layer_test.py b/axlearn/common/base_layer_test.py index 5045061c0..991416f60 100644 --- a/axlearn/common/base_layer_test.py +++ b/axlearn/common/base_layer_test.py @@ -126,7 +126,7 @@ def backward_impl(x): backward() return (x,) - prim = jax.core.Primitive("passthrough_with_callback") + prim = jax.extend.core.Primitive("passthrough_with_callback") prim.def_impl(forward_impl) prim.def_abstract_eval(forward_impl) jax.interpreters.ad.deflinear(prim, backward_impl) @@ -302,7 +302,7 @@ def test_remat_name(self): tagged_params = [el for el in jaxpr.eqns if "name" in el.params] self.assertEqual(len(tagged_params), 1) tagged_param = tagged_params.pop() - self.assertIsInstance(tagged_param.primitive, jax.core.Primitive) + self.assertIsInstance(tagged_param.primitive, jax.extend.core.Primitive) self.assertEqual(tagged_param.primitive.name, "name") self.assertEqual(f"{type(test_module).__name__}.{var_tag}", tagged_param.params.get("name")) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index c60b7bc93..28fd9d5d6 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -49,7 +49,7 @@ from jax._src.mesh import thread_resources from jax._src.tree_util import KeyEntry, KeyPath from jax.ad_checkpoint import Offloadable, Recompute, Saveable -from jax.core import Primitive +from jax.extend.core import Primitive from jax.experimental import mesh_utils, multihost_utils from jax.sharding import PartitionSpec From dd380a4302e91f2088d8ba58d8cde1b69402de8b Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 1 Apr 2025 18:46:17 +0100 Subject: [PATCH 02/21] fix pre-commit error --- axlearn/common/base_layer.py | 3 ++- axlearn/common/utils.py | 9 ++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/axlearn/common/base_layer.py b/axlearn/common/base_layer.py index a09b6597c..b721595a4 100644 --- a/axlearn/common/base_layer.py +++ b/axlearn/common/base_layer.py @@ -532,7 +532,8 @@ def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optiona return FanAxes(in_axis=-2, out_axis=-1) def _remat_name(self, x: Tensor, name: str) -> Tensor: - """Tags 'x' with 'name' using a custom jax.extend.core.Primitive, which is otherwise a no-op. + """Tags 'x' with 'name' using a custom jax.extend.core.Primitive, which + is otherwise a no-op. This is useful for custom activation rematerialization policies, as it allows us to filter on tagged points in the jaxpr. diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 28fd9d5d6..b637ef013 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -49,8 +49,8 @@ from jax._src.mesh import thread_resources from jax._src.tree_util import KeyEntry, KeyPath from jax.ad_checkpoint import Offloadable, Recompute, Saveable -from jax.extend.core import Primitive from jax.experimental import mesh_utils, multihost_utils +from jax.extend.core import Primitive from jax.sharding import PartitionSpec from axlearn.common import serialization @@ -148,8 +148,7 @@ def sharding(self) -> jax.sharding.Sharding: class RematPolicy(Protocol): - def __call__(self, prim: Primitive, *args: Any, **params: Any) -> Union[RematType, bool]: - ... + def __call__(self, prim: Primitive, *args: Any, **params: Any) -> Union[RematType, bool]: ... def save_and_offload_only_these_names_regex( @@ -1159,7 +1158,7 @@ def per_param_dtype_by_path( """ def fn( - tree: Union[Nested[Tensor], Nested[TensorSpec]] + tree: Union[Nested[Tensor], Nested[TensorSpec]], ) -> Union[Nested[Tensor], Nested[TensorSpec]]: if update_rules is None: return jax.tree.map(lambda x: default_dtype, tree_paths(tree)) @@ -1214,7 +1213,7 @@ def cast_per_param( def canonicalize_per_param_dtype( - param_dtype: Union[jnp.dtype, ConfigOr[PerParamFn[jnp.dtype]]] + param_dtype: Union[jnp.dtype, ConfigOr[PerParamFn[jnp.dtype]]], ) -> ConfigOr[PerParamFn[jnp.dtype]]: """Canonicalize the input `param_dtype` to a consistent format of `ConfigOr[PerParamFn[jnp.dtype]]`, which handles three possible cases: From 5f1640df70d4bc77f9c142c6f148f67f4f933167 Mon Sep 17 00:00:00 2001 From: Steboss Date: Thu, 3 Apr 2025 14:08:30 +0100 Subject: [PATCH 03/21] update jax.tree_map to jax.tree_util.tree_map --- axlearn/common/adapter_flax_test.py | 4 ++-- axlearn/common/array_serialization.py | 8 +++---- axlearn/common/decoding_test.py | 11 ++++----- axlearn/common/gradient_accumulation.py | 6 +++-- axlearn/common/gradient_accumulation_test.py | 2 +- axlearn/common/input_base.py | 2 +- axlearn/common/input_base_test.py | 2 +- axlearn/common/learner_test.py | 2 +- axlearn/common/optimizers.py | 24 +++++++++++++------- axlearn/common/optimizers_test.py | 6 ++++- axlearn/common/state_builder.py | 2 +- axlearn/common/state_builder_test.py | 4 +++- axlearn/common/t5_test.py | 4 ++-- axlearn/common/test_utils.py | 6 +++-- 14 files changed, 49 insertions(+), 34 deletions(-) diff --git a/axlearn/common/adapter_flax_test.py b/axlearn/common/adapter_flax_test.py index 34c14f276..afc01201e 100644 --- a/axlearn/common/adapter_flax_test.py +++ b/axlearn/common/adapter_flax_test.py @@ -260,12 +260,12 @@ def test_sharding(self): ) layer_params = jit_init_state(jax.random.PRNGKey(1)) - jax.tree_map( + jax.tree_util.tree_map( lambda x: self.assertFalse(x.value.is_fully_replicated), layer_params, is_leaf=lambda x: isinstance(x, nn.Partitioned), ) - jax.tree_map( + jax.tree_util.tree_map( lambda x: jax.debug.visualize_array_sharding(x.value), layer_params, is_leaf=lambda x: isinstance(x, nn.Partitioned), diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 1415c022b..6e7c0c6a6 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -166,8 +166,8 @@ async def _slice_shard_and_copy_to_host(shard_infos: list[_ShardInfo]): The .data field of each shard_info is modified in-place. """ # Note: jax.lax.slice_in_dim in _slice_fn will be cached in jit cache after first call. - shard_data = jax.tree_map(_slice_fn, shard_infos) - shard_data = jax.tree_map(_transfer_to_host, shard_data) + shard_data = jax.tree_util.tree_map(_slice_fn, shard_infos) + shard_data = jax.tree_util.tree_map(_transfer_to_host, shard_data) await asyncio.sleep(0) # Allow other D2Hs to launch. @@ -200,8 +200,7 @@ def _fix_metadata(tspec: dict[str, Any], shard_infos: list[_ShardInfo]): class TensorstoreSpecModifier: - def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]): - ... + def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]): ... async def _async_serialize( @@ -429,6 +428,7 @@ async def _run_serializer(): # Copied from (with modifications) # https://github.com/jax-ml/jax/blob/66037d10e7742c4fcadd07f0459a00813ec7ed5f/jax/experimental/array_serialization/serialization.py#L413-L429 + # pylint: disable=R0917 def deserialize( self, shardings: Sequence[Union[jax.sharding.Sharding, layout.Layout]], diff --git a/axlearn/common/decoding_test.py b/axlearn/common/decoding_test.py index 8e2cbf6fa..d5ca00657 100644 --- a/axlearn/common/decoding_test.py +++ b/axlearn/common/decoding_test.py @@ -221,13 +221,9 @@ def tokens_to_scores( # with no length normalization and length normalization. # bp_scores[0][2] should be nobp_scores[0][0] / (len('START-AA-ENDPAD') ** alpha) # Here len('START-AA-ENDPAD') is 4 since PAD is ignored. - np.testing.assert_almost_equal( - no_bp_scores[0][0] / (4**alpha), bp_scores[0][2], decimal=5 - ) + np.testing.assert_almost_equal(no_bp_scores[0][0] / (4**alpha), bp_scores[0][2], decimal=5) # no_bp_scores[0][2] and bp_scores[0][0] correspond the log probs of 'START-AAB-END' - np.testing.assert_almost_equal( - no_bp_scores[0][2] / (5**alpha), bp_scores[0][0], decimal=5 - ) + np.testing.assert_almost_equal(no_bp_scores[0][2] / (5**alpha), bp_scores[0][0], decimal=5) def test_add_decoding_dim(self): x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) @@ -827,6 +823,7 @@ def tokens_to_scores(tokens, cache): prefix_merger=[None, _TokenSumPrefixMerger()], brevity_penalty=[None, decoding.brevity_penalty_fn(bp_type="hf", alpha=1.0)], ) + # pylint: disable=R0917 def test_beam_search_prefill( self, prompt_length: Sequence[int], @@ -1554,7 +1551,7 @@ def tokens_to_scores( ) # Compare against expected. - target = jnp.asarray(jax.tree_map(vocab.tokenizer.piece_to_id, expected)) + target = jnp.asarray(jax.tree_util.tree_map(vocab.tokenizer.piece_to_id, expected)) self.assertTrue(jnp.all(sequences == target)) # Check that the token scores are 0 for pad_id tokens. diff --git a/axlearn/common/gradient_accumulation.py b/axlearn/common/gradient_accumulation.py index 0acdf7b43..9a2127ce0 100644 --- a/axlearn/common/gradient_accumulation.py +++ b/axlearn/common/gradient_accumulation.py @@ -195,7 +195,7 @@ def reshape_for_scan(x: Tensor): x = x.reshape(minibatch_size, -1, *x.shape[1:]) return jnp.swapaxes(x, 0, 1) - inputs["input_batch"] = jax.tree_map(reshape_for_scan, inputs["input_batch"]) + inputs["input_batch"] = jax.tree_util.tree_map(reshape_for_scan, inputs["input_batch"]) # Create a sample minibatch for the carry buffer creation below ( @@ -323,7 +323,9 @@ def func_bwd(saved_fwd_state, grad_from_later_in_network) -> tuple[Nested[Tensor """Defines backward pass for the custom vjp based gradient computation.""" grad_from_earlier, num_args = saved_fwd_state # Compute the backward pass gradient value. - grad = jax.tree_map(lambda x: x * grad_from_later_in_network.loss, grad_from_earlier) + grad = jax.tree_util.tree_map( + lambda x: x * grad_from_later_in_network.loss, grad_from_earlier + ) # Return gradient along with None so the output length equals to that of primal input. return (grad,) + (None,) * (num_args - 1) diff --git a/axlearn/common/gradient_accumulation_test.py b/axlearn/common/gradient_accumulation_test.py index b6615004c..df76d2a7e 100644 --- a/axlearn/common/gradient_accumulation_test.py +++ b/axlearn/common/gradient_accumulation_test.py @@ -77,7 +77,7 @@ def check_sharding(path, value): value, callback=lambda sharding: callback(path, sharding) ) - jax.tree_map(check_sharding, tree_paths(input_batch), input_batch) + jax.tree_util.tree_map(check_sharding, tree_paths(input_batch), input_batch) return input_batch callback = lambda path, sharding: self.assertEqual(expected[path], sharding.spec) diff --git a/axlearn/common/input_base.py b/axlearn/common/input_base.py index 91fbfcc7e..163a75dfc 100644 --- a/axlearn/common/input_base.py +++ b/axlearn/common/input_base.py @@ -114,7 +114,7 @@ def maybe_constrain(path: str, value: Tensor): "specify `PartitionSpec.UNCONSTRAINED` explicitly." ) - return jax.tree_map(maybe_constrain, tree_paths(input_batch), input_batch) + return jax.tree_util.tree_map(maybe_constrain, tree_paths(input_batch), input_batch) return fn diff --git a/axlearn/common/input_base_test.py b/axlearn/common/input_base_test.py index 951874e74..21ac18486 100644 --- a/axlearn/common/input_base_test.py +++ b/axlearn/common/input_base_test.py @@ -123,7 +123,7 @@ def check_sharding(path, value): @partial(pjit, in_shardings=None) def fn(input_batch): output_batch = ds.dispatch_global_batch(input_batch) - jax.tree_map(check_sharding, tree_paths(output_batch), output_batch) + jax.tree_util.tree_map(check_sharding, tree_paths(output_batch), output_batch) return output_batch fn.lower(input_batch).compile() diff --git a/axlearn/common/learner_test.py b/axlearn/common/learner_test.py index 02bf14a7e..dfa62c129 100644 --- a/axlearn/common/learner_test.py +++ b/axlearn/common/learner_test.py @@ -1168,7 +1168,7 @@ def loss_fn(model_params, inputs): result = jax.tree_util.tree_reduce(lambda x, y: x.sum() + y.sum(), model_params) return ForwardOutputs(loss=result, aux={}, output_collection=output_collection) - grads = jax.tree_map(lambda p: jnp.ones_like(p.value), params) + grads = jax.tree_util.tree_map(lambda p: jnp.ones_like(p.value), params) if method == "update": inputs = [ diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index c811c0ce7..0d2219be2 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -163,9 +163,9 @@ def copy_partition( dtype=spec.dtype, shape=spec.shape, mesh_axes=spec.mesh_axes, - memory_kind=memory_kind - if pattern and re.fullmatch(pattern, path) - else spec.memory_kind, + memory_kind=( + memory_kind if pattern and re.fullmatch(pattern, path) else spec.memory_kind + ), ), tree_paths(specs), specs, @@ -468,7 +468,7 @@ def update_fn( path=context.path() if context else None, ) - updates = jax.tree_map( + updates = jax.tree_util.tree_map( # Apply the scaling to each update. lambda g, m: g * m, updates, @@ -704,6 +704,7 @@ def sgd_optimizer( ) +# pylint: disable=R0913 def adamw_optimizer( learning_rate: schedule.Schedule, *, @@ -759,6 +760,7 @@ def adamw_optimizer( return chain(*tx) +# pylint: disable=R0913 def adamw_decoupled_optimizer( learning_rate: float, *, @@ -822,6 +824,7 @@ def adamw_decoupled_optimizer( return chain(*tx) +# pylint: disable=R0913 def adam_optimizer( learning_rate: schedule.Schedule, *, @@ -1021,6 +1024,7 @@ def get_scale_partition(param_spec: ParameterSpec) -> OptStateSpec: return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn) +# pylint: disable=R0913 def adafactor_optimizer( learning_rate: schedule.Schedule, *, @@ -1324,6 +1328,7 @@ def _moment( new_square_ema = decay * norm_square_ema + (1 - decay) * (val**2) return new_norm_ema, new_square_ema + # pylint: disable=R0913 def _is_valid_step( g_norm: Tensor, drop_norm: Union[float, DropNormThresholdFn], @@ -1642,7 +1647,7 @@ def update_fn(updates, state, params=None): del params mu = optax.update_moment(updates, state.mu, b2, 1) if mu_dtype is not None: - mu = jax.tree_map(lambda x: x.astype(mu_dtype), mu) + mu = jax.tree_util.tree_map(lambda x: x.astype(mu_dtype), mu) count_inc = optax.safe_int32_increment(state.count) updates = jax.tree.map(lambda g, m: jnp.sign((1.0 - b1) * g + b1 * m), updates, state.mu) return updates, ScaleByLionState(count=count_inc, mu=mu) @@ -1664,6 +1669,7 @@ def partition_fn( return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn) +# pylint: disable=R0917 def lion_optimizer( learning_rate: schedule.Schedule, b1: float, @@ -2104,9 +2110,11 @@ def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState: # memory usage of all states. Moreover, when the optimizer is run, all activations are # released, so we have less memory pressure at that point in time. return jax.tree.map( - lambda path, tensor: jax.device_put(tensor, TransferToMemoryKind(dst)) - if re.fullmatch(pattern, path) - else tensor, + lambda path, tensor: ( + jax.device_put(tensor, TransferToMemoryKind(dst)) + if re.fullmatch(pattern, path) + else tensor + ), tree_paths(state), state, ) diff --git a/axlearn/common/optimizers_test.py b/axlearn/common/optimizers_test.py index 49aa348e9..cb51bd0cf 100644 --- a/axlearn/common/optimizers_test.py +++ b/axlearn/common/optimizers_test.py @@ -684,7 +684,7 @@ def test_weight_scaling(self, optimizer_cfg, param_scale): ) state = optimizer.init(params) - grads = jax.tree_map(jnp.ones_like, opt_param_values(params)) + grads = jax.tree_util.tree_map(jnp.ones_like, opt_param_values(params)) updates, _ = optimizer.update(grads, state=state, params=params) updated_value = optax.apply_updates(opt_param_values(params), updates) @@ -1281,6 +1281,8 @@ def test_scale_by_schedule(self): update_schedule=(0.1,), weight_decay=(1e-4,), ) + + # pylint: disable=R0917 def test_adastar_vs_adamw_decoupled( self, learning_rate, b1, b2, eps, update_schedule, weight_decay ): @@ -1326,6 +1328,7 @@ def test_adastar_vs_adamw_decoupled( clipping_threshold=(None, 1e-2, 1.0), weight_decay=(1e-4,), ) + # pylint: disable=R0917 def test_adastar_vs_adafactor( self, learning_rate, @@ -1423,6 +1426,7 @@ def compute_loss(param_values): weight_decay=3e-4, ), ) + # pylint: disable=R0917 def test_adastar_summaries( self, learning_rate, diff --git a/axlearn/common/state_builder.py b/axlearn/common/state_builder.py index 2f2dfdf6a..38d7f5060 100644 --- a/axlearn/common/state_builder.py +++ b/axlearn/common/state_builder.py @@ -291,7 +291,7 @@ def target_to_source(self, target: Builder.State) -> tuple[Builder.State, Builde def source_to_target(self, source: Builder.State, aux: Builder.State) -> Builder.State: """Source is newly loaded state, aux is original state.""" - new_trainer_state = jax.tree_map( + new_trainer_state = jax.tree_util.tree_map( self._selector, utils.tree_paths(aux.trainer_state), aux.trainer_state, diff --git a/axlearn/common/state_builder_test.py b/axlearn/common/state_builder_test.py index 88c334ba5..0cf470e35 100644 --- a/axlearn/common/state_builder_test.py +++ b/axlearn/common/state_builder_test.py @@ -701,6 +701,7 @@ def _run_builder( converted_weight = replicate_to_local_data(converted_weight) return source_weight, converted_weight + # pylint: disable=R0917 def _dummy_model_config( self, patch_size: tuple[int, ...], @@ -721,6 +722,7 @@ def _dummy_model_config( dtype=jnp.float32, ) + # pylint: disable=R0917 def _mock_image_config( self, patch_size: tuple[int, ...], @@ -974,7 +976,7 @@ def _init_state(*args): ref_repeat = torch.nn.Linear(in_features=2, out_features=3, bias=True) ref_repeat_params = torch_to_axlearn(ref_repeat) # Tile the params across repeat dim. - ref_repeat_params = jax.tree_map( + ref_repeat_params = jax.tree_util.tree_map( lambda x: jnp.tile(x, [repeat_cfg.num_layers] + [1] * x.ndim), ref_repeat_params ) diff --git a/axlearn/common/t5_test.py b/axlearn/common/t5_test.py index 1fef71222..4986f946c 100644 --- a/axlearn/common/t5_test.py +++ b/axlearn/common/t5_test.py @@ -379,12 +379,12 @@ def _forward(model_parameters, forward_input_batch): return dict( prng_key=new_prng_key, - model=jax.tree_map(lambda x, y: x + y, state["model"], grads), + model=jax.tree_util.tree_map(lambda x, y: x + y, state["model"], grads), ) state_partition_specs = dict( prng_key=None, - model=jax.tree_map( + model=jax.tree_util.tree_map( lambda spec: spec.mesh_axes, layer.create_parameter_specs_recursively(), ), diff --git a/axlearn/common/test_utils.py b/axlearn/common/test_utils.py index e828d0394..48eae4b31 100644 --- a/axlearn/common/test_utils.py +++ b/axlearn/common/test_utils.py @@ -469,14 +469,16 @@ def replace_keys(v, mapping): # Complete the param_init_specs to match the params treedef. # Replace with Nones so that jax doesn't treat them as leaves. - params_with_nones = jax.tree_map( + params_with_nones = jax.tree_util.tree_map( partial(replace_keys, mapping={k: None for k in delegates}), params, is_leaf=is_leaf ) _, treedef = jax.tree_util.tree_flatten(params_with_nones) inits_with_nones = jax.tree_util.tree_unflatten(treedef, param_init_specs) # Replace the Nones with a delegate. - return jax.tree_map(partial(replace_keys, mapping=delegates), inits_with_nones, is_leaf=is_leaf) + return jax.tree_util.tree_map( + partial(replace_keys, mapping=delegates), inits_with_nones, is_leaf=is_leaf + ) def read_param_init_specs_recursively( From deb1da9cad3fda2354e50a7e7df35533ae6357c9 Mon Sep 17 00:00:00 2001 From: Steboss Date: Thu, 3 Apr 2025 19:03:01 +0100 Subject: [PATCH 04/21] Update axlearn/common/array_serialization.py Co-authored-by: Ruoming Pang --- axlearn/common/array_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 6e7c0c6a6..ffaaa577b 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -428,7 +428,7 @@ async def _run_serializer(): # Copied from (with modifications) # https://github.com/jax-ml/jax/blob/66037d10e7742c4fcadd07f0459a00813ec7ed5f/jax/experimental/array_serialization/serialization.py#L413-L429 - # pylint: disable=R0917 + # pylint: disable=too-many-arguments def deserialize( self, shardings: Sequence[Union[jax.sharding.Sharding, layout.Layout]], From 47848221e540a9b0cb6cc7ae4bcf71a3130a90f5 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 7 Apr 2025 16:02:02 +0100 Subject: [PATCH 05/21] fix black formatting --- axlearn/common/array_serialization.py | 3 ++- axlearn/common/decoding_test.py | 8 ++++++-- axlearn/common/utils.py | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index ffaaa577b..98ec15a54 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -200,7 +200,8 @@ def _fix_metadata(tspec: dict[str, Any], shard_infos: list[_ShardInfo]): class TensorstoreSpecModifier: - def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]): ... + def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]): + ... async def _async_serialize( diff --git a/axlearn/common/decoding_test.py b/axlearn/common/decoding_test.py index d5ca00657..8d7afde61 100644 --- a/axlearn/common/decoding_test.py +++ b/axlearn/common/decoding_test.py @@ -221,9 +221,13 @@ def tokens_to_scores( # with no length normalization and length normalization. # bp_scores[0][2] should be nobp_scores[0][0] / (len('START-AA-ENDPAD') ** alpha) # Here len('START-AA-ENDPAD') is 4 since PAD is ignored. - np.testing.assert_almost_equal(no_bp_scores[0][0] / (4**alpha), bp_scores[0][2], decimal=5) + np.testing.assert_almost_equal( + no_bp_scores[0][0] / (4**alpha), bp_scores[0][2], decimal=5 + ) # no_bp_scores[0][2] and bp_scores[0][0] correspond the log probs of 'START-AAB-END' - np.testing.assert_almost_equal(no_bp_scores[0][2] / (5**alpha), bp_scores[0][0], decimal=5) + np.testing.assert_almost_equal( + no_bp_scores[0][2] / (5**alpha), bp_scores[0][0], decimal=5 + ) def test_add_decoding_dim(self): x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index b637ef013..49ef4c6ed 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -148,7 +148,8 @@ def sharding(self) -> jax.sharding.Sharding: class RematPolicy(Protocol): - def __call__(self, prim: Primitive, *args: Any, **params: Any) -> Union[RematType, bool]: ... + def __call__(self, prim: Primitive, *args: Any, **params: Any) -> Union[RematType, bool]: + ... def save_and_offload_only_these_names_regex( From e1828ee1a93f9634c5099ae109b2446f63cc95a5 Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 11 Apr 2025 09:11:15 +0100 Subject: [PATCH 06/21] fix tree_map --- axlearn/common/adapter_flax_test.py | 4 ++-- axlearn/common/array_serialization.py | 6 +++--- axlearn/common/decoding_test.py | 2 +- axlearn/common/gradient_accumulation.py | 6 ++---- axlearn/common/gradient_accumulation_test.py | 2 +- axlearn/common/input_base.py | 2 +- axlearn/common/input_base_test.py | 2 +- axlearn/common/learner_test.py | 2 +- axlearn/common/optimizers.py | 4 ++-- axlearn/common/optimizers_test.py | 2 +- axlearn/common/state_builder.py | 4 ++-- axlearn/common/state_builder_test.py | 2 +- axlearn/common/t5_test.py | 4 ++-- axlearn/common/test_utils.py | 6 ++---- axlearn/common/utils.py | 2 +- 15 files changed, 23 insertions(+), 27 deletions(-) diff --git a/axlearn/common/adapter_flax_test.py b/axlearn/common/adapter_flax_test.py index afc01201e..ed2d2eec1 100644 --- a/axlearn/common/adapter_flax_test.py +++ b/axlearn/common/adapter_flax_test.py @@ -260,12 +260,12 @@ def test_sharding(self): ) layer_params = jit_init_state(jax.random.PRNGKey(1)) - jax.tree_util.tree_map( + jax.tree.map( lambda x: self.assertFalse(x.value.is_fully_replicated), layer_params, is_leaf=lambda x: isinstance(x, nn.Partitioned), ) - jax.tree_util.tree_map( + jax.tree.map( lambda x: jax.debug.visualize_array_sharding(x.value), layer_params, is_leaf=lambda x: isinstance(x, nn.Partitioned), diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 98ec15a54..1350cc921 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -166,8 +166,8 @@ async def _slice_shard_and_copy_to_host(shard_infos: list[_ShardInfo]): The .data field of each shard_info is modified in-place. """ # Note: jax.lax.slice_in_dim in _slice_fn will be cached in jit cache after first call. - shard_data = jax.tree_util.tree_map(_slice_fn, shard_infos) - shard_data = jax.tree_util.tree_map(_transfer_to_host, shard_data) + shard_data = jax.tree.map(_slice_fn, shard_infos) + shard_data = jax.tree.map(_transfer_to_host, shard_data) await asyncio.sleep(0) # Allow other D2Hs to launch. @@ -447,7 +447,7 @@ async def _run_deserializer(): # pylint: disable-next=protected-access byte_limiter = serialization._LimitInFlightBytes(concurrent_bytes) - future_arrays = jax.tree_util.tree_map( + future_arrays = jax.tree.map( functools.partial(serialization.async_deserialize, byte_limiter=byte_limiter), shardings, tensorstore_specs, diff --git a/axlearn/common/decoding_test.py b/axlearn/common/decoding_test.py index 8d7afde61..689fce667 100644 --- a/axlearn/common/decoding_test.py +++ b/axlearn/common/decoding_test.py @@ -1555,7 +1555,7 @@ def tokens_to_scores( ) # Compare against expected. - target = jnp.asarray(jax.tree_util.tree_map(vocab.tokenizer.piece_to_id, expected)) + target = jnp.asarray(jax.tree.map(vocab.tokenizer.piece_to_id, expected)) self.assertTrue(jnp.all(sequences == target)) # Check that the token scores are 0 for pad_id tokens. diff --git a/axlearn/common/gradient_accumulation.py b/axlearn/common/gradient_accumulation.py index 9a2127ce0..f70b38b08 100644 --- a/axlearn/common/gradient_accumulation.py +++ b/axlearn/common/gradient_accumulation.py @@ -195,7 +195,7 @@ def reshape_for_scan(x: Tensor): x = x.reshape(minibatch_size, -1, *x.shape[1:]) return jnp.swapaxes(x, 0, 1) - inputs["input_batch"] = jax.tree_util.tree_map(reshape_for_scan, inputs["input_batch"]) + inputs["input_batch"] = jax.tree.map(reshape_for_scan, inputs["input_batch"]) # Create a sample minibatch for the carry buffer creation below ( @@ -323,9 +323,7 @@ def func_bwd(saved_fwd_state, grad_from_later_in_network) -> tuple[Nested[Tensor """Defines backward pass for the custom vjp based gradient computation.""" grad_from_earlier, num_args = saved_fwd_state # Compute the backward pass gradient value. - grad = jax.tree_util.tree_map( - lambda x: x * grad_from_later_in_network.loss, grad_from_earlier - ) + grad = jax.tree.map(lambda x: x * grad_from_later_in_network.loss, grad_from_earlier) # Return gradient along with None so the output length equals to that of primal input. return (grad,) + (None,) * (num_args - 1) diff --git a/axlearn/common/gradient_accumulation_test.py b/axlearn/common/gradient_accumulation_test.py index df76d2a7e..0398c14d2 100644 --- a/axlearn/common/gradient_accumulation_test.py +++ b/axlearn/common/gradient_accumulation_test.py @@ -77,7 +77,7 @@ def check_sharding(path, value): value, callback=lambda sharding: callback(path, sharding) ) - jax.tree_util.tree_map(check_sharding, tree_paths(input_batch), input_batch) + jax.tree.map(check_sharding, tree_paths(input_batch), input_batch) return input_batch callback = lambda path, sharding: self.assertEqual(expected[path], sharding.spec) diff --git a/axlearn/common/input_base.py b/axlearn/common/input_base.py index 163a75dfc..d42fde14f 100644 --- a/axlearn/common/input_base.py +++ b/axlearn/common/input_base.py @@ -114,7 +114,7 @@ def maybe_constrain(path: str, value: Tensor): "specify `PartitionSpec.UNCONSTRAINED` explicitly." ) - return jax.tree_util.tree_map(maybe_constrain, tree_paths(input_batch), input_batch) + return jax.tree.map(maybe_constrain, tree_paths(input_batch), input_batch) return fn diff --git a/axlearn/common/input_base_test.py b/axlearn/common/input_base_test.py index 21ac18486..370106b4b 100644 --- a/axlearn/common/input_base_test.py +++ b/axlearn/common/input_base_test.py @@ -123,7 +123,7 @@ def check_sharding(path, value): @partial(pjit, in_shardings=None) def fn(input_batch): output_batch = ds.dispatch_global_batch(input_batch) - jax.tree_util.tree_map(check_sharding, tree_paths(output_batch), output_batch) + jax.tree.map(check_sharding, tree_paths(output_batch), output_batch) return output_batch fn.lower(input_batch).compile() diff --git a/axlearn/common/learner_test.py b/axlearn/common/learner_test.py index dfa62c129..a017085a2 100644 --- a/axlearn/common/learner_test.py +++ b/axlearn/common/learner_test.py @@ -1168,7 +1168,7 @@ def loss_fn(model_params, inputs): result = jax.tree_util.tree_reduce(lambda x, y: x.sum() + y.sum(), model_params) return ForwardOutputs(loss=result, aux={}, output_collection=output_collection) - grads = jax.tree_util.tree_map(lambda p: jnp.ones_like(p.value), params) + grads = jax.tree.map(lambda p: jnp.ones_like(p.value), params) if method == "update": inputs = [ diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 0d2219be2..ebfc82151 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -468,7 +468,7 @@ def update_fn( path=context.path() if context else None, ) - updates = jax.tree_util.tree_map( + updates = jax.tree.map( # Apply the scaling to each update. lambda g, m: g * m, updates, @@ -1647,7 +1647,7 @@ def update_fn(updates, state, params=None): del params mu = optax.update_moment(updates, state.mu, b2, 1) if mu_dtype is not None: - mu = jax.tree_util.tree_map(lambda x: x.astype(mu_dtype), mu) + mu = jax.tree.map(lambda x: x.astype(mu_dtype), mu) count_inc = optax.safe_int32_increment(state.count) updates = jax.tree.map(lambda g, m: jnp.sign((1.0 - b1) * g + b1 * m), updates, state.mu) return updates, ScaleByLionState(count=count_inc, mu=mu) diff --git a/axlearn/common/optimizers_test.py b/axlearn/common/optimizers_test.py index cb51bd0cf..f07ff3bcb 100644 --- a/axlearn/common/optimizers_test.py +++ b/axlearn/common/optimizers_test.py @@ -684,7 +684,7 @@ def test_weight_scaling(self, optimizer_cfg, param_scale): ) state = optimizer.init(params) - grads = jax.tree_util.tree_map(jnp.ones_like, opt_param_values(params)) + grads = jax.tree.map(jnp.ones_like, opt_param_values(params)) updates, _ = optimizer.update(grads, state=state, params=params) updated_value = optax.apply_updates(opt_param_values(params), updates) diff --git a/axlearn/common/state_builder.py b/axlearn/common/state_builder.py index 38d7f5060..f67272b8b 100644 --- a/axlearn/common/state_builder.py +++ b/axlearn/common/state_builder.py @@ -291,7 +291,7 @@ def target_to_source(self, target: Builder.State) -> tuple[Builder.State, Builde def source_to_target(self, source: Builder.State, aux: Builder.State) -> Builder.State: """Source is newly loaded state, aux is original state.""" - new_trainer_state = jax.tree_util.tree_map( + new_trainer_state = jax.tree.map( self._selector, utils.tree_paths(aux.trainer_state), aux.trainer_state, @@ -866,7 +866,7 @@ def _copy_leaf( for target_scope, source_scope in self.scopes.items(): orig_source_model = utils.get_recursively(source.trainer_state.model, source_scope) - source_model = jax.tree_util.tree_map_with_path( + source_model = jax.tree.map_with_path( lambda path, leaf, source_scope=source_scope: _copy_leaf( path, leaf, source_scope=source_scope ), diff --git a/axlearn/common/state_builder_test.py b/axlearn/common/state_builder_test.py index 0cf470e35..4914a2ba8 100644 --- a/axlearn/common/state_builder_test.py +++ b/axlearn/common/state_builder_test.py @@ -976,7 +976,7 @@ def _init_state(*args): ref_repeat = torch.nn.Linear(in_features=2, out_features=3, bias=True) ref_repeat_params = torch_to_axlearn(ref_repeat) # Tile the params across repeat dim. - ref_repeat_params = jax.tree_util.tree_map( + ref_repeat_params = jax.tree.map( lambda x: jnp.tile(x, [repeat_cfg.num_layers] + [1] * x.ndim), ref_repeat_params ) diff --git a/axlearn/common/t5_test.py b/axlearn/common/t5_test.py index 4986f946c..1f2a4b7d1 100644 --- a/axlearn/common/t5_test.py +++ b/axlearn/common/t5_test.py @@ -379,12 +379,12 @@ def _forward(model_parameters, forward_input_batch): return dict( prng_key=new_prng_key, - model=jax.tree_util.tree_map(lambda x, y: x + y, state["model"], grads), + model=jax.tree.map(lambda x, y: x + y, state["model"], grads), ) state_partition_specs = dict( prng_key=None, - model=jax.tree_util.tree_map( + model=jax.tree.map( lambda spec: spec.mesh_axes, layer.create_parameter_specs_recursively(), ), diff --git a/axlearn/common/test_utils.py b/axlearn/common/test_utils.py index 48eae4b31..a7768c124 100644 --- a/axlearn/common/test_utils.py +++ b/axlearn/common/test_utils.py @@ -469,16 +469,14 @@ def replace_keys(v, mapping): # Complete the param_init_specs to match the params treedef. # Replace with Nones so that jax doesn't treat them as leaves. - params_with_nones = jax.tree_util.tree_map( + params_with_nones = jax.tree.map( partial(replace_keys, mapping={k: None for k in delegates}), params, is_leaf=is_leaf ) _, treedef = jax.tree_util.tree_flatten(params_with_nones) inits_with_nones = jax.tree_util.tree_unflatten(treedef, param_init_specs) # Replace the Nones with a delegate. - return jax.tree_util.tree_map( - partial(replace_keys, mapping=delegates), inits_with_nones, is_leaf=is_leaf - ) + return jax.tree.map(partial(replace_keys, mapping=delegates), inits_with_nones, is_leaf=is_leaf) def read_param_init_specs_recursively( diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 49ef4c6ed..21c54403a 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -405,7 +405,7 @@ def tree_paths( Note that None is not considered a leaf by jax.tree_util, hence also preserved by tree_paths. """ - return jax.tree_util.tree_map_with_path( + return jax.tree.map_with_path( lambda kp, _: separator.join(_key_entry_to_str(k) for k in kp), tree, is_leaf=is_leaf ) From 9d83d7cf6f5b43b63810d8ce0c97dac69ef0d168 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 29 Apr 2025 12:20:16 +0100 Subject: [PATCH 07/21] conflicts to fix --- axlearn/common/decoding_test.py | 1 + axlearn/common/optimizers.py | 3 +++ axlearn/common/state_builder_test.py | 1 + 3 files changed, 5 insertions(+) diff --git a/axlearn/common/decoding_test.py b/axlearn/common/decoding_test.py index 689fce667..b2af6f883 100644 --- a/axlearn/common/decoding_test.py +++ b/axlearn/common/decoding_test.py @@ -828,6 +828,7 @@ def tokens_to_scores(tokens, cache): brevity_penalty=[None, decoding.brevity_penalty_fn(bp_type="hf", alpha=1.0)], ) # pylint: disable=R0917 + # pylint: disable=R0917 def test_beam_search_prefill( self, prompt_length: Sequence[int], diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index ebfc82151..19d2f63a0 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -166,6 +166,9 @@ def copy_partition( memory_kind=( memory_kind if pattern and re.fullmatch(pattern, path) else spec.memory_kind ), + memory_kind=( + memory_kind if pattern and re.fullmatch(pattern, path) else spec.memory_kind + ), ), tree_paths(specs), specs, diff --git a/axlearn/common/state_builder_test.py b/axlearn/common/state_builder_test.py index 4914a2ba8..c57ae1ee6 100644 --- a/axlearn/common/state_builder_test.py +++ b/axlearn/common/state_builder_test.py @@ -701,6 +701,7 @@ def _run_builder( converted_weight = replicate_to_local_data(converted_weight) return source_weight, converted_weight + # pylint: disable=R0917 # pylint: disable=R0917 def _dummy_model_config( self, From a089d43f940314d0bb03942bd798623e621140db Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 11 Apr 2025 09:22:43 +0100 Subject: [PATCH 08/21] rebase with main to solve conflicts --- axlearn/common/array_serialization.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 1350cc921..a2140a9ec 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -168,6 +168,8 @@ async def _slice_shard_and_copy_to_host(shard_infos: list[_ShardInfo]): # Note: jax.lax.slice_in_dim in _slice_fn will be cached in jit cache after first call. shard_data = jax.tree.map(_slice_fn, shard_infos) shard_data = jax.tree.map(_transfer_to_host, shard_data) + shard_data = jax.tree.map(_slice_fn, shard_infos) + shard_data = jax.tree.map(_transfer_to_host, shard_data) await asyncio.sleep(0) # Allow other D2Hs to launch. @@ -448,7 +450,12 @@ async def _run_deserializer(): byte_limiter = serialization._LimitInFlightBytes(concurrent_bytes) future_arrays = jax.tree.map( - functools.partial(serialization.async_deserialize, byte_limiter=byte_limiter), + functools.partial( + _async_deserialize, + byte_limiter=byte_limiter, + h2d_limiter=h2d_limiter, + single_thread_pool=self._single_thread_pool, + ), shardings, tensorstore_specs, [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, From 13a8451bae3dc72d0d0c33f59d8c60f4c831ea68 Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 11 Apr 2025 23:40:35 +0100 Subject: [PATCH 09/21] modify R0917 --- axlearn/common/decoding_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/axlearn/common/decoding_test.py b/axlearn/common/decoding_test.py index b2af6f883..36cc09cf8 100644 --- a/axlearn/common/decoding_test.py +++ b/axlearn/common/decoding_test.py @@ -827,8 +827,7 @@ def tokens_to_scores(tokens, cache): prefix_merger=[None, _TokenSumPrefixMerger()], brevity_penalty=[None, decoding.brevity_penalty_fn(bp_type="hf", alpha=1.0)], ) - # pylint: disable=R0917 - # pylint: disable=R0917 + # pylint: disable=too-many-arguments def test_beam_search_prefill( self, prompt_length: Sequence[int], From 8bc6f642d5b6fd907f54b2b60d7deace269d02ac Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 11 Apr 2025 23:42:15 +0100 Subject: [PATCH 10/21] modify R0917 --- axlearn/common/optimizers.py | 2 +- axlearn/common/state_builder_test.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 19d2f63a0..a5f5500ec 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -1672,7 +1672,7 @@ def partition_fn( return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn) -# pylint: disable=R0917 +# pylint: disable=too-many-arguments def lion_optimizer( learning_rate: schedule.Schedule, b1: float, diff --git a/axlearn/common/state_builder_test.py b/axlearn/common/state_builder_test.py index c57ae1ee6..d742d7b70 100644 --- a/axlearn/common/state_builder_test.py +++ b/axlearn/common/state_builder_test.py @@ -701,8 +701,7 @@ def _run_builder( converted_weight = replicate_to_local_data(converted_weight) return source_weight, converted_weight - # pylint: disable=R0917 - # pylint: disable=R0917 + # pylint: disable=too-many-arguments def _dummy_model_config( self, patch_size: tuple[int, ...], @@ -723,7 +722,7 @@ def _dummy_model_config( dtype=jnp.float32, ) - # pylint: disable=R0917 + # pylint: disable=too-many-arguments def _mock_image_config( self, patch_size: tuple[int, ...], From 16e4c04eb85a0021edbfa5874c06023535874546 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Apr 2025 10:22:51 +0100 Subject: [PATCH 11/21] Fix comments --- axlearn/common/array_serialization.py | 1 - axlearn/common/decoding_test.py | 1 - axlearn/common/optimizers.py | 6 ------ axlearn/common/optimizers_test.py | 4 ---- axlearn/common/state_builder_test.py | 2 -- 5 files changed, 14 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index a2140a9ec..dea44abab 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -431,7 +431,6 @@ async def _run_serializer(): # Copied from (with modifications) # https://github.com/jax-ml/jax/blob/66037d10e7742c4fcadd07f0459a00813ec7ed5f/jax/experimental/array_serialization/serialization.py#L413-L429 - # pylint: disable=too-many-arguments def deserialize( self, shardings: Sequence[Union[jax.sharding.Sharding, layout.Layout]], diff --git a/axlearn/common/decoding_test.py b/axlearn/common/decoding_test.py index 36cc09cf8..bf4a34942 100644 --- a/axlearn/common/decoding_test.py +++ b/axlearn/common/decoding_test.py @@ -827,7 +827,6 @@ def tokens_to_scores(tokens, cache): prefix_merger=[None, _TokenSumPrefixMerger()], brevity_penalty=[None, decoding.brevity_penalty_fn(bp_type="hf", alpha=1.0)], ) - # pylint: disable=too-many-arguments def test_beam_search_prefill( self, prompt_length: Sequence[int], diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index a5f5500ec..a618f46f5 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -707,7 +707,6 @@ def sgd_optimizer( ) -# pylint: disable=R0913 def adamw_optimizer( learning_rate: schedule.Schedule, *, @@ -763,7 +762,6 @@ def adamw_optimizer( return chain(*tx) -# pylint: disable=R0913 def adamw_decoupled_optimizer( learning_rate: float, *, @@ -827,7 +825,6 @@ def adamw_decoupled_optimizer( return chain(*tx) -# pylint: disable=R0913 def adam_optimizer( learning_rate: schedule.Schedule, *, @@ -1027,7 +1024,6 @@ def get_scale_partition(param_spec: ParameterSpec) -> OptStateSpec: return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn) -# pylint: disable=R0913 def adafactor_optimizer( learning_rate: schedule.Schedule, *, @@ -1331,7 +1327,6 @@ def _moment( new_square_ema = decay * norm_square_ema + (1 - decay) * (val**2) return new_norm_ema, new_square_ema - # pylint: disable=R0913 def _is_valid_step( g_norm: Tensor, drop_norm: Union[float, DropNormThresholdFn], @@ -1672,7 +1667,6 @@ def partition_fn( return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn) -# pylint: disable=too-many-arguments def lion_optimizer( learning_rate: schedule.Schedule, b1: float, diff --git a/axlearn/common/optimizers_test.py b/axlearn/common/optimizers_test.py index f07ff3bcb..b0bb66504 100644 --- a/axlearn/common/optimizers_test.py +++ b/axlearn/common/optimizers_test.py @@ -1281,8 +1281,6 @@ def test_scale_by_schedule(self): update_schedule=(0.1,), weight_decay=(1e-4,), ) - - # pylint: disable=R0917 def test_adastar_vs_adamw_decoupled( self, learning_rate, b1, b2, eps, update_schedule, weight_decay ): @@ -1328,7 +1326,6 @@ def test_adastar_vs_adamw_decoupled( clipping_threshold=(None, 1e-2, 1.0), weight_decay=(1e-4,), ) - # pylint: disable=R0917 def test_adastar_vs_adafactor( self, learning_rate, @@ -1426,7 +1423,6 @@ def compute_loss(param_values): weight_decay=3e-4, ), ) - # pylint: disable=R0917 def test_adastar_summaries( self, learning_rate, diff --git a/axlearn/common/state_builder_test.py b/axlearn/common/state_builder_test.py index d742d7b70..45935468f 100644 --- a/axlearn/common/state_builder_test.py +++ b/axlearn/common/state_builder_test.py @@ -701,7 +701,6 @@ def _run_builder( converted_weight = replicate_to_local_data(converted_weight) return source_weight, converted_weight - # pylint: disable=too-many-arguments def _dummy_model_config( self, patch_size: tuple[int, ...], @@ -722,7 +721,6 @@ def _dummy_model_config( dtype=jnp.float32, ) - # pylint: disable=too-many-arguments def _mock_image_config( self, patch_size: tuple[int, ...], From 9316c5c9aab78430d4fcdaa788c12f533bb5bd35 Mon Sep 17 00:00:00 2001 From: Steboss Date: Wed, 16 Apr 2025 11:05:59 +0100 Subject: [PATCH 12/21] use flatten_one_level_with_keys --- axlearn/common/utils.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 79ed9b707..6a9cbd4aa 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -31,10 +31,12 @@ from typing import ( Any, Callable, + List, Literal, NamedTuple, Optional, Protocol, + Tuple, TypeVar, Union, runtime_checkable, @@ -42,14 +44,13 @@ import attr import jax -import jax.flatten_util import numpy as np from absl import logging from jax import numpy as jnp from jax._src.ad_checkpoint import name_p from jax._src.lax import lax as lax_internal from jax._src.mesh import thread_resources -from jax._src.tree_util import KeyEntry, KeyPath +from jax._src.tree_util import KeyEntry, KeyPath, flatten_one_level_with_keys from jax.ad_checkpoint import Offloadable, Recompute, Saveable from jax.experimental import mesh_utils, multihost_utils from jax.extend.core import Primitive @@ -1853,7 +1854,7 @@ def thread_stack_traces() -> Sequence[Sequence[str]]: return grouped_lines -def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]: +def pytree_children(node: Any) -> List[Tuple[KeyEntry, Any]]: """Generate the (key, value) pairs for the immediate children of a pytree `node`. The returned children match those returned by @@ -1866,23 +1867,13 @@ def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]: assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])] ``` """ - # pylint: disable-next=protected-access - registry_with_keypaths = jax._src.tree_util._registry_with_keypaths - - key_handler = registry_with_keypaths.get(type(node)) - if key_handler: - key_children, _ = key_handler.flatten_with_keys(node) - return key_children - - flat = jax.tree_util.default_registry.flatten_one_level(node) - if flat is None: + try: + # pylint: disable-next=protected-access + key_child_pairs, _ = flatten_one_level_with_keys(node) + return list(key_child_pairs) + except ValueError: return [] - if isinstance(node, tuple) and hasattr(node, "_fields") and flat[1] == type(node): - # Handle namedtuple as a special case, based on heuristic. - return [(jax.tree_util.GetAttrKey(s), getattr(node, s)) for s in node._fields] - return [(jax.tree_util.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])] - def find_cycles(tree: Nested) -> dict[str, KeyPath]: """Find a cycle in pytree `tree` if one exists. From 8de926aecdf613280b918ee7c0a52fd389fd9d37 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 28 Apr 2025 16:25:49 +0100 Subject: [PATCH 13/21] the spmd has been removed in commit 7634230cdcd2d3cb42d1093f6ab255f47f9869d5 --- axlearn/common/utils_spmd.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 62330d786..1917e99a3 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -43,8 +43,6 @@ def setup( """ # Use a GSPMD-friendly PRNG implementation. jax.config.update("jax_default_prng_impl", "rbg") - # This allows replicated jax.Arrays to be used for computation on the host. - jax.config.update("jax_spmd_mode", "allow_all") global _jax_distributed_initialized # pylint: disable=global-statement if not _jax_distributed_initialized: From c3a46d4e4c03b8544ed4f4bdc25bcb5da4d5422b Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 29 Apr 2025 12:25:03 +0100 Subject: [PATCH 14/21] fix the array_serialization --- axlearn/common/array_serialization.py | 177 +++++++++++++++++++++++--- 1 file changed, 158 insertions(+), 19 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index dea44abab..9006eab22 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -16,20 +16,25 @@ https://github.com/google/orbax/blob/3cc343c63c769e4b2df44f3e57f6b5b43569df32/checkpoint/orbax/checkpoint/serialization.py https://github.com/google/jax/blob/595a620804e810335a870e93975a78504b2e95e5/jax/experimental/array_serialization/serialization.py """ - import asyncio import functools +import math +import os import threading import time from collections import defaultdict from concurrent import futures +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Any, Callable, Optional, Sequence, Union import jax +import jax.numpy as jnp import numpy as np +import tensorstore as ts from absl import logging -from jax._src import array, config, layout, typing +from jax._src import array, config, typing +from jax._src.layout import Layout from jax.experimental.array_serialization import serialization from axlearn.common.utils import Tensor @@ -168,8 +173,6 @@ async def _slice_shard_and_copy_to_host(shard_infos: list[_ShardInfo]): # Note: jax.lax.slice_in_dim in _slice_fn will be cached in jit cache after first call. shard_data = jax.tree.map(_slice_fn, shard_infos) shard_data = jax.tree.map(_transfer_to_host, shard_data) - shard_data = jax.tree.map(_slice_fn, shard_infos) - shard_data = jax.tree.map(_transfer_to_host, shard_data) await asyncio.sleep(0) # Allow other D2Hs to launch. @@ -346,6 +349,136 @@ async def _run_serializer( raise e +def _blocking_device_put(out: Tensor, layout: Layout) -> Tensor: + return jax.block_until_ready(jax.device_put(out, layout)) + + +async def _async_deserialize( + user_in_sharding: jax.sharding.Sharding | Layout, + tensorstore_spec: dict[str, Any], + global_shape: Optional[Sequence[int]], + dtype: Optional[typing.DTypeLike], + *, + h2d_limiter: serialization._LimitInFlightBytes, + byte_limiter: serialization._LimitInFlightBytes, + single_thread_pool: ThreadPoolExecutor, +): + """Modified from + https://github.com/jax-ml/jax/blob/e7ec418eba9ada336f755613948cbdf4a9e97d59/jax/experimental/array_serialization/serialization.py#L345 + + Changes: + 1. ts.cast is used rather than np.astype to allow casting on-the-fly. + 2. Avoid allocating a zero array if the global shape is the same as the shape of the tensor + stored in the checkpoint, which should be true for majority of the cases. + 3. Limit in flight padded H2D size to be smaller than premapped buffer size on TPU, so all H2Ds + can fit in the pre-mapped buffer. This is to avoid the significant runtime cost of + allocating large DMA buffers on-demand and to avoid having extra memory footprint for extra + DMA buffers. For tensors whose size exceed the entirety of the premapped buffer, their H2D + will be serialized using a single threaded threadpool. For non TPU backend, no limit on + in flight H2D is imposed. + + Combination of these optimizations speed up the loading of checkpoints as much as 5x if it's + not network-bound. + + ## Background on TPU H2D + + Each H2D consists of the following steps: + + Host buffer -> linearize -> (map DMA buffers) -> PCIe Copy, where linearization is the + conversion from host native layout to TPU native tiled layout. + + If there is sufficient capacity in the premapped DMA buffers, the map DMA step can be skipped, + and we linearize to a section of the pre-mapped DMA buffer directly. If there is sufficient + capacity in the pre-mapped buffer, we can perform several linearization concurrently for + improved performance. However, if there isn't sufficient capacity in the premapped buffer, + on-demand DMA buffer mapping is needed, and this is often very slow. Additionally, concurrently + mapping DMA buffers are neither faster (due to OS overhead) nor memory-efficient. Transparent + huge pages (THP) can help, but it's only for jax 0.5.1+. + """ + in_sharding = ( + user_in_sharding.sharding if isinstance(user_in_sharding, Layout) else user_in_sharding + ) + if not isinstance(in_sharding, jax.sharding.Sharding): + raise ValueError( + "sharding passed to deserialization should be specified, concrete and" + f" an instance of `jax.sharding.Sharding`. Got {in_sharding}" + ) + dll = user_in_sharding.device_local_layout if isinstance(user_in_sharding, Layout) else None + t = await ts.open( + tensorstore_spec, + open=True, + assume_metadata=False, + context=serialization.TS_CONTEXT, + ) + shape = tuple(t.shape if global_shape is None else global_shape) + new_shard_shape = in_sharding.shard_shape(shape) + loop = asyncio.get_running_loop() + + async def cb(index: array.Index, device: jax.Device): + requested_domain = ts.IndexTransform(input_shape=shape)[index].domain + restricted_domain = t.domain.intersect(requested_domain) + requested_bytes = serialization.estimate_read_memory_footprint(t, restricted_domain) + # Limit the bytes read for every shard. + await byte_limiter.wait_for_bytes(requested_bytes) + read_ts = t[restricted_domain] + # Use ts.cast rather than np.astype since ts can perform casting on-the-fly. + if dtype is not None: + read_ts = ts.cast(read_ts, dtype) + if tuple(t.shape) == shape: + # If the restore shape is the same as shape in ckpt, we can avoid the cost of + # allocating a zero array first. + out = np.empty(new_shard_shape, read_ts.dtype.numpy_dtype) + else: + # This maybe needed because the shape the array was saved with is smaller + # than the requested shape of the array in which it will be reloaded. So + # the extra values will be filled with 0s. + out = np.zeros(new_shard_shape, read_ts.dtype.numpy_dtype) + + await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write( + read_ts + ) + + # Convert to jnp array so that layouts are initialized properly for + # sub-byte dtypes. + # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to + # make this work. + if out.dtype == jnp.int4: + out = jnp.asarray(out) # type: ignore + + out_size = out.size * out.dtype.itemsize + # Pad to next 256mb. This is a very conservative padding. + mb_256 = 256 * 1024 * 1024 + out_size = math.ceil(out_size / mb_256) * mb_256 + + layout = Layout(dll, jax.sharding.SingleDeviceSharding(device)) + try: + await h2d_limiter.wait_for_bytes(out_size) + result = await loop.run_in_executor(None, _blocking_device_put, out, layout) + await h2d_limiter.release_bytes(out_size) + except ValueError as e: + if "Requested more bytes than we reserved" not in str(e): + raise e # Raise if it's not the type of error we expect. + logging.log_first_n( + logging.WARNING, + "Tensor shard for tensor %s (padded size %d bytes) exceeded " + "premapped buffer size %d. Consider allocating larger premapped buffer using " + "TPU_PREMAPPED_BUFFER_SIZE for improved H2D performance.", + 32, + str(out.shape), + out_size, + # pylint: disable-next=protected-access + h2d_limiter._max_bytes, + ) + result = await loop.run_in_executor( + single_thread_pool, _blocking_device_put, out, layout + ) + + await byte_limiter.release_bytes(requested_bytes) + return result + + return await serialization.create_async_array_from_callback(shape, in_sharding, cb) + + # Reference: # https://github.com/google/orbax/blob/ebb3e6d75f9ccb52bf862f1740943a45b18f4dac/checkpoint/orbax/checkpoint/future.py#L49 class _ThreadRaisingException(threading.Thread): @@ -379,6 +512,14 @@ def result(self, timeout: Optional[int] = None) -> Any: return self._t.join(timeout=timeout) +def _get_premapped_buffer_size(): + if jax.default_backend() == "tpu": + # If TPU_PREMAPPED_BUFFER_SIZE is not set, default is 4GB. + return int(os.getenv("TPU_PREMAPPED_BUFFER_SIZE", "4294967296")) + # On all other backends, use 1TB (effectively unlimited). + return 1099511627776 + + class GlobalAsyncCheckpointManager(serialization.GlobalAsyncCheckpointManager): """Similar to GlobalAsyncCheckpointManager but allows passing additional futures to be awaited while asynchronously serializing tensors. @@ -389,10 +530,16 @@ def __init__(self, *args, **kwargs): self._loop = asyncio.new_event_loop() self._loop_thread = threading.Thread(target=self._loop.run_forever, daemon=True) self._loop_thread.start() + self._single_thread_pool = ThreadPoolExecutor(1) - def __del__(self): + def stop(self): + """Cleans up any internal threads.""" self._loop.call_soon_threadsafe(self._loop.stop) self._loop_thread.join() + self._single_thread_pool.shutdown() + + def __del__(self): + self.stop() return super().__del__() def serialize( @@ -433,7 +580,7 @@ async def _run_serializer(): # https://github.com/jax-ml/jax/blob/66037d10e7742c4fcadd07f0459a00813ec7ed5f/jax/experimental/array_serialization/serialization.py#L413-L429 def deserialize( self, - shardings: Sequence[Union[jax.sharding.Sharding, layout.Layout]], + shardings: Sequence[Union[jax.sharding.Sharding, Layout]], tensorstore_specs: Sequence[dict[str, Any]], global_shapes: Optional[Sequence[array.Shape]] = None, dtypes: Optional[Sequence[typing.DTypeLike]] = None, @@ -445,8 +592,9 @@ def deserialize( async def _run_deserializer(): # Object should be created once per process. - # pylint: disable-next=protected-access + # pylint: disable=protected-access byte_limiter = serialization._LimitInFlightBytes(concurrent_bytes) + h2d_limiter = serialization._LimitInFlightBytes(_get_premapped_buffer_size()) future_arrays = jax.tree.map( functools.partial( @@ -469,16 +617,14 @@ async def _run_deserializer(): class BoundedDataShardedAsyncCheckpointManager(GlobalAsyncCheckpointManager): """Similar to GlobalAsyncCheckpointManager but with few improvements: - 1. Writing to tensorstore requires no host-to-host copy most of the time. This reduces host - memory usage while also reduces blocking time of the checkpointing process. - 2. Tensorstore calls now run in a background event loop, hiding the cost of `ts.open` and + 1. Tensorstore calls now run in a background event loop, hiding the cost of `ts.open` and `ts.copy`. Now, only D2H blocks training while serialization is fully asynchronous. - 3. Added additional sharding along data-parallel axis during save to further reduce host memory + 2. Added additional sharding along data-parallel axis during save to further reduce host memory overhead and improves D2H time. It's achieved by sharding the first dim that's divisible by the data-parallel dim. We manipulate shard.index to match the sliced shard, so to tensorstore it behaves as if we're sharding along the data-parallel axis. If no such dim is found, we use the old way to save-restore, i.e. using the first (0th) replica to do the save only. - 4. Optionally one can specify max_concurrent_gb to limit in-flight host memory during + 3. Optionally one can specify max_concurrent_gb to limit in-flight host memory during device-to-host transfers and tensorstore writes. Args: @@ -602,10 +748,3 @@ def serialize( logging.info("D2H during save took %fs. Starting async commit.", time.time() - start_t) self._start_async_commit(on_commit_callback) - - def stop(self): - """Disposes and cleanup any internal resources.""" - - def __del__(self): - super().__del__() - self.stop() From 0ddb4573520c13c57868dec24c91f3f9fd0780aa Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 29 Apr 2025 12:27:18 +0100 Subject: [PATCH 15/21] fix conflicts for optimizers.py --- axlearn/common/optimizers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index a618f46f5..aec2117fe 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -166,9 +166,6 @@ def copy_partition( memory_kind=( memory_kind if pattern and re.fullmatch(pattern, path) else spec.memory_kind ), - memory_kind=( - memory_kind if pattern and re.fullmatch(pattern, path) else spec.memory_kind - ), ), tree_paths(specs), specs, From 835ab56014debbcfa327c71266746484d635e348 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 29 Apr 2025 17:59:46 +0100 Subject: [PATCH 16/21] Update axlearn/common/utils.py Co-authored-by: apghml <143655008+apghml@users.noreply.github.com> --- axlearn/common/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 6a9cbd4aa..5413a5046 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1854,7 +1854,7 @@ def thread_stack_traces() -> Sequence[Sequence[str]]: return grouped_lines -def pytree_children(node: Any) -> List[Tuple[KeyEntry, Any]]: +def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]: """Generate the (key, value) pairs for the immediate children of a pytree `node`. The returned children match those returned by From 2c56be8a022fc5850871b2626a1efd9ce4a59447 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 29 Apr 2025 18:02:46 +0100 Subject: [PATCH 17/21] check pylint --- axlearn/common/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 5413a5046..bc6436bb9 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -31,12 +31,10 @@ from typing import ( Any, Callable, - List, Literal, NamedTuple, Optional, Protocol, - Tuple, TypeVar, Union, runtime_checkable, @@ -1868,7 +1866,6 @@ def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]: ``` """ try: - # pylint: disable-next=protected-access key_child_pairs, _ = flatten_one_level_with_keys(node) return list(key_child_pairs) except ValueError: From d2774c39185e557ad938c7805d463dc8303d002a Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 29 Apr 2025 21:56:58 +0100 Subject: [PATCH 18/21] fix docstring --- axlearn/common/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index bc6436bb9..6389156d0 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1858,8 +1858,6 @@ def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]: The returned children match those returned by `jax.tree_util.default_registry.flatten_one_level()`. - Reference: jax._src.tree_util.generate_key_paths() - Example: ``` assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])] From 7e9ca87fe16d1ca938137161bb71e400de0052c6 Mon Sep 17 00:00:00 2001 From: Steboss Date: Wed, 30 Apr 2025 10:44:21 +0100 Subject: [PATCH 19/21] use public API --- axlearn/common/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 6389156d0..1f4c5614a 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -48,11 +48,12 @@ from jax._src.ad_checkpoint import name_p from jax._src.lax import lax as lax_internal from jax._src.mesh import thread_resources -from jax._src.tree_util import KeyEntry, KeyPath, flatten_one_level_with_keys +from jax._src.tree_util import KeyEntry, KeyPath from jax.ad_checkpoint import Offloadable, Recompute, Saveable from jax.experimental import mesh_utils, multihost_utils from jax.extend.core import Primitive from jax.sharding import PartitionSpec +from jax.tree_util import default_registry from axlearn.common import serialization from axlearn.common.config import ( @@ -1864,7 +1865,7 @@ def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]: ``` """ try: - key_child_pairs, _ = flatten_one_level_with_keys(node) + key_child_pairs, _ = default_registry.flatten_one_level_with_keys(node) return list(key_child_pairs) except ValueError: return [] From c860d7fc03fbe428c645088272799fc16a22ec01 Mon Sep 17 00:00:00 2001 From: Steboss Date: Thu, 1 May 2025 11:06:56 +0100 Subject: [PATCH 20/21] fix tryexcept --- axlearn/common/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 1f4c5614a..2cda98018 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1864,11 +1864,9 @@ def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]: assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])] ``` """ - try: - key_child_pairs, _ = default_registry.flatten_one_level_with_keys(node) - return list(key_child_pairs) - except ValueError: - return [] + key_child_pairs, _ = default_registry.flatten_one_level_with_keys(node) + # if node is None key_child_pairs = (), so we're returning an empty list + return list(key_child_pairs) def find_cycles(tree: Nested) -> dict[str, KeyPath]: From 06054321342db493649b0d0f1dad59f54714ca3d Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 2 May 2025 17:22:59 +0100 Subject: [PATCH 21/21] make sure pytree contains all the previous cases + fix tests --- axlearn/common/utils.py | 24 +++++++++++++++++++++--- axlearn/common/utils_test.py | 19 +++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 2cda98018..e323a891c 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1864,9 +1864,27 @@ def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]: assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])] ``` """ - key_child_pairs, _ = default_registry.flatten_one_level_with_keys(node) - # if node is None key_child_pairs = (), so we're returning an empty list - return list(key_child_pairs) + # If node is a NamedTuple + if isinstance(node, tuple) and hasattr(node, "_fields"): + return [(jax.tree_util.GetAttrKey(name), getattr(node, name)) for name in node._fields] + + # If node is not a NT but exposes a public `_fields` attribute + if hasattr(node, "_fields") and not isinstance(node, tuple): + return [(jax.tree_util.GetAttrKey(name), getattr(node, name)) for name in node._fields] + # Standard JAX + try: + key_child_pairs, _ = default_registry.flatten_one_level_with_keys(node) + if key_child_pairs: + return list(key_child_pairs) + except (ValueError, TypeError): + pass + + # Node is Sequence + flat = jax.tree_util.default_registry.flatten_one_level(node) + if flat is None: + return [] + + return [(jax.tree_util.FlattenedIndexKey(i), child) for i, child in enumerate(flat[0])] def find_cycles(tree: Nested) -> dict[str, KeyPath]: diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 2411759c3..0b833906f 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -334,6 +334,25 @@ class TestUnstructured: [(jax.tree_util.FlattenedIndexKey(k), v) for k, v in enumerate(original_tree.values())], ) + # eg OutputCollection(summaries={}, state_updates={}, module_outputs={})) + class CustomWithFields: + _fields = ("a", "b", "c") + + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + + tree = CustomWithFields(**original_tree) + self.assertSequenceEqual( + pytree_children(tree), + [(jax.tree_util.GetAttrKey(k), getattr(tree, k)) for k in CustomWithFields._fields], + ) + + # Test object() + obj = object() + self.assertSequenceEqual(pytree_children(obj), []) + # No children self.assertSequenceEqual(pytree_children([]), [])