-
Notifications
You must be signed in to change notification settings - Fork 317
[Jax API update] Remove jax_spmd_mode
#1136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
52c92e4
to
8de926a
Compare
@apghml I've rebased the PR with the current |
Co-authored-by: apghml <[email protected]>
hey @apghml I0502 06:21:33.000848 140737350427392 utils.py:1878] ((GetAttrKey(name='output_collection'), OutputCollection(summaries={}, state_updates={}, module_outputs={})), (GetAttrKey(name='parent'), None), (GetAttrKey(name='prng_key'), Traced<ShapedArray(uint32[4])>with<DynamicJaxprTrace>), (GetAttrKey(name='state'), {'optimizer': (EmptyState(), (ScaleByAdamState(count=Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>, mu={'decoder': {'emb': {'dropout': {}, 'token_emb': {'weight': Traced<ShapedArray(float32[131072,3072])>with<DynamicJaxprTrace>}}, 'output_dropout': {}, 'output_norm': {'scale': Traced<ShapedArray(float32[3072])>with<DynamicJaxprTrace>}, 'transformer': {'repeat': VDict({'layer': {'feed_forward': {'dropout1': {}, 'dropout2': {}, 'linear1_0': {'weight': Traced<ShapedArray(float32[28,3072,8192])>with<DynamicJaxprTrace>}, 'linear1_1': {'weight': Traced<ShapedArray(float32[28,3072,8192])>with<DynamicJaxprTrace>}, 'linear2': {'weight': Traced<ShapedArray(float32[28,8192,3072])>with<DynamicJaxprTrace>}, 'norm': {'scale': Traced<ShapedArray(float32[28,3072])>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}, 'self_attention': {'attention': {'dropout': {}, 'i_proj': {'i_proj': {'qkv_proj': {'weight': Traced<ShapedArray(float32[28,3072,40,128])>with<DynamicJaxprTrace>}}, 'rope_pos_emb_layer': {}}, 'kv_cache': {}, 'o_proj': {'weight': Traced<ShapedArray(float32[28,3072,24,128])>with<DynamicJaxprTrace>}, 'scale_key': {}, 'scale_query': {}}, 'dropout': {}, 'norm': {'scale': Traced<ShapedArray(float32[28,3072])>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}}})}}, 'metrics': {'aux': {}, 'lm': {}}}, nu={'decoder': {'emb': {'dropout': {}, 'token_emb': {'weight': Traced<ShapedArray(float32[131072,3072])>with<DynamicJaxprTrace>}}, 'output_dropout': {}, 'output_norm': {'scale': Traced<ShapedArray(float32[3072])>with<DynamicJaxprTrace>}, 'transformer': {'repeat': VDict({'layer': {'feed_forward': {'dropout1': {}, 'dropout2': {}, 'linear1_0': {'weight': Traced<ShapedArray(float32[28,3072,8192])>with<DynamicJaxprTrace>}, 'linear1_1': {'weight': Traced<ShapedArray(float32[28,3072,8192])>with<DynamicJaxprTrace>}, 'linear2': {'weight': Traced<ShapedArray(float32[28,8192,3072])>with<DynamicJaxprTrace>}, 'norm': {'scale': Traced<ShapedArray(float32[28,3072])>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}, 'self_attention': {'attention': {'dropout': {}, 'i_proj': {'i_proj': {'qkv_proj': {'weight': Traced<ShapedArray(float32[28,3072,40,128])>with<DynamicJaxprTrace>}}, 'rope_pos_emb_layer': {}}, 'kv_cache': {}, 'o_proj': {'weight': Traced<ShapedArray(float32[28,3072,24,128])>with<DynamicJaxprTrace>}, 'scale_key': {}, 'scale_query': {}}, 'dropout': {}, 'norm': {'scale': Traced<ShapedArray(float32[28,3072])>with<DynamicJaxprTrace>}, 'stochastic_depth': {}}}})}}, 'metrics': {'aux': {}, 'lm': {}}}), ScaleByScheduleState(count=Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>), AddDecayedWeightsState(count=None), ScaleByScheduleState(count=Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>), EmptyState()))}))
I0502 06:21:33.000930 140737350427392 utils.py:1888] [(GetAttrKey(name='summaries'), {}), (GetAttrKey(name='state_updates'), {}), (GetAttrKey(name='module_outputs'), {})] and this is the output of the new implementation
I tested by running I've updated the test suite, so that we have:
obj = object()
self.assertSequenceEqual(pytree_children(obj), []) The test on Finally, I checked the performance, related to the removal of
|
|
||
tree = CustomWithFields(**original_tree) | ||
self.assertSequenceEqual( | ||
pytree_children(tree), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm..., this seems to not match the behavior of jax.tree.leaves()
?
import jax
class CustomWithFields:
_fields = ("a", "b", "c")
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
print(jax.tree.leaves(CustomWithFields(1,2,3))) # Prints the entire object as a single leaf.
key_children, _ = key_handler.flatten_with_keys(node) | ||
return key_children | ||
# If node is a NamedTuple | ||
if isinstance(node, tuple) and hasattr(node, "_fields"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? It looks like the output is already correct?
from typing import NamedTuple
import jax
class C(NamedTuple):
a: int
print(jax.tree_util.default_registry.flatten_one_level_with_keys(C(1)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also why did we move this earlier in the function body? IIUC, this will change the behavior in the case that someone creates a NamedTuple subclass and registers a custom tree flattening handler for it?
# Handle namedtuple as a special case, based on heuristic. | ||
return [(jax.tree_util.GetAttrKey(s), getattr(node, s)) for s in node._fields] | ||
return [(jax.tree_util.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])] | ||
return [(jax.tree_util.FlattenedIndexKey(i), child) for i, child in enumerate(flat[0])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I'm a bit nervous that the various changes in this function may cause us to deviate from what jax does, even if it has the same behavior in existing axlearn code. Was there a reason we needed to change this function?
Alternatively, would it work to do something like:
jax.tree_util.tree_map_with_path(lambda *args: args, node, is_leaf=lambda x: x is not node)
and then pull the k,v pairs out of the result of that?
key_child_pairs, _ = default_registry.flatten_one_level_with_keys(node) | ||
if key_child_pairs: | ||
return list(key_child_pairs) | ||
except (ValueError, TypeError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why we still need to catch these errors?
@apghml thanks for your comments.
|
@@ -1853,35 +1853,38 @@ def thread_stack_traces() -> Sequence[Sequence[str]]: | |||
return grouped_lines | |||
|
|||
|
|||
def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]: | |||
def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this example from the jax docs fail with the new implementation?
import jax.numpy as jnp
import jax.tree
from jax.tree_util import GetAttrKey
class MyContainer:
def __init__(self):
self.x = jnp.zeros(2)
self.y = jnp.ones(2)
def flatten_with_keys(obj):
children = [(GetAttrKey('x'), obj.x),
(GetAttrKey('y'), obj.y)]
aux_data = () # aux_data must contain static, hashable data.
return children, aux_data
def unflatten(aux_data, children):
obj = object.__new__(MyContainer)
obj.x, obj.y = children
obj.size, = aux_data
return obj
jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten)
pytree_children(MyContainer())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and i'll create a new PR for utils, so I'll check if we really need to restructure the utils.py pytree_children function completely, or if JAX can give us some alternatives
What do you think?
Sound good.
This PR is a continuation of #1106, and it refers to issue #1126
In particular here the major change is in
axlearn/common/utils_spmd.py
, where the config doesn't need to be updated anymore, asjax_spmd_mode
has been removed in JAX 0.6.0 (here)