Skip to content

Commit ec37d24

Browse files
committed
add hsdp + ep support on fully_shard model api
1 parent 4de1eef commit ec37d24

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

megatron/core/distributed/fsdp/src/README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,14 @@ device_mesh[("dp_shard", "cp")]._flatten("dp_shard_cp")
144144
# Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group.
145145
device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp")
146146
hsdp_group = device_mesh["hsdp"].get_group()
147-
# Initialize DeviceMesh for expert parallel (EP) modules when using FSDP + EP.
147+
148+
# Initialize DeviceMesh for expert parallel (EP) modules when using HSDP + EP.
148149
expert_device_mesh = torch.distributed.device_mesh.init_device_mesh(
149150
"cuda",
150-
mesh_shape=(expt_dp_shard_size, expt_tp_size),
151-
mesh_dim_names=("dp_shard", "tp"),
151+
mesh_shape=(dp_outer_size, expt_dp_shard_size, expt_tp_size),
152+
mesh_dim_names=("dp_outer", "dp_shard_cp", "tp"),
152153
)
153-
154+
hsdp_expt_group = expert_device_mesh[("dp_outer", "dp_shard_cp")].get_group()
154155
"""
155156
Fully-shard the model for Megatron-FSDP. This wraps the model in a MegatronFSDP
156157
class that schedules the sharding lifecycle of the model parameters and gradients
@@ -174,6 +175,8 @@ model = fully_shard_model(
174175
tp_dim="tp",
175176
# Only required when using HSDP. Otherwise, set this to None.
176177
hybrid_fsdp_group=hsdp_group,
178+
# Only required when using HSDP + EP. Otherwise, set this to None.
179+
hybrid_fsdp_expt_group=hsdp_expt_group,
177180
# Only required for FSDP + EP. Otherwise, set this to None.
178181
expt_device_mesh=expt_device_mesh,
179182
# FSDP Sharding Strategy: no_shard (0) / optim (1) / optim_grads (2) / optim_grads_params (3)
@@ -260,6 +263,7 @@ optimizer.load_state_dict(ckpt_state_dict["optimizer"])
260263
- `tp_dim` is the name of the sub-mesh used for tensor parallelism (TP), which is required for `(FSDP, TP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` TP.
261264
- For more information about tensor parallelism, refer to: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053).
262265
- `hybrid_fsdp_group` is the `ProcessGroup` which contains all ranks in the flattened `dp_shard_dim` and `dp_outer_dim` sub-meshes utilized to specify the `(DP-Outer, DP-Shard)` sharded mesh coordinates for the weight and gradient buffers. Required for HSDP.
266+
- `hybrid_fsdp_expt_group` is the equivalent `ProcessGroup` for Expert Parallel (EP) modules. It also contains all ranks from these flattened sub-meshes to establish the same `(DP-Outer, DP-Shard)` sharding coordinates and is required for combined HSDP + EP.
263267
- `expt_device_mesh` is another [`torch.distributed.DeviceMesh`](https://docs.pytorch.org/docs/stable/distributed.html#devicemesh) tailored for the expert parallel (EP) modules in `MegatronFSDP`.
264268
- `dp_shard_dim` is the name of the sub-mesh required for FSDP sharding of the EP modules, enabling expert data parallelism (EDP).
265269
- `tp_dim` is the name of the sub-mesh used for expert tensor parallelism (ETP), which is required for `(FSDP, ETP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` ETP.

megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def fully_shard_model(
7777
dp_outer_dim: Optional[str] = None,
7878
tp_dim: Optional[str] = None,
7979
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
80+
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
8081
expt_device_mesh: Optional[DeviceMesh] = None,
8182
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
8283
zero_dp_strategy: str | int = 3,
@@ -341,6 +342,8 @@ class that schedules the sharding lifecycle of the model parameters and gradient
341342
tp_dim=tp_dim,
342343
# Only required for HSDP.
343344
hybrid_fsdp_group=hybrid_fsdp_group,
345+
# Only required for HSDP + EP.
346+
hybrid_fsdp_expt_group=hybrid_fsdp_expt_group,
344347
# Access to flattened DP rank assignments for HSDP.
345348
hsdp_outer_dp_shard=_outer_fsdp_sharding,
346349
# Only required for Megatron-FSDP + EP.
@@ -509,6 +512,7 @@ def fully_shard(
509512
dp_outer_dim: Optional[str] = None,
510513
tp_dim: Optional[str] = None,
511514
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
515+
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
512516
expt_device_mesh: Optional[DeviceMesh] = None,
513517
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
514518
zero_dp_strategy: str | int = 3,
@@ -555,6 +559,7 @@ def fully_shard(
555559
dp_outer_dim=dp_outer_dim,
556560
tp_dim=tp_dim,
557561
hybrid_fsdp_group=hybrid_fsdp_group,
562+
hybrid_fsdp_expt_group=hybrid_fsdp_expt_group,
558563
expt_device_mesh=expt_device_mesh,
559564
fsdp_unit_modules=fsdp_unit_modules,
560565
zero_dp_strategy=zero_dp_strategy,

0 commit comments

Comments
 (0)