Skip to content

DDP refactoring: Extract parameter layout computation into optimizer classmethod#3812

Draft
deepakn94 wants to merge 1 commit intoNVIDIA:mainfrom
deepakn94:dnarayanan/refactor_param_mapping
Draft

DDP refactoring: Extract parameter layout computation into optimizer classmethod#3812
deepakn94 wants to merge 1 commit intoNVIDIA:mainfrom
deepakn94:dnarayanan/refactor_param_mapping

Conversation

@deepakn94
Copy link
Contributor

Move the optimizer-specific parameter layout logic (padding, bucket splitting) out of _ParamAndGradBuffer.init and into DistributedOptimizer.compute_param_layout(). This decouples the buffer from optimizer-specific assumptions, allowing future optimizer implementations to define custom parameter layouts by overriding the classmethod.

Introduces ParamLayout dataclass and _default_param_layout() for the non-distributed case.

…classmethod

Move the optimizer-specific parameter layout logic (padding, bucket splitting)
out of _ParamAndGradBuffer.__init__ and into DistributedOptimizer.compute_param_layout().
This decouples the buffer from optimizer-specific assumptions, allowing future
optimizer implementations to define custom parameter layouts by overriding the classmethod.

Introduces ParamLayout dataclass and _default_param_layout() for the non-distributed case.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Deepak Narayanan <dnarayanan@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 11, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
"""
if self.ddp_config.use_distributed_optimizer:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't delete this.

self.bucket_indices = layout.bucket_indices
per_bucket_numel_unpadded = layout.per_bucket_numel_unpadded

def _pad(number_to_be_padded: int, divisor: int) -> int:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method used at all?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant