Skip to content

[Jax API update] Remove jax_spmd_mode #1136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

Steboss
Copy link
Contributor

@Steboss Steboss commented Apr 28, 2025

This PR is a continuation of #1106, and it refers to issue #1126
In particular here the major change is in axlearn/common/utils_spmd.py, where the config doesn't need to be updated anymore, as jax_spmd_mode has been removed in JAX 0.6.0 (here)

@Steboss Steboss force-pushed the sbosisio/tree_util branch from 52c92e4 to 8de926a Compare April 29, 2025 11:22
@Steboss
Copy link
Contributor Author

Steboss commented Apr 29, 2025

@apghml I've rebased the PR with the current main :)

@Steboss
Copy link
Contributor Author

Steboss commented May 2, 2025

hey @apghml
So I investigated over the pytree_children function, so that it could match exactly the previous implementation. For simplicity, i'll copy and paste here a part of the initial output only. This is the current implementation

I0502 06:21:33.000848 140737350427392 utils.py:1878] ((GetAttrKey(name='output_collection'), OutputCollection(summaries={}, state_updates={}, module_outputs={})), (GetAttrKey(name='parent'), None), (GetAttrKey(name='prng_key'), Traced<ShapedArray(uint32[4])>with<DynamicJaxprTrace>), (GetAttrKey(name='state'), {'optimizer': (EmptyState(), (ScaleByAdamState(count=Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>, mu={'decoder': {'emb': {'dropout': {}, 'token_emb': {'weight': Traced<ShapedArray(float32[131072,3072])>with<DynamicJaxprTrace>}}, 'output_dropout': {}, 'output_norm': {'scale': Traced<ShapedArray(float32[3072])>with<DynamicJaxprTrace>}, 'transformer': {'repeat': VDict({'layer': {'feed_forward': {'dropout1': {}, 'dropout2': {}, 'linear1_0': {'weight': Traced<ShapedArray(float32[28,3072,8192])>with<DynamicJaxprTrace>}, 'linear1_1': {'weight': Traced<ShapedArray(float32[28,3072,8192])>with<DynamicJaxprTrace>}, 'linear2': {'weight': Traced<ShapedArray(float32[28,8192,3072])>with<DynamicJaxprTrace>}, 'norm': {'scale': Traced<ShapedArray(float32[28,3072])>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}, 'self_attention': {'attention': {'dropout': {}, 'i_proj': {'i_proj': {'qkv_proj': {'weight': Traced<ShapedArray(float32[28,3072,40,128])>with<DynamicJaxprTrace>}}, 'rope_pos_emb_layer': {}}, 'kv_cache': {}, 'o_proj': {'weight': Traced<ShapedArray(float32[28,3072,24,128])>with<DynamicJaxprTrace>}, 'scale_key': {}, 'scale_query': {}}, 'dropout': {}, 'norm': {'scale': Traced<ShapedArray(float32[28,3072])>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}}})}}, 'metrics': {'aux': {}, 'lm': {}}}, nu={'decoder': {'emb': {'dropout': {}, 'token_emb': {'weight': Traced<ShapedArray(float32[131072,3072])>with<DynamicJaxprTrace>}}, 'output_dropout': {}, 'output_norm': {'scale': Traced<ShapedArray(float32[3072])>with<DynamicJaxprTrace>}, 'transformer': {'repeat': VDict({'layer': {'feed_forward': {'dropout1': {}, 'dropout2': {}, 'linear1_0': {'weight': Traced<ShapedArray(float32[28,3072,8192])>with<DynamicJaxprTrace>}, 'linear1_1': {'weight': Traced<ShapedArray(float32[28,3072,8192])>with<DynamicJaxprTrace>}, 'linear2': {'weight': Traced<ShapedArray(float32[28,8192,3072])>with<DynamicJaxprTrace>}, 'norm': {'scale': Traced<ShapedArray(float32[28,3072])>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}, 'self_attention': {'attention': {'dropout': {}, 'i_proj': {'i_proj': {'qkv_proj': {'weight': Traced<ShapedArray(float32[28,3072,40,128])>with<DynamicJaxprTrace>}}, 'rope_pos_emb_layer': {}}, 'kv_cache': {}, 'o_proj': {'weight': Traced<ShapedArray(float32[28,3072,24,128])>with<DynamicJaxprTrace>}, 'scale_key': {}, 'scale_query': {}}, 'dropout': {}, 'norm': {'scale': Traced<ShapedArray(float32[28,3072])>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}}})}}, 'metrics': {'aux': {}, 'lm': {}}}), ScaleByScheduleState(count=Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>), AddDecayedWeightsState(count=None), ScaleByScheduleState(count=Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>), EmptyState()))}))


I0502 06:21:33.000930 140737350427392 utils.py:1888] [(GetAttrKey(name='summaries'), {}), (GetAttrKey(name='state_updates'), {}), (GetAttrKey(name='module_outputs'), {})]

and this is the output of the new implementation

I0502 09:17:33.186500 140737350427392 utils.py:1893] [(GetAttrKey(name='output_collection'), OutputCollection(summaries={}, state_updates={}, module_outputs={})), (GetAttrKey(name='parent'), None), (GetAttrKey(name='prng_key'), Traced<uint32[4]>with<DynamicJaxprTrace>), (GetAttrKey(name='state'), {'optimizer': (EmptyState(), (ScaleByAdamState(count=Traced<int32[]>with<DynamicJaxprTrace>, mu={'decoder': {'emb': {'dropout': {}, 'token_emb': {'weight': Traced<float32[131072,3072]>with<DynamicJaxprTrace>}}, 'output_dropout': {}, 'output_norm': {'scale': Traced<float32[3072]>with<DynamicJaxprTrace>}, 'transformer': {'repeat': VDict({'layer': {'feed_forward': {'dropout1': {}, 'dropout2': {}, 'linear1_0': {'weight': Traced<float32[28,3072,8192]>with<DynamicJaxprTrace>}, 'linear1_1': {'weight': Traced<float32[28,3072,8192]>with<DynamicJaxprTrace>}, 'linear2': {'weight': Traced<float32[28,8192,3072]>with<DynamicJaxprTrace>}, 'norm': {'scale': Traced<float32[28,3072]>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}, 'self_attention': {'attention': {'dropout': {}, 'i_proj': {'i_proj': {'qkv_proj': {'weight': Traced<float32[28,3072,40,128]>with<DynamicJaxprTrace>}}, 'rope_pos_emb_layer': {}}, 'kv_cache': {}, 'o_proj': {'weight': Traced<float32[28,3072,24,128]>with<DynamicJaxprTrace>}, 'scale_key': {}, 'scale_query': {}}, 'dropout': {}, 'norm': {'scale': Traced<float32[28,3072]>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}}})}}, 'metrics': {'aux': {}, 'lm': {}}}, nu={'decoder': {'emb': {'dropout': {}, 'token_emb': {'weight': Traced<float32[131072,3072]>with<DynamicJaxprTrace>}}, 'output_dropout': {}, 'output_norm': {'scale': Traced<float32[3072]>with<DynamicJaxprTrace>}, 'transformer': {'repeat': VDict({'layer': {'feed_forward': {'dropout1': {}, 'dropout2': {}, 'linear1_0': {'weight': Traced<float32[28,3072,8192]>with<DynamicJaxprTrace>}, 'linear1_1': {'weight': Traced<float32[28,3072,8192]>with<DynamicJaxprTrace>}, 'linear2': {'weight': Traced<float32[28,8192,3072]>with<DynamicJaxprTrace>}, 'norm': {'scale': Traced<float32[28,3072]>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}, 'self_attention': {'attention': {'dropout': {}, 'i_proj': {'i_proj': {'qkv_proj': {'weight': Traced<float32[28,3072,40,128]>with<DynamicJaxprTrace>}}, 'rope_pos_emb_layer': {}}, 'kv_cache': {}, 'o_proj': {'weight': Traced<float32[28,3072,24,128]>with<DynamicJaxprTrace>}, 'scale_key': {}, 'scale_query': {}}, 'dropout': {}, 'norm': {'scale': Traced<float32[28,3072]>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}}})}}, 'metrics': {'aux': {}, 'lm': {}}}), ScaleByScheduleState(count=Traced<int32[]>with<DynamicJaxprTrace>), AddDecayedWeightsState(count=None), ScaleByScheduleState(count=Traced<int32[]>with<DynamicJaxprTrace>), EmptyState()))})]

I0502 09:17:33.186563 140737350427392 utils.py:1876] [(GetAttrKey(name='summaries'), {}), (GetAttrKey(name='state_updates'), {}), (GetAttrKey(name='module_outputs'), {})]

I tested by running fuji-3B-v3-flash, c4 dataset, ICI FSDP=8, global batch size 16, sequence length 4096. The only difference in the output is the print for Traced<ShapedArray(int32[])> that, now, with JAX is Traced<int32[]> - so nothing to be worried about.

I've updated the test suite, so that we have:

  • a custom object with fields _fields, like OutputCollection(summaries={}, state_updates={}, module_outputs={})) that we see as input node
  • and a test with object():
obj = object()
self.assertSequenceEqual(pytree_children(obj), [])

The test on pytree_children works fine.

Finally, I checked the performance, related to the removal of spmd_mode, and the overall results are the same:

Metrics This PR implementation Previous AXLearn implementation
Tokens per sec per gpu 9288 8904
Seqs per sec per gpu 2.26 2.17
Average time step 0.88 0.91
TFLOPS per sec per GPU 218.80 209.74


tree = CustomWithFields(**original_tree)
self.assertSequenceEqual(
pytree_children(tree),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm..., this seems to not match the behavior of jax.tree.leaves()?

import jax

class CustomWithFields:
    _fields = ("a", "b", "c")

    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.c = c

print(jax.tree.leaves(CustomWithFields(1,2,3)))  # Prints the entire object as a single leaf.

key_children, _ = key_handler.flatten_with_keys(node)
return key_children
# If node is a NamedTuple
if isinstance(node, tuple) and hasattr(node, "_fields"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? It looks like the output is already correct?

from typing import NamedTuple

import jax

class C(NamedTuple):
    a: int

print(jax.tree_util.default_registry.flatten_one_level_with_keys(C(1)))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why did we move this earlier in the function body? IIUC, this will change the behavior in the case that someone creates a NamedTuple subclass and registers a custom tree flattening handler for it?

# Handle namedtuple as a special case, based on heuristic.
return [(jax.tree_util.GetAttrKey(s), getattr(node, s)) for s in node._fields]
return [(jax.tree_util.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])]
return [(jax.tree_util.FlattenedIndexKey(i), child) for i, child in enumerate(flat[0])]
Copy link
Contributor

@apghml apghml May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I'm a bit nervous that the various changes in this function may cause us to deviate from what jax does, even if it has the same behavior in existing axlearn code. Was there a reason we needed to change this function?

Alternatively, would it work to do something like:

jax.tree_util.tree_map_with_path(lambda *args: args, node, is_leaf=lambda x: x is not node)

and then pull the k,v pairs out of the result of that?

key_child_pairs, _ = default_registry.flatten_one_level_with_keys(node)
if key_child_pairs:
return list(key_child_pairs)
except (ValueError, TypeError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why we still need to catch these errors?

@Steboss
Copy link
Contributor Author

Steboss commented May 2, 2025

@apghml thanks for your comments.
let's to do in this way:

  • i'll create a PR only for jax_spmd_mode removal
  • and i'll create a new PR for utils, so I'll check if we really need to restructure the utils.py pytree_children function completely, or if JAX can give us some alternatives
    What do you think?

@@ -1853,35 +1853,38 @@ def thread_stack_traces() -> Sequence[Sequence[str]]:
return grouped_lines


def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]:
def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]:
Copy link
Contributor

@apghml apghml May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this example from the jax docs fail with the new implementation?

import jax.numpy as jnp
import jax.tree
from jax.tree_util import GetAttrKey


class MyContainer:

    def __init__(self):

        self.x = jnp.zeros(2)
        self.y = jnp.ones(2)



def flatten_with_keys(obj):

    children = [(GetAttrKey('x'), obj.x),
                (GetAttrKey('y'), obj.y)]



    aux_data = ()  # aux_data must contain static, hashable data.

    return children, aux_data

def unflatten(aux_data, children):
    obj = object.__new__(MyContainer)

    obj.x, obj.y = children

    obj.size, = aux_data

    return obj

jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten)

pytree_children(MyContainer())

Copy link
Contributor

@apghml apghml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and i'll create a new PR for utils, so I'll check if we really need to restructure the utils.py pytree_children function completely, or if JAX can give us some alternatives
What do you think?

Sound good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants