DDP refactoring: Extract parameter layout computation into optimizer classmethod#3812
Draft
deepakn94 wants to merge 1 commit intoNVIDIA:mainfrom
Draft
DDP refactoring: Extract parameter layout computation into optimizer classmethod#3812deepakn94 wants to merge 1 commit intoNVIDIA:mainfrom
deepakn94 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
…classmethod Move the optimizer-specific parameter layout logic (padding, bucket splitting) out of _ParamAndGradBuffer.__init__ and into DistributedOptimizer.compute_param_layout(). This decouples the buffer from optimizer-specific assumptions, allowing future optimizer implementations to define custom parameter layouts by overriding the classmethod. Introduces ParamLayout dataclass and _default_param_layout() for the non-distributed case. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Deepak Narayanan <dnarayanan@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
deepakn94
commented
Mar 11, 2026
| Pads end index of bucket if using distributed optimizer (to ensure uniform sharding). | ||
| """ | ||
| if self.ddp_config.use_distributed_optimizer: | ||
| # Workaround for TE bug causing cuBLAS to pick an incompatible algorithm. |
Contributor
Author
There was a problem hiding this comment.
Don't delete this.
deepakn94
commented
Mar 11, 2026
| self.bucket_indices = layout.bucket_indices | ||
| per_bucket_numel_unpadded = layout.per_bucket_numel_unpadded | ||
|
|
||
| def _pad(number_to_be_padded: int, divisor: int) -> int: |
Contributor
Author
There was a problem hiding this comment.
Is this method used at all?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Move the optimizer-specific parameter layout logic (padding, bucket splitting) out of _ParamAndGradBuffer.init and into DistributedOptimizer.compute_param_layout(). This decouples the buffer from optimizer-specific assumptions, allowing future optimizer implementations to define custom parameter layouts by overriding the classmethod.
Introduces ParamLayout dataclass and _default_param_layout() for the non-distributed case.