-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable Context Parallel #592
base: gh/XilunWu/6/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
ghstack-source-id: b76e0d183826dad8c4c76426fe62abaf9ad43f2f Pull Request resolved: #592
[ghstack-poisoned]
ghstack-source-id: 90f1bde378561c9bd1dee3ac82990f9d91ba59ab Pull Request resolved: #592
# (use 2x max sequence length to be safe) | ||
self.model_args.max_seq_len * 2, | ||
# Note: removed the 2x relaxing in CP enablement | ||
self.model_args.max_seq_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc., @tianyu-l Want to understand is this okay?
For a general use case, we can also expand the CP to support stride
-like feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate a bit on why this change was needed by CP?
if parallel_dims.cp_enabled: | ||
dp_dim_names = dp_mesh.mesh_dim_names | ||
assert isinstance(dp_dim_names, Tuple) | ||
dp_mesh = world_mesh[(*dp_dim_names, "cp")]._flatten() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember we want to initialize all the PG in the very beginning. Can we move this to parallel_dims.py
and use mesh_dim_name
to rename it to dp_shard_cp
?
[ghstack-poisoned]
ghstack-source-id: f9bc24ff92c0abce98dc3a0f847fc874fa77788c Pull Request resolved: #592
[ghstack-poisoned]
ghstack-source-id: 51288d0a142c839291d6035e6dddcc915e5e5a08 Pull Request resolved: #592
[ghstack-poisoned]
ghstack-source-id: 6126585b13e49131e8b2d9e05a5ef1f736a0c4d9 Pull Request resolved: #592
Stack from ghstack (oldest at bottom):