Skip to content

Commit e79dd7a

Browse files
authored
Revert "[style] minor: remove subclass" (#1441)
1 parent 856458a commit e79dd7a

File tree

6 files changed

+97
-96
lines changed

6 files changed

+97
-96
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ..training_utils.loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function
3232
from . import checkpoint
3333
from .lr_scheduler import get_lr_scheduler
34-
from .parallel import create_fsdp_parallel_state
34+
from .parallel import FSDPParallelState
3535
from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor
3636

3737
logger = logging.getLogger(__name__)
@@ -55,7 +55,7 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
5555
super().init(args, role, with_ref)
5656

5757
# Setup ParallelState for both CP and non-CP cases
58-
self.parallel_state = create_fsdp_parallel_state(args)
58+
self.parallel_state = FSDPParallelState(args)
5959

6060
torch.manual_seed(args.seed)
6161

slime/backends/fsdp_utils/parallel.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,47 +12,45 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15-
def create_fsdp_parallel_state(args: Namespace) -> ParallelState:
16-
"""Create a ParallelState instance for FSDP configuration."""
17-
world_size = dist.get_world_size()
18-
rank = dist.get_rank()
19-
20-
cp_size = args.context_parallel_size
21-
dp_rank = rank // cp_size
22-
cp_rank = rank % cp_size
23-
24-
mesh = init_device_mesh("cuda", mesh_shape=(world_size // cp_size, cp_size), mesh_dim_names=("dp", "cp"))
25-
26-
logger.info(
27-
f"[Rank {rank}] Device mesh (2D): world_size={world_size}, "
28-
f"cp_size={cp_size}, dp_size={world_size // cp_size}"
29-
)
30-
logger.info(f"[Rank {rank}] Mesh shape: {mesh.shape}, " f"dp_rank={dp_rank}, cp_rank={cp_rank}")
31-
32-
# Setup Ring Flash Attention with CP group from mesh (only when cp_size > 1)
33-
if cp_size > 1:
34-
substitute_hf_flash_attn(mesh.get_group("cp"), heads_k_stride=1)
35-
logger.info(f"[Rank {rank}] CP initialized via device mesh")
36-
else:
37-
logger.info(f"[Rank {rank}] Pure DP mode (cp_size=1)")
38-
39-
parallel_state = ParallelState(
40-
dp_rank=dp_rank,
41-
dp_src_rank=dp_rank // world_size,
42-
dp_size=world_size // cp_size,
43-
cp_rank=cp_rank,
44-
cp_size=cp_size,
45-
dp_cp_rank=rank,
46-
dp_cp_size=world_size,
47-
dp_group=mesh.get_group("dp"),
48-
dp_cp_group=dist.group.WORLD,
49-
dp_cp_group_gloo=get_gloo_group(),
50-
cp_group=mesh.get_group("cp"),
51-
tp_size=1,
52-
tp_rank=0,
53-
tp_group=dist.new_group([rank]),
54-
)
55-
56-
parallel_state.dp_mesh = mesh["dp"]
57-
58-
return parallel_state
15+
class FSDPParallelState(ParallelState):
16+
def __init__(self, args: Namespace):
17+
super().__init__()
18+
19+
world_size = dist.get_world_size()
20+
rank = dist.get_rank()
21+
22+
self.cp_size = args.context_parallel_size
23+
self.dp_size = world_size // self.cp_size
24+
self.dp_cp_size = world_size
25+
26+
self.dp_rank = rank // self.cp_size
27+
self.cp_rank = rank % self.cp_size
28+
self.dp_cp_rank = rank
29+
self.dp_src_rank = self.dp_rank // world_size
30+
31+
self.tp_size = 1
32+
self.tp_rank = 0
33+
self.tp_group = dist.new_group([rank])
34+
35+
self.mesh = init_device_mesh(
36+
"cuda", mesh_shape=(world_size // self.cp_size, self.cp_size), mesh_dim_names=("dp", "cp")
37+
)
38+
self.dp_mesh = self.mesh["dp"]
39+
40+
self.dp_group = self.mesh.get_group("dp")
41+
self.cp_group = self.mesh.get_group("cp")
42+
self.dp_cp_group = dist.group.WORLD
43+
self.dp_cp_group_gloo = get_gloo_group()
44+
45+
logger.info(
46+
f"[Rank {rank}] Device mesh (2D): world_size={world_size}, "
47+
f"cp_size={self.cp_size}, dp_size={world_size // self.cp_size}"
48+
)
49+
logger.info(f"[Rank {rank}] Mesh shape: {self.mesh.shape}, " f"dp_rank={self.dp_rank}, cp_rank={self.cp_rank}")
50+
51+
# Setup Ring Flash Attention with CP group from mesh (only when cp_size > 1)
52+
if self.cp_size > 1:
53+
substitute_hf_flash_attn(self.cp_group, heads_k_stride=1)
54+
logger.info(f"[Rank {rank}] CP initialized via device mesh")
55+
else:
56+
logger.info(f"[Rank {rank}] Pure DP mode (cp_size=1)")

slime/backends/megatron_utils/actor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from .checkpoint import load_checkpoint
3434
from .initialize import init, is_megatron_main_rank
3535
from .model import forward_only, initialize_model_and_optimizer, save, train
36-
from .parallel import create_megatron_parallel_state
36+
from .parallel import MegatronParallelState
3737
from .update_weight.common import named_params_and_buffers
3838
from .update_weight.update_weight_from_distributed import UpdateWeightFromDistributed
3939
from .update_weight.update_weight_from_tensor import UpdateWeightFromTensor
@@ -92,7 +92,7 @@ def init(
9292
args, role
9393
)
9494

95-
self.parallel_state = create_megatron_parallel_state(model=self.model)
95+
self.parallel_state = MegatronParallelState(model=self.model)
9696

9797
if role == "critic":
9898
if self.args.offload_train:

slime/backends/megatron_utils/model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@
2828
from ..training_utils.data import DataIterator, get_batch
2929
from ..training_utils.log_utils import aggregate_forward_results, aggregate_train_losses, log_train_step
3030
from ..training_utils.loss import loss_function
31-
from ..training_utils.parallel import ParallelState
3231
from .checkpoint import load_checkpoint, save_checkpoint
3332
from .model_provider import get_model_provider_func
34-
from .parallel import get_packed_seq_params
33+
from .parallel import MegatronParallelState, get_packed_seq_params
3534

3635
logger = logging.getLogger(__name__)
3736

@@ -157,7 +156,7 @@ def forward_only(
157156
model: Sequence[DDP],
158157
data_iterator: Sequence[DataIterator],
159158
num_microbatches: Sequence[int],
160-
parallel_state: ParallelState,
159+
parallel_state: MegatronParallelState,
161160
store_prefix: str = "",
162161
) -> dict[str, list[torch.Tensor]]:
163162
"""Run forward passes only and collect non-loss outputs (e.g., logprobs).
@@ -297,7 +296,7 @@ def train_one_step(
297296
optimizer: MegatronOptimizer,
298297
opt_param_scheduler: OptimizerParamScheduler,
299298
num_microbatches: int,
300-
parallel_state: ParallelState,
299+
parallel_state: MegatronParallelState,
301300
) -> tuple[dict[str, float], float]:
302301
"""Execute a single pipeline-parallel training step.
303302
@@ -482,7 +481,7 @@ def train(
482481
opt_param_scheduler: OptimizerParamScheduler,
483482
data_iterator: Sequence[DataIterator],
484483
num_microbatches: Sequence[int],
485-
parallel_state: ParallelState,
484+
parallel_state: MegatronParallelState,
486485
) -> None:
487486
"""Run training over a rollout consisting of multiple steps.
488487

slime/backends/megatron_utils/parallel.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,44 +12,45 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15-
def create_megatron_parallel_state(
16-
model: torch.nn.Module | Sequence[torch.nn.Module] | None = None,
17-
) -> ParallelState:
18-
vpp_size_value = mpu.get_virtual_pipeline_model_parallel_world_size()
19-
if vpp_size_value is None:
20-
vpp_size = 1
21-
microbatch_group_size_per_vp_stage = None
22-
elif vpp_size_value > 1:
23-
assert model is not None
24-
model_to_check = model[0] if isinstance(model, Sequence) else model
25-
config = get_model_config(model_to_check)
26-
vpp_size = vpp_size_value
27-
microbatch_group_size_per_vp_stage = config.microbatch_group_size_per_vp_stage
28-
else:
29-
vpp_size = 1
30-
microbatch_group_size_per_vp_stage = None
15+
class MegatronParallelState(ParallelState):
16+
"""
17+
ParallelState for Megatron backend, initialized from mpu module.
18+
"""
19+
20+
def __init__(
21+
self,
22+
model: torch.nn.Module | Sequence[torch.nn.Module] | None = None,
23+
):
24+
super().__init__()
25+
26+
self.dp_rank = mpu.get_data_parallel_rank(with_context_parallel=False)
27+
self.cp_rank = mpu.get_context_parallel_rank()
28+
self.tp_rank = mpu.get_tensor_model_parallel_rank()
29+
self.dp_cp_rank = mpu.get_data_parallel_rank(with_context_parallel=True)
30+
self.dp_src_rank = mpu.get_data_parallel_src_rank(with_context_parallel=True)
31+
32+
self.dp_size = mpu.get_data_parallel_world_size(with_context_parallel=False)
33+
self.dp_cp_size = mpu.get_data_parallel_world_size(with_context_parallel=True)
34+
self.cp_size = mpu.get_context_parallel_world_size()
35+
self.tp_size = mpu.get_tensor_model_parallel_world_size()
3136

32-
parallel_state = ParallelState(
33-
dp_rank=mpu.get_data_parallel_rank(with_context_parallel=False),
34-
dp_src_rank=mpu.get_data_parallel_src_rank(with_context_parallel=True),
35-
dp_size=mpu.get_data_parallel_world_size(with_context_parallel=False),
36-
cp_rank=mpu.get_context_parallel_rank(),
37-
cp_size=mpu.get_context_parallel_world_size(),
38-
dp_cp_rank=mpu.get_data_parallel_rank(with_context_parallel=True),
39-
dp_cp_size=mpu.get_data_parallel_world_size(with_context_parallel=True),
40-
dp_group=mpu.get_data_parallel_group(with_context_parallel=False),
41-
dp_cp_group=mpu.get_data_parallel_group(with_context_parallel=True),
42-
dp_cp_group_gloo=mpu.get_data_parallel_group_gloo(with_context_parallel=True),
43-
cp_group=mpu.get_context_parallel_group(),
44-
tp_size=mpu.get_tensor_model_parallel_world_size(),
45-
tp_rank=mpu.get_tensor_model_parallel_rank(),
46-
tp_group=mpu.get_tensor_model_parallel_group(),
47-
is_pp_last_stage=mpu.is_pipeline_last_stage(),
48-
vpp_size=vpp_size,
49-
microbatch_group_size_per_vp_stage=microbatch_group_size_per_vp_stage,
50-
)
37+
self.dp_group = mpu.get_data_parallel_group(with_context_parallel=False)
38+
self.dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True)
39+
self.dp_cp_group_gloo = mpu.get_data_parallel_group_gloo(with_context_parallel=True)
40+
self.cp_group = mpu.get_context_parallel_group()
41+
self.tp_group = mpu.get_tensor_model_parallel_group()
5142

52-
return parallel_state
43+
self.is_pp_last_stage = mpu.is_pipeline_last_stage()
44+
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
45+
if vpp_size is None:
46+
self.vpp_size = 1
47+
self.microbatch_group_size_per_vp_stage = None
48+
elif vpp_size > 1:
49+
assert model is not None
50+
model_to_check = model[0] if isinstance(model, Sequence) else model
51+
config = get_model_config(model_to_check)
52+
self.vpp_size = vpp_size
53+
self.microbatch_group_size_per_vp_stage = config.microbatch_group_size_per_vp_stage
5354

5455

5556
def get_packed_seq_params(batch: dict[str, torch.Tensor], args: Namespace) -> PackedSeqParams:

slime/backends/training_utils/parallel.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44

55
@dataclass
66
class ParallelState:
7-
"""Core parallel state shared across all backends.
8-
Required by the general training utils.
9-
"""
10-
117
dp_rank: int
128
dp_src_rank: int
139
dp_size: int
@@ -22,6 +18,13 @@ class ParallelState:
2218
tp_size: int
2319
tp_rank: int
2420
tp_group: dist.ProcessGroup | None
25-
is_pp_last_stage: bool = True
26-
vpp_size: int | None = 1
27-
microbatch_group_size_per_vp_stage: int | None = None
21+
dp_mesh: dist.DeviceMesh | None
22+
cp_mesh: dist.DeviceMesh | None
23+
is_pp_last_stage: bool
24+
vpp_size: int | None
25+
microbatch_group_size_per_vp_stage: int | None
26+
27+
def __init__(self):
28+
self.vpp_size = 1
29+
self.microbatch_group_size_per_vp_stage = None
30+
self.is_pp_last_stage = True

0 commit comments

Comments
 (0)