Skip to content

Commit 4de1eef

Browse files
committed
lint code by bash ./tools/autoformat.sh
1 parent 1c448e7 commit 4de1eef

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

megatron/core/distributed/fsdp/mcore_fsdp_adapter.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ def _init_dist_index(self, pg_collection):
251251

252252
if enable_hsdp:
253253
if expt_dp_group is not None:
254-
expt_mesh = _get_hsdp_tp_mesh(outer_fsdp_group, expt_dp_group, expt_tp_group, ep_size=ep_group.size())
254+
expt_mesh = _get_hsdp_tp_mesh(
255+
outer_fsdp_group, expt_dp_group, expt_tp_group, ep_size=ep_group.size()
256+
)
255257
expt_device_mesh = DeviceMesh.from_group(
256258
[outer_fsdp_group, expt_dp_group, expt_tp_group],
257259
device_type="cuda",
@@ -274,7 +276,7 @@ def _init_dist_index(self, pg_collection):
274276
tp_dim="tp",
275277
hybrid_fsdp_group=hybrid_fsdp_group,
276278
hybrid_fsdp_expt_group=hybrid_fsdp_expt_group,
277-
expt_device_mesh=expt_device_mesh
279+
expt_device_mesh=expt_device_mesh,
278280
)
279281
else:
280282
if ep_group is not None:
@@ -385,7 +387,10 @@ def _get_hsdp_tp_mesh(outer_fsdp_dp_group, dp_cp_group, tp_group, ep_size=1):
385387
assert (
386388
len(dp_tp_meshes) == 1
387389
), f"[Megatron-FSDP] Current rank {rank} is not unique in the mesh ranks {mesh.tolist()}."
388-
assert len(dp_tp_meshes[0].reshape(-1).tolist()) == outer_fsdp_dp_group.size() * dp_cp_group.size() * tp_group.size(), (
390+
assert (
391+
len(dp_tp_meshes[0].reshape(-1).tolist())
392+
== outer_fsdp_dp_group.size() * dp_cp_group.size() * tp_group.size()
393+
), (
389394
f"[Megatron-FSDP] DP-TP mesh size {len(dp_tp_meshes[0].reshape(-1).tolist())} "
390395
f"does not match expected size {outer_fsdp_dp_group.size() * dp_cp_group.size() * tp_group.size()}."
391396
)

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1869,7 +1869,9 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
18691869
hsdp_buf_dp_group = self.dist_index.get_fsdp_group(
18701870
is_expert_parallel=group.is_expert_param
18711871
)
1872-
main_buf_extra_kwargs["dp_rank"] = self.dist_index.get_logical_hybrid_fsdp_rank(is_expert_parallel=group.is_expert_param)
1872+
main_buf_extra_kwargs["dp_rank"] = self.dist_index.get_logical_hybrid_fsdp_rank(
1873+
is_expert_parallel=group.is_expert_param
1874+
)
18731875
else:
18741876
main_buf_dp_group = self.dist_index.get_fsdp_group(
18751877
is_expert_parallel=group.is_expert_param

megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,11 @@ def get_logical_hybrid_fsdp_rank(self, is_expert_parallel: bool = False):
929929
self.hsdp_outer_dp_shard
930930
), "get_logical_hybrid_fsdp_rank is only valid when full-shard hybrid FSDP is enabled."
931931

932-
_hybrid_fsdp_group_name = "_hybrid_fsdp_group_ranks" if not is_expert_parallel else "_hybrid_fsdp_expt_group_ranks"
932+
_hybrid_fsdp_group_name = (
933+
"_hybrid_fsdp_group_ranks"
934+
if not is_expert_parallel
935+
else "_hybrid_fsdp_expt_group_ranks"
936+
)
933937

934938
if not hasattr(self, _hybrid_fsdp_group_name):
935939
dp_world_size = self.get_dp_group(is_expert_parallel).size()
@@ -944,7 +948,9 @@ def get_logical_hybrid_fsdp_rank(self, is_expert_parallel: bool = False):
944948
setattr(self, _hybrid_fsdp_group_name, mesh.tolist())
945949

946950
# Find the index for the current rank in the hybrid group
947-
return getattr(self, _hybrid_fsdp_group_name).index(self.get_dp_group(is_expert_parallel).rank())
951+
return getattr(self, _hybrid_fsdp_group_name).index(
952+
self.get_dp_group(is_expert_parallel).rank()
953+
)
948954

949955

950956
class GlobalMemoryBuffer:

0 commit comments

Comments
 (0)